Commit cafa8371 by ZhangXiaoyun

Initial commit

parent 468adf22
......@@ -43,10 +43,10 @@ dependencies = [
"pybind11",
"pylatexenc",
"ray>=2.10",
"tensordict<0.6",
"tensordict==0.6",
"torchdata",
"transformers",
"vllm<=0.6.3",
"vllm==0.6.3",
'wandb',
]
......
......@@ -14,8 +14,8 @@ pyarrow>=15.0.0
pybind11
pylatexenc
ray[default]
tensordict<0.6
tensordict==0.6
torchdata
transformers
# vllm==0.6.3.post1
vllm==0.8.0
wandb
......@@ -35,10 +35,10 @@ install_requires = [
'pybind11',
'pylatexenc',
'ray>=2.10',
'tensordict<0.6',
'tensordict==0.6',
'torchdata',
'transformers',
'vllm<=0.6.3',
'vllm==0.8.0',
'wandb',
]
......
......@@ -278,6 +278,16 @@ class RayPPOTrainer(object):
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
if self.config.actor_rollout_ref.actor.use_kl_loss:
self.use_reference_policy &= (self.config.algorithm.kl_ctrl.kl_coef != 0.0 and self.config.actor_rollout_ref.actor.kl_loss_coef != 0.0)
else:
self.use_reference_policy &= (self.config.algorithm.kl_ctrl.kl_coef != 0.0)
print(f"==" * 60)
print(f" self.use_reference_policy is {self.use_reference_policy}")
print(f"==" * 60)
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
......
......@@ -37,6 +37,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
elif data_source in ['hiyouga/geometry3k']:
from . import geo3k
res = geo3k.compute_score(solution_str, ground_truth)
elif data_source in ['deepscaler']:
from deepscaler.rewards.math_reward import deepscaler_reward_fn
res = deepscaler_reward_fn(solution_str, ground_truth)
else:
raise NotImplementedError
......
......@@ -233,7 +233,7 @@ class DataParallelPPOActor(BasePPOActor):
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.use_kl_loss:
if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
select_keys.append('ref_log_prob')
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
......@@ -296,7 +296,7 @@ class DataParallelPPOActor(BasePPOActor):
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss:
if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
......
......@@ -223,7 +223,7 @@ class MegatronPPOActor(BasePPOActor):
"""
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.use_kl_loss:
if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
select_keys.append('ref_log_prob')
data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
......@@ -293,7 +293,7 @@ class MegatronPPOActor(BasePPOActor):
policy_loss = pg_loss - entropy_loss * entropy_coeff
metrics = {}
if self.config.use_kl_loss:
if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment