Commit e75f97ef by Yaoyu Zhu

add adaptive batch size and fix bugs for history sampling

parent d08a947e
#!/bin/bash
set -x
set -euxo pipefail
project_name='DAPO'
exp_name='DAPO-Early-Qwen2.5-32B'
adv_estimator=grpo
kl_coef=0.0
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 1))
overlong_penalty_factor=1.0
# An early version for DAPO
enable_filter_groups=True
train_prompt_bsz=128
train_prompt_mini_bsz=64
n_resp_per_prompt=16
use_token_level_loss=True
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
# Algorithm
## Train
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 16))
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4
ppo_max_token_len_per_gpu=32768
num_gpu=$(($USER_GPUS_PER_NODE * $SLURM_JOB_NUM_NODES))
export VLLM_USE_V1=1
echo "$WANDB_DIR"
echo "$SAVE_DIR"
echo "$WANDB_API_KEY"
# Set default model path if not provided
MODEL_PATH="/nfs_global/CodeV-R1-Distill-Qwen-7B"
# Train over a single node, 8 A100-80GB GPUs.
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=/nfs_global/verl/data/codev/v1/3.1k_r1_filtered/train.parquet \
data.val_files=/nfs_global/verl/data/codev/v1/3.1k_r1_filtered/test.parquet \
data.train_batch_size=${train_prompt_bsz} \
data.val_batch_size=512 \
data.max_prompt_length=2048 \
data.max_response_length=$max_response_length \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.max_num_gen_batches=999 \
algorithm.filter_groups.metric=acc \
data.gen_batch_size=$((($train_prompt_bsz * 4 + $num_gpu - 1) / $num_gpu * $num_gpu)) \
actor_rollout_ref.model.path=$MODEL_PATH \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
+actor_rollout_ref.model.use_liger=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0.0 \
actor_rollout_ref.actor.use_dynamic_bsz=True\
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.00 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=0.5 \
actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$(($ppo_max_token_len_per_gpu*2)) \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.val_kwargs.n=4 \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
reward_model.reward_manager=prime \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \
custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \
custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
custom_reward_function.train.path=verl/utils/reward_score/codev.py \
custom_reward_function.train.name=compute_score_wrapper \
custom_reward_function.train.continuous_reward.enable=False \
algorithm.kl_ctrl.kl_coef=0.0 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='codev' \
trainer.experiment_name='codev-7b-3.1kdata' \
trainer.n_gpus_per_node=$USER_GPUS_PER_NODE \
trainer.nnodes=$SLURM_JOB_NUM_NODES \
+trainer.val_before_train=False \
trainer.default_local_dir=$SAVE_DIR \
trainer.resume_mode=auto \
trainer.default_hdfs_dir=null \
trainer.save_freq=20 \
trainer.test_freq=20 \
trainer.total_epochs=100 "${@:1}"
# custom_reward_function.path=/nfs_global/S/zhuyaoyu/projects/dapo/verl/utils/reward_score/codev.py \
\ No newline at end of file
#!/bin/bash
set -x
set -euxo pipefail
project_name='DAPO'
exp_name='DAPO-Early-Qwen2.5-32B'
adv_estimator=grpo
kl_coef=0.0
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 1))
overlong_penalty_factor=1.0
# An early version for DAPO
enable_filter_groups=True
train_prompt_bsz=128
train_prompt_mini_bsz=64
n_resp_per_prompt=16
use_token_level_loss=True
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
# Algorithm
## Train
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 16))
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4
ppo_max_token_len_per_gpu=32768
num_gpu=$(($USER_GPUS_PER_NODE * $SLURM_JOB_NUM_NODES))
export VLLM_USE_V1=1
echo "$WANDB_DIR"
echo "$SAVE_DIR"
echo "$WANDB_API_KEY"
# Set default model path if not provided
MODEL_PATH="/nfs_global/CodeV-R1-Distill-Qwen-7B"
# Train over a single node, 8 A100-80GB GPUs.
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=/nfs_global/verl/data/codev/v1/3.1k_r1_filtered/train.parquet \
data.val_files=/nfs_global/verl/data/codev/v1/3.1k_r1_filtered/test.parquet \
data.train_batch_size=${train_prompt_bsz} \
data.val_batch_size=512 \
data.max_prompt_length=2048 \
data.max_response_length=$max_response_length \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.accelerate=False \
algorithm.filter_groups.max_num_gen_batches=999 \
algorithm.filter_groups.metric=acc \
data.gen_batch_size=$((($train_prompt_bsz * 2 + $num_gpu - 1) / $num_gpu * $num_gpu)) \
actor_rollout_ref.model.path=$MODEL_PATH \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
+actor_rollout_ref.model.use_liger=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0.0 \
actor_rollout_ref.actor.use_dynamic_bsz=True\
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.00 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=0.5 \
actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$(($ppo_max_token_len_per_gpu*2)) \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.val_kwargs.n=4 \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
reward_model.reward_manager=prime \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \
custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \
custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
custom_reward_function.train.path=verl/utils/reward_score/codev.py \
custom_reward_function.train.name=compute_score_wrapper \
custom_reward_function.train.continuous_reward.enable=False \
algorithm.kl_ctrl.kl_coef=0.0 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='codev' \
trainer.experiment_name='codev-7b-3.1kdata' \
trainer.n_gpus_per_node=$USER_GPUS_PER_NODE \
trainer.nnodes=$SLURM_JOB_NUM_NODES \
+trainer.val_before_train=False \
trainer.default_local_dir=$SAVE_DIR \
trainer.resume_mode=auto \
trainer.default_hdfs_dir=null \
trainer.save_freq=20 \
trainer.test_freq=20 \
trainer.total_epochs=100 "${@:1}"
# custom_reward_function.path=/nfs_global/S/zhuyaoyu/projects/dapo/verl/utils/reward_score/codev.py \
\ No newline at end of file
......@@ -203,6 +203,7 @@ algorithm:
kl_coef: 0.001
filter_groups:
enable: False # We try to avoid forgetting to set enable
accelerate: True # accelerate DAPO
metric: null # acc / score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 0 # Non-positive values mean no upper limit
......
......@@ -316,6 +316,7 @@ class RayPPOTrainer(object):
# effective dapo
self.filter_flags = {}
self.dapo_accelerate = self.config.algorithm.filter_groups.accelerate
def _validate_config(self):
......@@ -502,8 +503,10 @@ class RayPPOTrainer(object):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
self.redundant_ratio = 1.1
self.select_ratio = 1.0 * self.redundant_ratio * self.config.data.train_batch_size / self.config.data.gen_batch_size
# self.redundant_ratio = 1.1
# self.select_ratio = 1.0 * self.redundant_ratio * self.config.data.train_batch_size / self.config.data.gen_batch_size
self.select_ratio = 1.0
self.select_ratio_part = 1.0
def _maybe_log_val_generations(self, inputs, outputs, scores):
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
......@@ -820,10 +823,11 @@ class RayPPOTrainer(object):
with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps))
# save filter flags
local_latest_filter_flags_path = os.path.join(local_global_step_folder, 'latest_filter_flags.json')
with open(local_latest_filter_flags_path, 'w', encoding='utf-8') as f:
json.dump(self.filter_flags, f)
if self.dapo_accelerate:
# save filter flags
local_latest_filter_flags_path = os.path.join(local_global_step_folder, 'latest_filter_flags.json')
with open(local_latest_filter_flags_path, 'w', encoding='utf-8') as f:
json.dump(self.filter_flags, f)
def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
......@@ -878,14 +882,15 @@ class RayPPOTrainer(object):
else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
# load filter flags
local_latest_filter_flags_path = os.path.join(global_step_folder, 'latest_filter_flags.json')
try:
with open(local_latest_filter_flags_path, 'r', encoding='utf-8') as f:
self.filter_flags = json.load(f)
except Exception as e:
print(f'Failed to load filter flags from {local_latest_filter_flags_path}.')
self.filter_flags = {}
if self.dapo_accelerate:
# load filter flags
local_latest_filter_flags_path = os.path.join(global_step_folder, 'latest_filter_flags.json')
try:
with open(local_latest_filter_flags_path, 'r', encoding='utf-8') as f:
self.filter_flags = json.load(f)
except Exception as e:
print(f'Failed to load filter flags from {local_latest_filter_flags_path}.')
self.filter_flags = {}
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
......@@ -944,50 +949,59 @@ class RayPPOTrainer(object):
num_prompt_in_batch = 0
num_gen_batches = 0
new_batch_flag = True
dapo_raw_batch_size_accumulate = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_part in self.train_dataloader:
# 切小批量样本生成
if new_batch_flag:
new_batch = None
new_batch_flag = False
num_gen_remaining = self.config.data.train_batch_size - num_prompt_in_batch
metrics = {}
new_batch = DataProto.from_single_dict(batch_part) if new_batch is None else DataProto.concat([new_batch, DataProto.from_single_dict(batch_part)])
print('Length of batch_part and new_batch:', len(batch_part), len(new_batch))
print(batch_part.keys())
if len(new_batch) < num_gen_remaining / self.select_ratio * self.redundant_ratio:
continue
else:
new_batch_flag = True
if 'problem_id' in new_batch.non_tensor_batch.keys():
filter_key = 'problem_id'
elif 'question' in new_batch.batch.keys():
# if 'question' in new_batch.batch.keys():
filter_key = 'question'
else:
filter_key = 'raw_prompt_ids'
if self.dapo_accelerate:
num_gen_remaining = self.config.data.train_batch_size - num_prompt_in_batch
select_idx = []
for id, item in enumerate(new_batch.non_tensor_batch[filter_key]):
# 这里需不需要改成一个计数?多少次再把这条数据给放出来
if item in self.filter_flags and self.filter_flags[item] == 1:
new_batch = DataProto.from_single_dict(batch_part) if new_batch is None else DataProto.concat([new_batch, DataProto.from_single_dict(batch_part)])
print('Length of batch_part and new_batch:', len(batch_part), len(new_batch))
# if len(new_batch) < num_gen_remaining / self.select_ratio * self.redundant_ratio:
if len(new_batch) < num_gen_remaining / max(self.select_ratio, self.select_ratio_part):
continue
select_idx.append(id)
self.select_ratio += 0.1 * (num_gen_remaining / self.config.data.train_batch_size) * (len(select_idx) / num_gen_remaining - self.select_ratio)
print("Now, self.select_ratio is", self.select_ratio)
else:
new_batch_flag = True
if 'problem_id' in new_batch.non_tensor_batch.keys():
filter_key = 'problem_id'
elif 'question' in new_batch.batch.keys():
# if 'question' in new_batch.batch.keys():
filter_key = 'question'
else:
filter_key = 'raw_prompt_ids'
select_idx = []
for id, item in enumerate(new_batch.non_tensor_batch[filter_key]):
# 这里需不需要改成一个计数?多少次再把这条数据给放出来
# json.dump再load之后key就全变成字符串了,这边统一成字符串吧
item = str(item)
if item in self.filter_flags and self.filter_flags[item] == 1:
continue
select_idx.append(id)
print(f'len(select_idx) is {len(select_idx)}, len(new_batch) is {len(new_batch)}')
dapo_raw_batch_size = len(new_batch)
dapo_raw_batch_size_accumulate += dapo_raw_batch_size
if len(select_idx) == 0:
# 换下一批样本继续生成吧
continue
new_batch = new_batch.select_idxs(select_idx)
if len(select_idx) == 0:
# 换下一批样本继续生成吧
continue
# pad to be divisible by dp_size
new_batch, pad_size = pad_dataproto_to_divisor(new_batch, self.actor_rollout_wg.world_size)
new_batch = new_batch.select_idxs(select_idx)
# pad to be divisible by dp_size
new_batch, pad_size = pad_dataproto_to_divisor(new_batch, self.actor_rollout_wg.world_size)
else:
if len(new_batch) < self.config.data.gen_batch_size:
continue
else:
new_batch_flag = True
num_gen_batches += 1
# pop those keys for generation
......@@ -1104,18 +1118,28 @@ class RayPPOTrainer(object):
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
# filter uids
filter_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std == 0 and np.mean(prompt_uid2metric_vals[uid]) == 1]
filter_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in filter_prompt_uids:
filter_traj_idxs.append(idx)
for idx in filter_traj_idxs:
# print(new_batch.non_tensor_batch[filter_key][idx])
self.filter_flags[new_batch.non_tensor_batch[filter_key][idx]] = 1
if self.dapo_accelerate:
# filter uids, here we assume reward for correct is 1
filter_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std == 0 and np.mean(prompt_uid2metric_vals[uid]) == 1]
filter_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in filter_prompt_uids:
filter_traj_idxs.append(idx)
for idx in filter_traj_idxs:
# print(new_batch.non_tensor_batch[filter_key][idx])
self.filter_flags[str(new_batch.non_tensor_batch[filter_key][idx])] = 1
print(f'num_gen_remaining is {num_gen_remaining}, train_batch_size is {self.config.data.train_batch_size}, len(kept_prompt_uids) is {len(kept_prompt_uids)}')
print(f'len(new_batch) is {len(new_batch)}, select_ratio is {self.select_ratio}, select_ratio_part is {self.select_ratio_part}')
# self.select_ratio += 0.1 * num_gen_remaining / self.config.data.train_batch_size * (1.0 * len(kept_prompt_uids) / dapo_raw_batch_size - self.select_ratio)
# 一步到位,select_ratio按当前转化率最低的轮次做保守估计
self.select_ratio_part = min(self.select_ratio_part, 1.0 * len(kept_prompt_uids) / dapo_raw_batch_size)
if len(kept_prompt_uids) >= num_gen_remaining:
self.select_ratio = min(self.select_ratio, 1.0 * self.config.data.train_batch_size / dapo_raw_batch_size_accumulate)
self.select_ratio_part = 1.0
print(f"Now, self.select_ratio is {self.select_ratio}, self.select_ratio_part is {self.select_ratio_part}")
new_batch = new_batch[kept_traj_idxs]
if batch is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment