Unverified Commit 3b1aef2f by _T_L_R_ Committed by GitHub

[Fix] Using an enumeration class to avoid spelling errors in adv_esti… (#377)

#369

---------

Co-authored-by: Thom <zhangyi@zhangyideMacBook-Pro.local>
parent 0dc8e859
...@@ -16,7 +16,7 @@ An naive implementation of split placment example ...@@ -16,7 +16,7 @@ An naive implementation of split placment example
""" """
from pprint import pprint from pprint import pprint
from verl import DataProto from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics, AdvantageEstimator
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
...@@ -69,7 +69,7 @@ def fit(self): ...@@ -69,7 +69,7 @@ def fit(self):
with _timer('gen', timing_raw): with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax': if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw): with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False gen_baseline_batch.meta_info['do_sample'] = False
......
...@@ -54,6 +54,17 @@ class Role(Enum): ...@@ -54,6 +54,17 @@ class Role(Enum):
ActorRolloutRef = 6 ActorRolloutRef = 6
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = 'gae'
GRPO = 'grpo'
REINFORCE_PLUS_PLUS = 'reinforce_plus_plus'
REMAX = 'remax'
RLOO = 'rloo'
@dataclass @dataclass
class ResourcePoolManager: class ResourcePoolManager:
""" """
...@@ -119,7 +130,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, ...@@ -119,7 +130,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# prepare response group # prepare response group
# TODO: add other ways to estimate advantages # TODO: add other ways to estimate advantages
if adv_estimator == 'gae': if adv_estimator == AdvantageEstimator.GAE:
values = data.batch['values'] values = data.batch['values']
responses = data.batch['responses'] responses = data.batch['responses']
response_length = responses.size(-1) response_length = responses.size(-1)
...@@ -133,7 +144,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ...@@ -133,7 +144,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
lam=lam) lam=lam)
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'grpo': elif adv_estimator == AdvantageEstimator.GRPO:
token_level_rewards = data.batch['token_level_rewards'] token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid'] index = data.non_tensor_batch['uid']
responses = data.batch['responses'] responses = data.batch['responses']
...@@ -145,7 +156,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ...@@ -145,7 +156,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
index=index) index=index)
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'reinforce_plus_plus': elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
token_level_rewards = data.batch['token_level_rewards'] token_level_rewards = data.batch['token_level_rewards']
responses = data.batch['responses'] responses = data.batch['responses']
response_length = responses.size(-1) response_length = responses.size(-1)
...@@ -155,7 +166,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ...@@ -155,7 +166,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma) token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma)
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'remax': elif adv_estimator == AdvantageEstimator.REMAX:
token_level_rewards = data.batch['token_level_rewards'] token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid'] index = data.non_tensor_batch['uid']
responses = data.batch['responses'] responses = data.batch['responses']
...@@ -171,7 +182,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re ...@@ -171,7 +182,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
data.batch['advantages'] = advantages data.batch['advantages'] = advantages
data.batch['returns'] = returns data.batch['returns'] = returns
elif adv_estimator == 'rloo': elif adv_estimator == AdvantageEstimator.RLOO:
token_level_rewards = data.batch['token_level_rewards'] token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid'] index = data.non_tensor_batch['uid']
responses = data.batch['responses'] responses = data.batch['responses']
...@@ -378,9 +389,12 @@ class RayPPOTrainer(object): ...@@ -378,9 +389,12 @@ class RayPPOTrainer(object):
else: else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.)
if self.config.algorithm.adv_estimator == 'gae': if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True self.use_critic = True
elif self.config.algorithm.adv_estimator in ['grpo', 'reinforce_plus_plus', 'remax', 'rloo']: elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX,
AdvantageEstimator.RLOO
]:
self.use_critic = False self.use_critic = False
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -869,7 +883,7 @@ class RayPPOTrainer(object): ...@@ -869,7 +883,7 @@ class RayPPOTrainer(object):
with _timer('gen', timing_raw): with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax': if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw): with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False gen_baseline_batch.meta_info['do_sample'] = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment