Unverified Commit cd52d8b3 by Guangming Sheng Committed by GitHub

[algo] feat: support GRPO algorithm (#124)

- Implement KL loss, GRPO outcome adv, and utilize bon rollouts
- Provide scripts for deepseek and qwen on GSM8k. Can provide more for
other datasets.
- Support seq balance
- Train using qwen2-7b, GSM8k score can reach 0.89
parent 5b90cd7d
set -x
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=256 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size=256 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='deepseek_llm_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
set -x
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=256 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size=256 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \
+trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
...@@ -58,4 +58,4 @@ python3 -m verl.trainer.main_ppo \ ...@@ -58,4 +58,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \ trainer.nnodes=1 \
trainer.save_freq=-1 \ trainer.save_freq=-1 \
trainer.test_freq=5 \ trainer.test_freq=5 \
trainer.total_epochs=100 $@ trainer.total_epochs=15 $@
...@@ -49,4 +49,4 @@ python3 -m verl.trainer.main_ppo \ ...@@ -49,4 +49,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \ trainer.nnodes=1 \
trainer.save_freq=-1 \ trainer.save_freq=-1 \
trainer.test_freq=5 \ trainer.test_freq=5 \
trainer.total_epochs=100 $@ trainer.total_epochs=15 $@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
...@@ -28,6 +28,9 @@ actor_rollout_ref: ...@@ -28,6 +28,9 @@ actor_rollout_ref:
grad_clip: 1.0 grad_clip: 1.0
clip_ratio: 0.2 clip_ratio: 0.2
entropy_coeff: 0.0 entropy_coeff: 0.0
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1 ppo_epochs: 1
shuffle: False shuffle: False
ulysses_sequence_parallel_size: 1 # sp size ulysses_sequence_parallel_size: 1 # sp size
......
...@@ -48,5 +48,5 @@ if __name__ == '__main__': ...@@ -48,5 +48,5 @@ if __name__ == '__main__':
best_reward = reward best_reward = reward
print(f'Best reward is {best_reward}') print(f'Best reward is {best_reward}')
assert best_reward > 0.2, f'Best reward must be greater than 0.3. best_reward: {best_reward}' assert best_reward > 0.2, f'Best reward must be greater than 0.2. best_reward: {best_reward}'
print('Check passes') print('Check passes')
...@@ -27,6 +27,9 @@ actor_rollout_ref: ...@@ -27,6 +27,9 @@ actor_rollout_ref:
grad_clip: 1.0 grad_clip: 1.0
clip_ratio: 0.2 clip_ratio: 0.2
entropy_coeff: 0.001 entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1 ppo_epochs: 1
shuffle: False shuffle: False
ulysses_sequence_parallel_size: 1 # sp size ulysses_sequence_parallel_size: 1 # sp size
......
...@@ -20,6 +20,7 @@ implement PPO ...@@ -20,6 +20,7 @@ implement PPO
import numpy as np import numpy as np
import torch import torch
from collections import defaultdict
import verl.utils.torch_functional as verl_F import verl.utils.torch_functional as verl_F
...@@ -106,6 +107,54 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc ...@@ -106,6 +107,54 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
return advantages, returns return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
non_zero_mask = (token_level_rewards != 0)
scores = (token_level_rewards * non_zero_mask).sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio return token_level_scores - kl * kl_ratio
...@@ -210,6 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe ...@@ -210,6 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
if kl_penalty == "mse": if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square() return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty == 'low_var_kl':
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full": if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError raise NotImplementedError
......
...@@ -17,6 +17,7 @@ This trainer supports model-agonistic model initialization with huggingface ...@@ -17,6 +17,7 @@ This trainer supports model-agonistic model initialization with huggingface
""" """
import os import os
import uuid
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
...@@ -112,16 +113,16 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, ...@@ -112,16 +113,16 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
return data, metrics return data, metrics
def compute_advantage(data: DataProto, gamma, lam, adv_estimator): def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
values = data.batch['values'] # prepare response group
responses = data.batch['responses']
response_length = responses.size(1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch['token_level_rewards']
# TODO: add other ways to estimate advantages # TODO: add other ways to estimate advantages
if adv_estimator == 'gae': if adv_estimator == 'gae':
values = data.batch['values']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch['token_level_rewards']
advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards, advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards,
values=values, values=values,
eos_mask=response_mask, eos_mask=response_mask,
...@@ -129,6 +130,18 @@ def compute_advantage(data: DataProto, gamma, lam, adv_estimator): ...@@ -129,6 +130,18 @@ def compute_advantage(data: DataProto, gamma, lam, adv_estimator):
lam=lam) lam=lam)
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'grpo':
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards,
eos_mask=response_mask,
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else: else:
raise NotImplementedError raise NotImplementedError
return data return data
...@@ -156,14 +169,13 @@ def _compute_response_info(batch): ...@@ -156,14 +169,13 @@ def _compute_response_info(batch):
) )
def compute_data_metrics(batch): def compute_data_metrics(batch, use_critic=True):
# TODO: add response length # TODO: add response length
sequence_score = batch.batch['token_level_scores'].sum(-1) sequence_score = batch.batch['token_level_scores'].sum(-1)
sequence_reward = batch.batch['token_level_rewards'].sum(-1) sequence_reward = batch.batch['token_level_rewards'].sum(-1)
advantages = batch.batch['advantages'] advantages = batch.batch['advantages']
returns = batch.batch['returns'] returns = batch.batch['returns']
values = batch.batch['values']
max_response_length = batch.batch['responses'].shape[-1] max_response_length = batch.batch['responses'].shape[-1]
...@@ -178,10 +190,12 @@ def compute_data_metrics(batch): ...@@ -178,10 +190,12 @@ def compute_data_metrics(batch):
valid_adv = torch.masked_select(advantages, response_mask) valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask) valid_returns = torch.masked_select(returns, response_mask)
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values) if use_critic:
return_var = torch.var(valid_returns) values = batch.batch['values']
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = { metrics = {
# score # score
...@@ -212,13 +226,15 @@ def compute_data_metrics(batch): ...@@ -212,13 +226,15 @@ def compute_data_metrics(batch):
torch.max(valid_returns).detach().item(), torch.max(valid_returns).detach().item(),
'critic/returns/min': 'critic/returns/min':
torch.min(valid_returns).detach().item(), torch.min(valid_returns).detach().item(),
# values **({
'critic/values/mean': # values
torch.mean(valid_values).detach().item(), 'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': 'critic/values/max': torch.max(valid_values).detach().item(),
torch.max(valid_values).detach().item(), 'critic/values/min': torch.min(valid_values).detach().item(),
'critic/values/min': # vf explained var
torch.min(valid_values).detach().item(), 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),
# response length # response length
'response_length/mean': 'response_length/mean':
torch.mean(response_length).detach().item(), torch.mean(response_length).detach().item(),
...@@ -237,8 +253,6 @@ def compute_data_metrics(batch): ...@@ -237,8 +253,6 @@ def compute_data_metrics(batch):
torch.min(prompt_length).detach().item(), torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio': 'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} }
return metrics return metrics
...@@ -449,8 +463,9 @@ class RayPPOTrainer(object): ...@@ -449,8 +463,9 @@ class RayPPOTrainer(object):
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls
self.use_critic = True self.use_critic = True
elif self.config.algorithm.adv_estimator == 'grpo':
self.use_critic = False
else: else:
# support GRPO and ReMax
raise NotImplementedError raise NotImplementedError
# create reference policy if needed # create reference policy if needed
...@@ -572,6 +587,8 @@ class RayPPOTrainer(object): ...@@ -572,6 +587,8 @@ class RayPPOTrainer(object):
with _timer('gen', timing_raw): with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout # repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output) batch = batch.union(gen_batch_output)
...@@ -591,9 +608,10 @@ class RayPPOTrainer(object): ...@@ -591,9 +608,10 @@ class RayPPOTrainer(object):
batch = batch.union(ref_log_prob) batch = batch.union(ref_log_prob)
# compute values # compute values
with _timer('values', timing_raw): if self.use_critic:
values = self.critic_wg.compute_values(batch) with _timer('values', timing_raw):
batch = batch.union(values) values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw): with _timer('adv', timing_raw):
# compute scores. Support both model and function-based. # compute scores. Support both model and function-based.
...@@ -609,16 +627,20 @@ class RayPPOTrainer(object): ...@@ -609,16 +627,20 @@ class RayPPOTrainer(object):
batch.batch['token_level_scores'] = reward_tensor batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available # compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch, if not self.config.actor_rollout_ref.actor.use_kl_loss:
kl_ctrl=self.kl_ctrl, batch, kl_metrics = apply_kl_penalty(batch,
kl_penalty=self.config.algorithm.kl_penalty) kl_ctrl=self.kl_ctrl,
metrics.update(kl_metrics) kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
# compute advantages, executed on the driver process # compute advantages, executed on the driver process
batch = compute_advantage(batch, batch = compute_advantage(batch,
self.config.algorithm.gamma, adv_estimator=self.config.algorithm.adv_estimator,
self.config.algorithm.lam, gamma=self.config.algorithm.gamma,
adv_estimator=self.config.algorithm.adv_estimator) lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
# update critic # update critic
if self.use_critic: if self.use_critic:
...@@ -648,7 +670,7 @@ class RayPPOTrainer(object): ...@@ -648,7 +670,7 @@ class RayPPOTrainer(object):
self._save_checkpoint() self._save_checkpoint()
# collect metrics # collect metrics
metrics.update(compute_data_metrics(batch=batch)) metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend # TODO: make a canonical logger that supports various backend
......
...@@ -142,4 +142,8 @@ class RLHFDataset(Dataset): ...@@ -142,4 +142,8 @@ class RLHFDataset(Dataset):
if self.return_raw_chat: if self.return_raw_chat:
row_dict['raw_prompt'] = chat.tolist() row_dict['raw_prompt'] = chat.tolist()
return row_dict # add index for each prompt
\ No newline at end of file index = row_dict.get("extra_info", {}).get("index", 0)
row_dict["index"] = index
return row_dict
...@@ -26,7 +26,7 @@ from verl import DataProto ...@@ -26,7 +26,7 @@ from verl import DataProto
from verl.trainer.ppo import core_algos from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad from verl.utils.torch_functional import logprobs_from_logits, masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F import verl.utils.torch_functional as verl_F
...@@ -209,6 +209,8 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -209,6 +209,8 @@ class DataParallelPPOActor(BasePPOActor):
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.use_kl_loss:
select_keys.append('ref_log_prob')
batch = data.select(batch_keys=select_keys).batch batch = data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor # Split to make minibatch iterator for updating the actor
...@@ -254,6 +256,18 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -254,6 +256,18 @@ class DataParallelPPOActor(BasePPOActor):
# compute policy loss # compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)
policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef
loss = policy_loss / self.gradient_accumulation loss = policy_loss / self.gradient_accumulation
loss.backward() loss.backward()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment