Unverified Commit 769b8d04 by Ziniu Li Committed by GitHub

Feature/add remax support (#234)

## Description
Added [ReMax](https://arxiv.org/abs/2310.10505) support to verl. ReMax
is a simple, efficient, and stable RL algorithm customized for LLM
training, with theoretical guarantees for variance reduction.

The [HybridFlow](https://arxiv.org/pdf/2409.19256v2) paper experimented
with ReMax, but verl did not provide an implementation. Therefore, ReMax
has been added.


## Changes
- Added RayReMaxTrainer implementation
- Added example scripts for ReMax training
- Added documentation for ReMax algorithm

## Testing
- Tested ReMax example scripts with Qwen models

validation reward of optimizing Qwen2.5-3B-Instruct on the GSM8K
dataset:

<img width="501" alt="截屏2025-02-09 20 51 14"
src="https://github.com/user-attachments/assets/742c2eab-6877-4c3c-b0a2-4159bd109add"
/>

The curve demonstrates the effectiveness of ReMax, though its
performance can be further enhanced through hyperparameter fine-tuning.

## Documentation
- Added ReMax documentation
- Updated example configurations

## Checklist
- [x] Code follows project's style guidelines (yapf formatted)
- [x] Tests added/updated and passing
- [x] Documentation updated
- [x] Example scripts added
parent 5a66ed26
...@@ -50,6 +50,10 @@ jobs: ...@@ -50,6 +50,10 @@ jobs:
run: | run: |
ray stop --force ray stop --force
bash tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh bash tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm (ReMax)
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_function_rm_remax.sh
- name: Running gsm8k e2e without rmpad using function rm and load ckpt from previous step - name: Running gsm8k e2e without rmpad using function rm and load ckpt from previous step
run: | run: |
ray stop --force ray stop --force
......
...@@ -41,7 +41,7 @@ verl is fast with: ...@@ -41,7 +41,7 @@ verl is fast with:
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon. - **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
- huggingface models support - huggingface models support
- Supervised fine-tuning - Supervised fine-tuning
- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer) and [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer) - Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer), [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer), and [ReMax](https://github.com/volcengine/verl/tree/main/examples/remax_trainer)
- Support model-based reward and function-based reward (verifiable reward) - Support model-based reward and function-based reward (verifiable reward)
- flash-attention, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [long context](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh) - flash-attention, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [long context](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh)
- scales up to 70B models and hundreds of GPUs - scales up to 70B models and hundreds of GPUs
......
...@@ -19,6 +19,8 @@ Refer to the table below to reproduce PPO training from different pre-trained mo ...@@ -19,6 +19,8 @@ Refer to the table below to reproduce PPO training from different pre-trained mo
.. _Megatron PPO Command and Logs: https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log .. _Megatron PPO Command and Logs: https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log
.. _Qwen7b GRPO Script: https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh .. _Qwen7b GRPO Script: https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
.. _Megatron wandb: https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3 .. _Megatron wandb: https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3
.. _Qwen7b ReMax Script: https://github.com/eric-haibin-lin/verl/blob/main/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh
.. _Qwen7b ReMax Wandb: https://wandb.ai/liziniu1997/verl_remax_example_gsm8k/runs/vxl10pln
+----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+
| Model | Method | Test score | Details | | Model | Method | Test score | Details |
...@@ -37,6 +39,7 @@ Refer to the table below to reproduce PPO training from different pre-trained mo ...@@ -37,6 +39,7 @@ Refer to the table below to reproduce PPO training from different pre-trained mo
+----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+
| Qwen/Qwen2-7B-Instruct | GRPO | 89 | `Qwen7b GRPO Script`_ | | Qwen/Qwen2-7B-Instruct | GRPO | 89 | `Qwen7b GRPO Script`_ |
+----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+
| Qwen/Qwen2.5-7B-Instruct | ReMax | 97 | `Qwen7b ReMax Script`_, `Qwen7b ReMax Wandb`_ |
+----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+
.. [1] During the evaluation, we have only extracted answers following the format "####". A more flexible answer exaction, longer response length and better prompt engineering may lead to higher score. .. [1] During the evaluation, we have only extracted answers following the format "####". A more flexible answer exaction, longer response length and better prompt engineering may lead to higher score.
\ No newline at end of file
set -x
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=remax \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=512 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_remax_example_gsm8k' \
trainer.experiment_name='qwen2.5_3b_function_rm_kl1e-3' \
+trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=5 $@
\ No newline at end of file
set -x
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=remax \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_remax_example_gsm8k' \
trainer.experiment_name='qwen2.5_7b_function_rm_kl1e-3' \
+trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=10 $@
\ No newline at end of file
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
algorithm.adv_estimator=remax \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
...@@ -188,6 +188,37 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten ...@@ -188,6 +188,37 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten
return advantages, returns return advantages, returns
def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
eos_mask: torch.Tensor):
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio return token_level_scores - kl * kl_ratio
......
...@@ -23,6 +23,7 @@ from dataclasses import dataclass, field ...@@ -23,6 +23,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pprint import pprint from pprint import pprint
from typing import Type, Dict from typing import Type, Dict
from copy import deepcopy
import numpy as np import numpy as np
from codetiming import Timer from codetiming import Timer
...@@ -154,6 +155,22 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ...@@ -154,6 +155,22 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma) token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma)
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'remax':
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
reward_baselines = data.batch['reward_baselines']
advantages, returns = core_algos.compute_remax_outcome_advantage(token_level_rewards=token_level_rewards,
reward_baselines=reward_baselines,
eos_mask=response_mask)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else: else:
raise NotImplementedError raise NotImplementedError
return data return data
...@@ -355,6 +372,8 @@ class RayPPOTrainer(object): ...@@ -355,6 +372,8 @@ class RayPPOTrainer(object):
self.use_critic = False self.use_critic = False
elif self.config.algorithm.adv_estimator == 'reinforce_plus_plus': elif self.config.algorithm.adv_estimator == 'reinforce_plus_plus':
self.use_critic = False self.use_critic = False
elif self.config.algorithm.adv_estimator == 'remax':
self.use_critic = False
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -826,6 +845,22 @@ class RayPPOTrainer(object): ...@@ -826,6 +845,22 @@ class RayPPOTrainer(object):
with _timer('gen', timing_raw): with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax':
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object) dtype=object)
# repeat to align with repeated responses in rollout # repeat to align with repeated responses in rollout
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment