Unverified Commit 6a820b61 by Zefan Wang Committed by GitHub

algo: Rloo advantage estimator (#341)

Implement RLOO algorithm according to https://arxiv.org/abs/2402.14740
parent 76352ae9
......@@ -130,7 +130,7 @@ Actor/Rollout/Reference Policy
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
n: 1 # > 1 for grpo, rloo
**Common config for actor, rollout and reference model**
......@@ -328,7 +328,7 @@ Algorithm
- ``gemma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``.
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``rloo``
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse`` and ``full``. How to
calculate the kl divergence between actor and reference policy. For
specific options, refer to `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py#L192>`_ .
......
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=rloo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_rloo_example_gsm8k' \
trainer.experiment_name='qwen2_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
......@@ -154,6 +154,51 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
return scores, scores
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num -
1) - id2mean[index[i]] * response_num / (response_num - 1)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor,
gamma: torch.Tensor):
"""
......
......@@ -171,6 +171,18 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'rloo':
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards=token_level_rewards,
eos_mask=response_mask,
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else:
raise NotImplementedError
return data
......@@ -368,11 +380,7 @@ class RayPPOTrainer(object):
if self.config.algorithm.adv_estimator == 'gae':
self.use_critic = True
elif self.config.algorithm.adv_estimator == 'grpo':
self.use_critic = False
elif self.config.algorithm.adv_estimator == 'reinforce_plus_plus':
self.use_critic = False
elif self.config.algorithm.adv_estimator == 'remax':
elif self.config.algorithm.adv_estimator in ['grpo', 'reinforce_plue_plus', 'remax', 'rloo']:
self.use_critic = False
else:
raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment