Releases: huggingface/trl
v0.17.0
Major and breaking
The TRL v0.17 release introduces three major changes that, together, enable significantly faster generation performance in GRPO—up to 10x faster in some configurations.
These three changes are:
- Data parallelism (DP) for the vLLM server
- A new GRPO training strategy that generates once per effective batch
- Support for the V1 engine in vLLM
Below, we provide a summary of these changes and how to use them.
⚡ Up to 4x faster: Data Parallel for vLLM server
The TRL vLLM server now supports data parallelism (DP), enabling significantly faster generation speeds—especially for smaller models. This new feature can be used by adding the --data_parallel_size N
argument when launching the vLLM server.
trl vllm-serve --model Qwen/Qwen2.5-14B-Instruct --tensor_parallel_size 2 --data_parallel_size 2
by @qgallouedec in #3310
* ☝️ [GRPO] Generate once per effective batch
Previously, GRPO made one generation request per global batch. The global batch is the total of all local batches, without accounting for gradient accumulation. In other words, if the gradient accumulation step was 8, GRPO would make 8 generation requests per training step.
Now, GRPO groups these global batches into a single "effective batch" and makes only one generation request per effective batch. Since vLLM applies optimizations that are especially effective for large batches, this new approach leads to significantly faster training overall.
No changes are required in the training script, as this is handled internally by the GRPO trainer.
by @qgallouedec in #3283
⏱️ Fix vLLM server to support V1 Engine
vLLM provides two versions of its engine (V0 and V1), and V1 is significantly faster. This version is now supported by TRL and requires vLLM version 0.8.3 or higher.
👎 [GRPO] Adds option to disable dropout
Disabling dropout has shown to stabilize training. You can now disable dropout in GRPO by setting the disable_dropout
argument to False
in the GRPO config.
from trl import GRPOConfig
training_args = GRPOConfig(..., disable_dropout=True)
by @edbeeching in #3234
🩺 Dr. GRPO loss
GRPO now supports the various losses proposed in the recent literature, including the Dr. GRPO loss. The loss type can be set in the GRPO config:
from trl import GRPOConfig
training_args = GRPOConfig(..., loss_type="dr_grpo")
by @qgallouedec in #3256
🎲 [GRPO] Make training dataset shuffle optional
The GRPO trainer now has an option to disable shuffling of the training dataset. This is useful for curriculum learning, where the order of the training data is important.
from trl import GRPOConfig
training_args = GRPOConfig(..., shuffle_dataset=False)
by @LeonEricsson in #3334
☕ Overlong-filtering for GRPO
Overlong filtering has been shown to significantly stabilize learning and improve performance. You can now use it in TRL!
It simply consists in masking the loss of truncated samples
from trl import GRPOConfig
training_args = GRPOConfig(..., mask_truncated_completions=True)
by @shirinyamani in #3248
🐯 Integrate Liger GRPO Loss to GRPO Trainer
Liger allows to significantly reduce the memory peak of the loss computation. You can now use it in TRL with the use_liger_loss
argument in the GRPO config:
from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_loss=True)
by @shivam15s in #3184
Bug fixes
- Fix: Multi gpu hang for ORPO and CPO Trainer by @NanoCode012 in #3069
- 📊 Fix
clip_ratio
logging and better document logged values by @qgallouedec in #3145 - ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint by @PenutChen in #3148
- 📎 Fix is_clipped to compute the effective clip_ratio by @pandong2011 in #3175
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- ⏯️ Fix logging when resuming from checkpoint GRPO by @qgallouedec in #3185
- 💠 Fix multi-gpu padding free by @qgallouedec in #3245
- 🕷 Fix online DPO crash when model is a DataParallel object by @wilrop in #3225
- 🏁 Fix adding special tokens in SFT by @qgallouedec in #3328
- 🍡 Fix using reward model and DeepSpeed ZeRO 3 by @qgallouedec in #3326
What's Changed
- Fix: Multi gpu hang for ORPO and CPO Trainer by @NanoCode012 in #3069
- 📊 Fix
clip_ratio
logging and better document logged values by @qgallouedec in #3145 - BCOTrainer version upgrade fixes by @claralp in #2867
- 🐇 [Research] Layer Skip SFT by @ariG23498 in #3111
- 🤝 Align GRPO equation doc with the implementation by @qgallouedec in #3151
- Enable number of printed completions to be set by @lewtun in #3149
- 🩹 Fix CI by @qgallouedec in #3155
- ⚰️ Remove deprecated by @qgallouedec in #3153
- 🔫 Disable triggering CI when PR is draft by @qgallouedec in #3154
- 👨🍳 vLLM serve: destroy process group on exit and pass
worker_cls
as string by @qgallouedec in #3159 - 💰 Richer rich table - log all the rewards by @qgallouedec in #3156
- 💎 Gemma 3 VLM SFT example script for single-image and multi-image by @sergiopaniego in #3131
- [Liger] Liger KTO support by @vaibhavjindal in #2812
- 🏃 Migrate CI to self-hosted runners by @qgallouedec in #3174
- ❤️🩹 [CI] fix transformers dev CI failure by @kashif in #3176
- ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint by @PenutChen in #3148
- 📎 Fix is_clipped to compute the effective clip_ratio by @pandong2011 in #3175
- Fix breaking typo for flash_attention reducing_memory_usage.md by @burtenshaw in #3190
- Show unique prompts in GRPO WandB tables by @lewtun in #3191
- 🐗 [CI] Fix trufflehog false positives by @lewtun in #3192
- [GRPO] Improve completion length logging by @edbeeching in #3188
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- 🗝️ Fix type hint in vLLM client by @qgallouedec in #3205
- 📚 Accumulate completions for logging by @lewtun in #3217
- Group completion metrics by common prefix by @lewtun in #3212
- 🐯 Integrate Liger GRPO Loss to GRPO Trainer by @shivam15s in #3184
- Update ruff to 11.3 and base Python version to 3.9 by @cyyever in #3230
- ⏯️ Fix logging when resuming from checkpoint GRPO by @qgallouedec in #3185
- 📢 Improve GRPO trainer error message for invalid num_generations by @AliBakly in #3199
- 🎀 Simplify logging text by @qgallouedec in #3219
- 🌊 Add error for iterable datasets in GRPOTrainer by @qgallouedec in #3216
- ⏳ PPOTrainer: fix progress bar for num_mini_batches > 1 by @dawidm in #2531
- ☑ Update PULL_REQUEST_TEMPLATE.md by @qgallouedec in #3241
- 🔭 Add support for better KL estimator (k3) in PPOTrainer by @AMindToThink in #3240
- 🏃 Fix and make CI faster by @qgallouedec in #3160
- 🗑️ Deprecate
ConstantLengthDataset
by @qgallouedec in #3242 - 📦 [SFT] Deprecate batched
formatting_func
by @YeFD in #3147 - 💠 Fix multi-gpu padding free by @qgallouedec in #3245
- ☕ Overlong-filtering for GRPO by @shirinyamani in #3248
- 📜 Fix license and copyrights by @qgallouedec in #3264
- ⛏️ Add cli dict parsing for grpo_config by @Tavish9 in #3082
- 🐯
is_liger_kernel_available
with min version by @qgal...
v0.16.1
What's Changed
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- 📉 Add
learning_rate
argument to_maybe_log_save_evaluate
by @qgallouedec in #3206
Full Changelog: v0.16.0...v0.16.1
v0.16.0
Major and breaking
🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication
Previously, vLLM could only be used by dedicating a single GPU, preventing both the scalability benefits of vLLM and multi-node training. This limitation has now been removed!
GRPO can now scale efficiently with models exceeding 70B parameters, supporting multi-node training with super-fast performance.
To take advantage of this, simply launch a vLLM server using the following command:
trl vllm-serve --model <model_name> --tensor_parallel_size <tp_size>
Then, start GRPO training with use_vllm=True
.
Below is a comparison of GRPO throughput with and without vLLM, across different TP values and model sizes.
@binary-husky and @qgallouedec in #3094
🐦🔥 6x faster GRPO with multi-step optimization
This release introduces the multi-step trick, which allows for the reuse of generated data across multiple steps, speeding up the training process.
To support this, we've implemented importance sampling and clipping logic. This enhancement should lead to significant improvements in training speed.

To use it, simply set num_iterations
to a value greater than 1.
training_args = GRPOConfig(..., num_iterations=4)
by @qgallouedec in #2899
🌍 Use global normalization in GRPO
As demonstrated in Dr GRPO, sequence-level normalization can introduce a response level length bias.
To address this, we have now switched to normalizing the loss and by the total number of tokens in the batch, ensuring more consistent and unbiased training.
- loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
by @edbeeching in #2881
⚖️ Add option not to scale rewards
As demonstrated in Dr GRPO, scaling rewards can introduce a question-level difficulty bias. To address this, we have now added an option to disable reward scaling in GRPO.
training_args = GRPOConfig(..., scale_rewards=False)
advantages = rewards - mean_grouped_rewards
- advantages = advantages / std_grouped_rewards
+ if self.args.scale_rewards:
+ advantages = advantages / std_grouped_rewards
it's likely that we'll make this (scale_rewards=False
) the default behavior in the future.
by @qgallouedec in #3135
🤸♀️ Domain-specific rewards in GRPO
When optimizing across multiple domains, not all reward functions are relevant for every sample. For example, a math verifier's reward does not apply to grammar samples, and a grammar verifier's reward does not apply to math samples.
It is now possible to return None
for rewards that do not make sense for a given sample. For instance, when the domain is specified in a column like domain
, you can implement it as follows:
def math_reward(completions, domain, **kwargs):
rewards = []
for completion, dom in zip(completions, domain):
if dom == "math":
rewards.append(verify(completion))
else:
rewards.append(None)
return rewards
This allows for more domain-specific reward handling, ensuring that irrelevant rewards are ignored and don’t interfere with optimization.
by @shirinyamani in #3079
🍃 Do not load reference model when beta == 0.0
It has been observed that not minimizing the KL divergence between the trained model and the reference model can still yield good results, while significantly reducing memory usage and compute. This is because there is no need to store the reference model in memory or perform a forward pass for it.
When beta
is set to 0.0
, the reference model is not loaded, and the KL divergence is not computed, leading to savings in both time and memory.
training_args = GRPOConfig(..., beta=0.0)
🕊️ Padding-free for SFT
Padding-free batching is an alternative approach to packing for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
To enable padding-free batching in SFT, simply set padding_free=True
in the SFTConfig
, and make sure to use flash_attention2
as the attention implementation.
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention2"})
by @qgallouedec in #3076
🎬 Clip Higher for Better Exploration
As outlined in the DAPO paper, increasing the upper bound epsilon leads to higher entropy during generation, promoting better exploration. To enable this, we’ve added support for adjusting the upper bound epsilon directly in the default GRPO trainer.
training_args = GRPOConfig(epsilon_high=0.28)
by @shirinyamani in #3118
Bug fixes
- 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 🐯 Fix LigerKernel for SFTTrainer by @lewtun @kashif and @qgallouedec in #2874, #2940 and #2949
- 🫔 [GRPO] Pass wrapped model to
unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility by @kiddj in #2871 - 🛣️
inference_mode
tono_grad
when computingold_per_token_logps
by @qgallouedec in #2987 - 🏊 [SFT] Compatibility with padding free and iterable dataset by @qgallouedec in #3053
- Fixing JSD loss computation in GKDTrainer as per definition by @abhigoyal1997 in #3043
Minor
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW and @DanFosing in #2863 and #2939
- ✨ Add vLLM guided decoding support to GRPO Trainer by @kldzj in #2811
- 🩳
max_seq_length
tomax_length
by @qgallouedec in #2895 and #2947 - Optimize vllm num_generations by @edbeeching in #2855
- 📍 [GRPO] add gradient_checkpointing by @kashif in #2848
- 🪪 Adds profiling decorators for GRPOTrainer by @edbeeching in #2889 and #2975
- 🐈 Bye bye chat by @qgallouedec in #2934
- 📇 GRPO: print completions to console and update docs by @nopepper in #2951
- 👧🏽 Adding DoRA support to model config by @nbasyl in #2974
- 🧗 Add GRPO Trainer support for third-party accelerators by @ji-huazhong in #2836
- 🪙 [SFT] Log
num_tokens
and some logging fixes by @qgallouedec in #3006 - 🌡️ Fix temperature inconsistency in GRPO trainer by @Aladoro in #3029
- ⛔ Add EOS token to processed input in SFT by @qgallouedec in #3091
- ⚡ Pack 300 times faster, truncate 100 times faster by @mariosasko in #3009
What's Changed
- [SFT] fix check for AutoLigerKernelForCausalLM by @kashif in #2874
- 🆙 Bump vLLM min version to 0.7.2 by @edbeeching in #2860
- [GRPO] Fix loss normalization by @edbeeching in #2881
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW in #2863
- Optimize vllm num_generations ...
v0.15.2
What changed
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 🐯 Fix LigerKernel for SFTTrainer by @lewtun in #2940
- 📌 Pin liger-kernel and vLLM by @qgallouedec in #2952
Full Changelog: v0.15.1...v0.15.2
v0.15.1
What's Changed
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets by @kashif in SFT in #2862 - [SFT] fix check for AutoLigerKernelForCausalLM by @kashif in #2874
- 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW in #2863
- 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- Release: v0.15.1 by @qgallouedec
Full Changelog: v0.15.0...v0.15.1
v0.15.0
Major and breaking changes
Coming soon
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #2689
- 📦
trl.templates
in excluded packages by @qgallouedec in #2690 - 📖 Docs fix spelling issues by @nnsW3 in #2682
- 📄 Add GRPO batch size note in docs by @sdpkjc in #2672
- 🙈 Fixed typo in the GRPO documentation by @famouswizard in #2691
- docs: Fix broken "Good First Issue" link in CONTRIBUTING.md by @famouswizard in #2693
- 🧠 Fix typo in "understand" in ppo_trainer.md by @famouswizard in #2695
- ☠️ Remove deprecated by @qgallouedec in #2692
- 💡 Add "Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial" by @qgallouedec in #2697
- 📋 Add eval loss logging during prediction in GRPO by @kashif in #2694
- fix: Fix typo in filename Update ultrafeedback.py by @brawncode in #2699
- 📖 Add GRPOTrainer to README.md by @burtenshaw in #2713
- Improve GRPO example by @lewtun in #2717
- 📖 Nit Fix in Documentation by @ParagEkbote in #2722
- 🏰
num_logits_to_keep
tologits_to_keep
by @qgallouedec in #2721 - 💰 Fix incorrect calculation in Olivia's baguette spending logic by @defiberrys in #2727
- fix: Fix typo in filename in ultrafeedback-prompt.py by @brawncode in #2716
- docs: Fix typos in alias descriptions by @defiberrys in #2729
⚠️ Fix Attention Masking in GRPO by @andyl98 in #2708- 🔂 Use vLLM prefix caching for speedup by @winglian in #2757
- 💔 Decouple loss computing and generation in GRPO by @qgallouedec in #2762
- 📌 vLLM >= 0.7.1 for device fix by @ctjlewis in #2766
- 📐 Add vLLM dtype configuration for GRPO trainer by @joey00072 in #2738
- 📖 Clarification max len in Reward documentation by @ParagEkbote in #2740
- 🔎 Add missing script argument in PPO documentation by @JohnConnor123 in #2720
- 🤖 Properly unwrap torch.compile-ed models in GRPO by @winglian in #2750
- 🔁 🦈 Support iterative GRPO by @shirinyamani in #2700
- 🚧 Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation by @SeungyounShin in #2667
↔️ GRPO: Set max_model_len when initializing vLLM instance by @mirceapricop in #2728- 💡 GRPO vram-efficiency improvement; only compute relevant logprobs by @tyler-romero in #2773
- 🙃 Fix reward function in GRPO example by @junuMoon in #2777
- 💡 Add 'Post training an LLM for reasoning with GRPO in TRL' tutorial by @sergiopaniego in #2785
- 📉 Optimize GRPO memory usage by redefining
per_device_batch_size
as generations per device by @qgallouedec in #2776 - 🆚 Distinguish padding and eos when they differ by @binary-husky in #2793
- 🎯 [SFT] add token accuracy metric by @kashif in #2597
- 📠 Log completions for GRPO by @qgallouedec in #2772
- 🔬 SFT simplification by @qgallouedec in #2405
- ➖ Fix GRPO example in README by @qgallouedec in #2800
- ⛰️ Reduce peak vram consumption with efficient selective log_softmax by @tyler-romero in #2799
- fix: typos in documentation files by @maximevtush in #2804
- 📤 GRPO refactor loading the model weights to vllm by @winglian in #2817
- 🫘 Add
set_seed()
call in GRPO to ensure unique seed for each process by @qgallouedec in #2824 - ⚖️ Add reward weight in multi-reward settings for GRPO by @hesamsheikh in #2676
- 🙌 Share vLLM device with training when only 1 available by @qgallouedec in #2827
- 👴 Update
tokenizer
parameter toprocessing_class
in tests by @qgallouedec in #2828 - 🥾 Allow bootstrap GRPO by @qgallouedec in #2829
- ⚡ Fix GRPO PEFT by @qgallouedec in #2725
- Fix PeftModel check when moving weights to vlllm by @edbeeching in #2850
- 🪆 Fix for Incorrect ValueError Handling in reward_weights in grpo_trainer.py by @loveychen in #2843
- 👨👩👧 GRPO + PEFT + vLLM by @winglian in #2818
New Contributors
- @nnsW3 made their first contribution in #2682
- @sdpkjc made their first contribution in #2672
- @famouswizard made their first contribution in #2691
- @brawncode made their first contribution in #2699
- @ParagEkbote made their first contribution in #2722
- @defiberrys made their first contribution in #2727
- @ctjlewis made their first contribution in #2766
- @joey00072 made their first contribution in #2738
- @JohnConnor123 made their first contribution in #2720
- @shirinyamani made their first contribution in #2700
- @mirceapricop made their first contribution in #2728
- @tyler-romero made their first contribution in #2773
- @junuMoon made their first contribution in #2777
- @binary-husky made their first contribution in #2793
- @maximevtush made their first contribution in #2804
- @hesamsheikh made their first contribution in #2676
- @loveychen made their first contribution in #2843
Full Changelog: v0.9.6...v0.15.0
v0.14.0
Major and breaking changes
👨👨👧👧 GRPO
by @qgallouedec in #2565
What's Changed
- ⚰️ Remove deprecated by @qgallouedec in #2485
- 🗣️ Improve prose for smol course by @burtenshaw in #2487
- 🤩 Add SmolVLM tutorials to Community Tutorials page by @sergiopaniego in #2498
- 🏞️ Proper dataset for documentation images by @qgallouedec in #2499
- 🗂️ Reorganize documentation by @qgallouedec in #2483
- [ORPO] fix orpo chosen-nll loss by @kashif in #2502
- 🏚 Remove unused components by @qgallouedec in #2480
- Update community_tutorials.md by @qgallouedec in #2509
- ❎ Remove RLOO example test by @qgallouedec in #2513
- 👨🍳 Clarify DPO data preparation by @qgallouedec in #2512
- 💧 Generalize
disable_dropout
by @qgallouedec in #2511 - 👬 Rename collator
PreferenceCollator
toDataCollatorForPreference
by @qgallouedec in #2510 - 📦 Packing documentation by @qgallouedec in #2503
- ☄️ Update Comet integration to include LogCompletionsCallback and Trainer.evaluation_loop() by @yaricom in #2501
- Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM by @Abhishek-TAMU in #2158
- 🚜 Use field in dataclasses by @qgallouedec in #2494
- ©️ Update copyrights year by @qgallouedec in #2547
- 🧑🤝🧑 Proper metrics gathering across ranks before logging by @zhc7 in #2474
- ✒️ Fix typo in
formatting_func
's documentation inConstantLengthDataset
by @SamuelLarkin in #2549 - 🕊️ DPO padding free by @qgallouedec in #2520
- ℹ️ XPU support for DPO by @faaany in #2533
- 🔠 Fix SFT truncation documentation by @umbilnm in #2521
- ↩️ Revert ORPO loss changes by @kashif in #2527
- 🎴 Add readme for datasets by @August-murr in #2491
- 💔 Fix dataset type unpair conversion docs by @claralp in #2550
- [RLOO] Reinforce++ by @kashif in #2552
- 🏛️ Improve DPO configuration documentation structure by @qgallouedec in #2561
- ✨ Refine model card method docstring by @qgallouedec in #2566
- 🪄 Minor comment style modif by @qgallouedec in #2582
- 🏎️ vllm for Online DPO by @qgallouedec in #2558
- 🔖 Issues Auto-Labeller by @August-murr in #2542
- 🐛 Simplify bug report template by @qgallouedec in #2585
- [RLOO] fix token_level_kl by @kashif in #2575
- ✂️ Truncate by default by @qgallouedec in #2587
- 🫢 Add
max_prompt_length
parameter in tests by @qgallouedec in #2588 - 🎞️ Fix documentation SFT -
max_seq_length
instead ofmax_length
by @skandermoalla in #2590 - 👨👨👧👧 GRPO by @qgallouedec in #2565
- 🫣 Ignore CLI test for Python 3.9 by @qgallouedec in #2592
- Fix merge error by @qgallouedec in #2595
- 🧰 Tool fine-tuning support DPO by @August-murr in #2479
- 💾 Reduce memory peak in GRPO by adding
max_prompt_length
and loop usage in logp computation by @qgallouedec in #2598 - ⚡ Add uv installation instructions by @stevhliu in #2601
- 🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional by @dawidm in #2557
- 🫷 Include stop token in policy model's generation_config by @dawidm in #2528
- ✂️ Reintroduce
truncation_mode
inDPOTrainer
by @anakin87 in #2551 - 👋 Drop MDX by @qgallouedec in #2611
- 💎 Rename an inner var in GRPO to improve clarity by @qgallouedec in #2616
- 🏆 Custom reward function for GRPO and shiny doc by @qgallouedec in #2606
- 🥞 Fix DPO gradient accumulation loss scaling by @winglian in #2615
- 🥞 Fix BCO gradient accumulation loss scaling by @qgallouedec in #2638
- 🍭 Custom reward function for RLOO by @August-murr in #2612
- 🌯 Fix context manager runtime error when gather is disabled by @Superskyyy in #2639
- 🥞 Fix CPO gradient accumulation loss scaling by @qgallouedec in #2645
- 🥞 Fix GRPO gradient accumulation loss scaling by @qgallouedec in #2647
- 🥞 Fix KTO gradient accumulation loss scaling by @qgallouedec in #2648
- 🚛 Provide all columns of the dataset to the reward function by @qgallouedec in #2650
- 👐 DeepSpeed integration for GRPO by @qgallouedec in #2652
- 🔎 Finegrained reward logging for GRPO by @qgallouedec in #2651
- 📍 Disable caching when grad checkpointing enable in GRPO by @qgallouedec in #2653
- 📏 Log completion length in GRPO by @qgallouedec in #2659
- 🌀 Fix GRPO default completion length doc by @andyl98 in #2662
- 🏷️ Add model tags to model trained with GRPO by @qgallouedec in #2663
- 🖊 Fix typos by @omahs in #2673
- ⚡ vLLM for fast generation in GRPO by @qgallouedec in #2600
- 📉 Use
num_logits_to_keep
to reduce memory usage in GRPO by @qgallouedec in #2683
New Contributors
- @Abhishek-TAMU made their first contribution in #2158
- @zhc7 made their first contribution in #2474
- @SamuelLarkin made their first contribution in #2549
- @umbilnm made their first contribution in #2521
- @stevhliu made their first contribution in #2601
- @dawidm made their first contribution in #2557
- @Superskyyy made their first contribution in #2639
- @andyl98 made their first contribution in #2662
- @omahs made their first contribution in #2673
Full Changelog: v0.13.0...v0.14.0
v0.13.0
Major and breaking changes
🐾 Process-supervised RM Trainer
We introduced a new trainer to train Process-supervised Reward Model (PRM) in TRL. A PRM rewards the quality of intermediate steps, promoting structured reasoning over focusing solely on the final outcome.With this trainer, we introduce a new dataset type: Stepwise supervision, which is a variant of the prompt-completion type, but for which completion is divided into several intermediate steps, and each step is associated with a label. Find out more in the stepwise-supervision section in the TRL documentation.
Here is an example of how to use the PRMTrainer
to train a PRM on the Math Shepherd dataset:
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
For more information, check out the PRMTrainer documentation.
by @qgallouedec and @gaetanlop in #2127 and #2148
🔀 Add MergeModelCallBack
Various works show that model merging can non-trivially improve performance, especially if the models belong to the same architecture. TRL now features a callback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. This callback uses Arcee's mergekit lib: https://github.com/arcee-ai/mergekit
from trl import DPOTrainer, MergeModelCallback
from trl.mergekit_utils import MergeConfig
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
by @August-murr in #2282
🔨 Support for tools for data utils
TRL preprocessing utils now support tooling. A first step toward agent fine-tuning.
from trl import apply_chat_template
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.
Args:
location: The location to get the temperature for
"""
return 22.0
example = apply_chat_template(example, tokenizer, tools=[get_current_temperature])
by @August-murr in #2455
🌋 Add support for LLaVA-Next in DPOTrainer
VLMs have their own specificities which require special treatment in the trainer. DPOTrainer
now supports LLaVA-Next models natively.
model = model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
trainer = DPOTrainer(model=model, ...)
by @chenweize1998 in #2413
🕹️ CLI and TRLParser refactor
TRL CLI has been refactored to be more user-friendly and easy to extend. We plan to extend the support to all trainers soon.
(simplified output, for readibility)
$ trl dpo --help
usage: trl dpo [-h] --dataset_name DATASET_NAME [--dataset_config DATASET_CONFIG] --output_dir OUTPUT_DIR [--loss_type {sigmoid,hinge,ipo}]
options:
-h, --help show this help message and exit
--dataset_name DATASET_NAME, --dataset-name DATASET_NAME
--dataset_config DATASET_CONFIG, --dataset-config DATASET_CONFIG
--output_dir OUTPUT_DIR, --output-dir OUTPUT_DIR
The output directory where the model predictions and checkpoints will be written. (default: None)
--loss_type {sigmoid,hinge,ipo}, --loss-type {sigmoid,hinge,ipo}
by @qgallouedec in #2380 and #2412
🤝 Mixture of judges
TRL features a new judge AllTrueJudge
that unifies the decision of multiple binary judges. This judge implements the Mixture of Judges as described in the CGPO paper.
from trl import AllTrueJudge, BaseBinaryJudge
class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""
def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
judgements = judge.judge(prompts=prompts, completions=completions)
print(judgements) # [0, 1]
by @gaetanlop in #2159
❄️ DPO trainer supports num_logits_to_keep
to save memory
Save memory by only keeping the top num_logits_to_keep
logits in the DPO trainer.
training_args = DPOConfig(..., use_num_logits_to_keep=True)
🗺️ Implementation DiscoPOP Loss
The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0).
training_args = DPOConfig(..., loss_type="discopop", discopop_tau=0.05)
🧑🍳 Add precompute batch size argument in DPOTrainer
for reference model
We can now control the batch size for precomputing reference model logits.
training_args = DPOConfig(
...
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
)
by @SwayamInSync in #2426
📦 Support for packing tokenized datasets for SFT
SFTTrainer
has supported packing datasets for faster training. Now, it support packing tokenized datasets as well.
📉 Add PEFT support for PPOTrainer
PPOTrainer
now supports PEFT for efficient training.
PPOTrainer(
...,
peft_config=peft_config,
)
💾 Deprecate config
in favor of args
in PPOTrainer
config
has been deprecated in favor of args
in PPOTrainer
.
PPOTrainer(
- config=training_args,
+ args=training_args,
)
by @qgallouedec in #2384
👮 Deprecate policy
in favor of model
in PPOTrainer
policy
has been deprecated in favor of model
in PPOTrainer
.
PPOTrainer(
- policy=model,
+ model=model,
)
by @qgallouedec in #2386
What's Changed
- ⏫ Bump dev version to
0.13.0.dev0
by @qgallouedec in #2305 - 📰 Update blog posts in documentation by @qgallouedec in #2319
- ⚰️ Remove deprecated args, script arguments, and PPOv2 by @qgallouedec in #2306
- 🧽 Fix judge doc by @qgallouedec in #2320
- 🪧 Fix slack notification titles by @qgallouedec in #2322
- 🪪 Check with
token_id
instead oftoken
inDPOTrainer
by @qgallouedec in #2324 - Fix wrong truncating index of tensor in DPOTrainer's concatenated_forward() by @yanghh2000 in #2332
- Fix gradient_checkpointing_kwargs assignment in examples by @Galaxy-Husky in #2331
- Bump liger-kernel to 0.4.0 by @ByronHsu in #2333
- DPO trainer supports num_logits_to_keep to save memory by @xyangk in #2129
- 🧞 Add
output_layer
to the list oflm_head_namings
inAutoModelForCausalLMWithValueHead
by @qgallouedec in #2328 - 🫴 Better guide users in error reporting by @qgallouedec in #2327
- 🪡 Various RLOO fixes by @qgallouedec in #2325
- 💣 Remove transformers version check by @xyangk in #2343
- 👈 Add
tokenizer
arg back and add deprecation guidelines by @qgallouedec in #2348 - 🖨️ Fix error text in BCO and KTO tokenizing function by @PhilipMay in #2286
- Adding video llm fine-tuning example by @mfarre in #2336
- 👋 Remove deprecated
tokenizer
argument in BCO, GKD, Iterative SFT, Nash MD and XPO by @qgallouedec in #2349 - ⚖️ Add
use_soft_judge
option toWinRateCallback
by @kashif in #2347 - 🪜 Stepwise supervision dataset type by @qgallouedec in #2148
- 🔮 Inference mode in
GeometricMixtureWrapper.forward
by @kashif in #2345 - 🗃️ Use specified
data_collator
inRLOOTrainer
andPPOTrainer
by @bartoszzuk in h...
v0.12.2
v0.12.1
What's Changed
- 👈 Add
tokenizer
arg back and add deprecation guidelines by @qgallouedec in #2348
Full Changelog: v0.12.0...v0.12.1