Commit cafa8371 by ZhangXiaoyun

Initial commit

parent 468adf22
...@@ -43,10 +43,10 @@ dependencies = [ ...@@ -43,10 +43,10 @@ dependencies = [
"pybind11", "pybind11",
"pylatexenc", "pylatexenc",
"ray>=2.10", "ray>=2.10",
"tensordict<0.6", "tensordict==0.6",
"torchdata", "torchdata",
"transformers", "transformers",
"vllm<=0.6.3", "vllm==0.6.3",
'wandb', 'wandb',
] ]
......
...@@ -14,8 +14,8 @@ pyarrow>=15.0.0 ...@@ -14,8 +14,8 @@ pyarrow>=15.0.0
pybind11 pybind11
pylatexenc pylatexenc
ray[default] ray[default]
tensordict<0.6 tensordict==0.6
torchdata torchdata
transformers transformers
# vllm==0.6.3.post1 vllm==0.8.0
wandb wandb
...@@ -35,10 +35,10 @@ install_requires = [ ...@@ -35,10 +35,10 @@ install_requires = [
'pybind11', 'pybind11',
'pylatexenc', 'pylatexenc',
'ray>=2.10', 'ray>=2.10',
'tensordict<0.6', 'tensordict==0.6',
'torchdata', 'torchdata',
'transformers', 'transformers',
'vllm<=0.6.3', 'vllm==0.8.0',
'wandb', 'wandb',
] ]
......
...@@ -278,6 +278,16 @@ class RayPPOTrainer(object): ...@@ -278,6 +278,16 @@ class RayPPOTrainer(object):
self.role_worker_mapping = role_worker_mapping self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_reference_policy = Role.RefPolicy in role_worker_mapping
if self.config.actor_rollout_ref.actor.use_kl_loss:
self.use_reference_policy &= (self.config.algorithm.kl_ctrl.kl_coef != 0.0 and self.config.actor_rollout_ref.actor.kl_loss_coef != 0.0)
else:
self.use_reference_policy &= (self.config.algorithm.kl_ctrl.kl_coef != 0.0)
print(f"==" * 60)
print(f" self.use_reference_policy is {self.use_reference_policy}")
print(f"==" * 60)
self.use_rm = Role.RewardModel in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger() self.validation_generations_logger = ValidationGenerationsLogger()
......
...@@ -37,6 +37,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N ...@@ -37,6 +37,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
elif data_source in ['hiyouga/geometry3k']: elif data_source in ['hiyouga/geometry3k']:
from . import geo3k from . import geo3k
res = geo3k.compute_score(solution_str, ground_truth) res = geo3k.compute_score(solution_str, ground_truth)
elif data_source in ['deepscaler']:
from deepscaler.rewards.math_reward import deepscaler_reward_fn
res = deepscaler_reward_fn(solution_str, ground_truth)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -233,7 +233,7 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -233,7 +233,7 @@ class DataParallelPPOActor(BasePPOActor):
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.use_kl_loss: if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
select_keys.append('ref_log_prob') select_keys.append('ref_log_prob')
batch = data.select(batch_keys=select_keys).batch batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
...@@ -296,7 +296,7 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -296,7 +296,7 @@ class DataParallelPPOActor(BasePPOActor):
# compute policy loss # compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss: if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
ref_log_prob = data['ref_log_prob'] ref_log_prob = data['ref_log_prob']
# compute kl loss # compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob, kld = core_algos.kl_penalty(logprob=log_prob,
......
...@@ -223,7 +223,7 @@ class MegatronPPOActor(BasePPOActor): ...@@ -223,7 +223,7 @@ class MegatronPPOActor(BasePPOActor):
""" """
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
if self.config.use_kl_loss: if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
select_keys.append('ref_log_prob') select_keys.append('ref_log_prob')
data = data.select(batch_keys=select_keys) data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size, return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
...@@ -293,7 +293,7 @@ class MegatronPPOActor(BasePPOActor): ...@@ -293,7 +293,7 @@ class MegatronPPOActor(BasePPOActor):
policy_loss = pg_loss - entropy_loss * entropy_coeff policy_loss = pg_loss - entropy_loss * entropy_coeff
metrics = {} metrics = {}
if self.config.use_kl_loss: if self.config.use_kl_loss and self.config.kl_loss_coef != 0.0:
ref_log_prob = data['ref_log_prob'] ref_log_prob = data['ref_log_prob']
# compute kl loss # compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob, kld = core_algos.kl_penalty(logprob=log_prob,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment