Unverified Commit 3b1aef2f by _T_L_R_ Committed by GitHub

[Fix] Using an enumeration class to avoid spelling errors in adv_esti… (#377)

#369

---------

Co-authored-by: Thom <zhangyi@zhangyideMacBook-Pro.local>
parent 0dc8e859
......@@ -16,7 +16,7 @@ An naive implementation of split placment example
"""
from pprint import pprint
from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics, AdvantageEstimator
from copy import deepcopy
import numpy as np
import torch
......@@ -69,7 +69,7 @@ def fit(self):
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax':
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
......
......@@ -54,6 +54,17 @@ class Role(Enum):
ActorRolloutRef = 6
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = 'gae'
GRPO = 'grpo'
REINFORCE_PLUS_PLUS = 'reinforce_plus_plus'
REMAX = 'remax'
RLOO = 'rloo'
@dataclass
class ResourcePoolManager:
"""
......@@ -119,7 +130,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == 'gae':
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch['values']
responses = data.batch['responses']
response_length = responses.size(-1)
......@@ -133,7 +144,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
lam=lam)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'grpo':
elif adv_estimator == AdvantageEstimator.GRPO:
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
......@@ -145,7 +156,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'reinforce_plus_plus':
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
token_level_rewards = data.batch['token_level_rewards']
responses = data.batch['responses']
response_length = responses.size(-1)
......@@ -155,7 +166,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'remax':
elif adv_estimator == AdvantageEstimator.REMAX:
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
......@@ -171,7 +182,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'rloo':
elif adv_estimator == AdvantageEstimator.RLOO:
token_level_rewards = data.batch['token_level_rewards']
index = data.non_tensor_batch['uid']
responses = data.batch['responses']
......@@ -378,9 +389,12 @@ class RayPPOTrainer(object):
else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.)
if self.config.algorithm.adv_estimator == 'gae':
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator in ['grpo', 'reinforce_plus_plus', 'remax', 'rloo']:
elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX,
AdvantageEstimator.RLOO
]:
self.use_critic = False
else:
raise NotImplementedError
......@@ -869,7 +883,7 @@ class RayPPOTrainer(object):
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax':
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment