Commit 04f41b77 by Shawn/Yuxuan Tong

feat: docs & metric for filtering

parent 1156f863
# DAPO
🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo)
🐶 [中文博客@机器之心](https://mp.weixin.qq.com/s/_w_HtjNQiG-yP5LEN85o0Q)
> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.
>
> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)
>
> DAPO samples a group of outputs $\left\{o_i\right\}_{i=1}^G$ for each question $q$ paired with the answer $a$, and optimizes the policy via the following objective:
$$
\begin{aligned}
\mathcal{J}_{\text {DAPO }}(\theta)= & \mathbb{E}_{(q, a) \sim \mathcal{D},\left\{o_i \mid\right\}_{i=1}^G \sim \pi_{\theta_{\text {old }}}(\mid q)} \\
& {\left[\frac{1}{\sum_{i=1}^G\left|o_i\right|} \sum_{i=1}^G \sum_{t=1}^{\left|o_i\right|} \min \left(r_{i, t}(\theta) \hat{A}_{i, t}, \operatorname{clip}\left(r_{i, t}(\theta), 1-\varepsilon_{\text {low }}, 1+\varepsilon_{\text {high }}\right) \hat{A}_{i, t}\right)\right] } \\
\text { s.t. } & 0<\mid\left\{o_i \mid \text { is\_equivalent }\left(a, o_i\right)\right\} \mid<G,
\end{aligned}
$$
> where
$$
r_{i, t}(\theta)=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, \quad \hat{A}_{i, t}=\frac{R_i-\operatorname{mean}\left(\left\{R_i\right\}_{i=1}^G\right)}{\operatorname{std}\left(\left\{R_i\right\}_{i=1}^G\right)} .
$$
## Quickstart
1. Prepare the datasets **on the Ray cluster**:
```bash
bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default
```
2. Submit the job to the Ray cluster **from any machine**:
```bash
cd verl # Repo root
export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to
export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster
# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml
export RUNTIME_ENV="./verl/trainer/runtime_env.yaml"
bash recipe/dapo/run_dapo_qwen2.5_32b.sh
```
## Configuration
### Separated Clip Episilons (-> Clip-Higher)
An example configuration:
```yaml
actor_rollout_ref:
actor:
clip_ratio_low: 0.2
clip_ratio_high: 0.28
```
`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective.
Core relevant code:
```python
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
```
### Dynamic Sampling (with Group Filtering)
An example configuration:
```yaml
data:
gen_batch_size: 1536
train_batch_size: 512
algorithm:
filter_groups:
enable: True
metric: acc # / reward / final_reward
fill_to_train_bsz: True
drop_last_mini_batch: True
```
Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.
Setting `fill_to_train_bsz` to `True` will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size`.
Seting `drop_last_mini_batch` to `True` might be helpful when `fill_to_train_bsz` is `False` since the last mini-batch might be incomplete due to possibly strong filteration.
### Token-level Policy Gradient Loss
An example configuration:
```yaml
actor_rollout_ref:
actor:
use_token_level_loss: True
```
Setting `use_token_level_loss` to `True` will mean the policy gradient loss across all the tokens in all the sequences in a batch.
Core relevant code:
```python
if use_token_level_loss:
pg_loss = verl_F.masked_mean(pg_losses, eos_mask)
else:
pg_loss = torch.sum(pg_losses * eos_mask, dim=1) / seq_len_per_sample
pg_loss = torch.mean(pg_loss)
```
### Overlong Reward Shaping
An example configuration:
```yaml
data:
max_response_length: 20480 # 16384 + 4096
custom_reward_function:
overlong_buffer:
enable: True
len: 4096
penalty_factor: 1.0
```
Setting `overlong_buffer.enable` to `True` will penalize the outputs whose length entering the last `overlong_buffer.len` tokens before the `max_response_length`.
The penalty increases linearly from 0 to `overlong_buffer.penalty_factor`.
Core relevant code:
```python
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
final_reward += overlong_reward
```
## Reproduction Runs
### DAPO w/o Token-level PG loss & Dynamic Sampling -> 44% on AIME 2024
We achieve 44% accuracy on AIME 2024 with a early version of DAPO w/o Token-level PG loss & Dynamic Sampling.
The training record will be available on WandB soon.
The corresponding training script in [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh).
### DAPO -> 50% on AIME 2024
Coming soon.
\ No newline at end of file
......@@ -2,20 +2,25 @@
set -euxo pipefail
project_name='DAPO'
exp_name='DAPO-Pre-Qwen2.5-32B'
exp_name='DAPO-Early-Qwen2.5-32B'
adv_estimator=grpo
kl_coef=0.0
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
overlong_buffer_len=$((1024 * 4))
# Intermediate ablation for DAPO
# An early version for DAPO
use_token_level_loss=False
enable_filter_groups=False
gen_prompt_bsz=512 # NOTE: no filtering here
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
RUNTIME_ENV=${RUNTIME_ENV:-"./verl/trainer/runtime_env.yaml"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
......@@ -23,15 +28,10 @@ MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
## Train
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
gen_prompt_bsz=512 # NOTE: no filtering here
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
......@@ -45,7 +45,7 @@ offload=True
gen_tp=4
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${PWD}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
......
......@@ -11,10 +11,15 @@ clip_ratio_high=0.28
overlong_buffer_len=$((1024 * 4))
use_token_level_loss=True
enable_filter_groups=True
gen_prompt_bsz=1024
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
RUNTIME_ENV=${RUNTIME_ENV:-"./verl/trainer/runtime_env.yaml"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
......@@ -27,10 +32,6 @@ TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
## Train
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
gen_prompt_bsz=1024
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
......@@ -44,7 +45,7 @@ offload=True
gen_tp=4
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${PWD}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
......
......@@ -6,7 +6,8 @@ exp_name='DAPO-Qwen2.5-7B-Math-Test'
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
RUNTIME_ENV=${RUNTIME_ENV:-"./verl/trainer/runtime_env.yaml"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
......@@ -34,7 +35,7 @@ train_micro_batch_size=null
offload=False
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${PWD}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
......@@ -52,7 +53,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
algorithm.filter_groups.enable=True \
algorithm.filter_groups.fill_train_batch=True \
algorithm.filter_groups.fill_to_train_batch_size=True \
algorithm.filter_groups.drop_last_mini_batch=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
......
......@@ -29,7 +29,8 @@ actor_rollout_ref:
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
use_torch_compile: True # False to disable torch compile
clip_ratio: 0.2
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
clip_ratio_low: 0.2
clip_ratio_high: 0.2
use_token_level_loss: True
......@@ -176,6 +177,8 @@ algorithm:
kl_coef: 0.001
filter_groups:
enable: False
metric: acc # / reward / seq_final_reward
# `fill_to_train_bsz` will repeat sampling with generation batch size until there are enough qualified groups for the training batch size
fill_to_train_bsz: True
drop_last_mini_batch: True
......
......@@ -32,7 +32,8 @@ actor_rollout_ref:
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
clip_ratio_low: 0.2
clip_ratio_high: 0.2
use_token_level_loss: True
......@@ -184,6 +185,8 @@ algorithm:
kl_coef: 0.001
filter_groups:
enable: False
metric: acc # / reward / seq_final_reward
# `fill_to_train_bsz` will repeat sampling with generation batch size until there are enough qualified groups for the training batch size
fill_to_train_bsz: True
drop_last_mini_batch: True
......
......@@ -903,9 +903,20 @@ class RayPPOTrainer(object):
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(batch, return_dict=True)
reward_tensor = reward_result['reward_tensor']
reward_extra_infos_dict = reward_result['extra_info']
except Exception as e:
reward_tensor = self.reward_fn(batch)
reward_extra_infos_dict = {}
batch.batch['token_level_scores'] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update(reward_extra_infos_dict)
# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
......@@ -917,28 +928,32 @@ class RayPPOTrainer(object):
if self.config.algorithm.filter_groups.enable:
filter_metric_dict = {}
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
batch.non_tensor_batch["seq_final_reward"] = batch.batch['token_level_scores'].sum(
dim=-1).tolist()
# Collect the sequence reward for each trajectory
prompt_uid2seq_rewards = defaultdict(list)
for uid, tok_rewards in zip(batch.non_tensor_batch['uid'], batch.batch['token_level_rewards']):
seq_reward = torch.sum(tok_rewards).item()
prompt_uid2seq_rewards[uid].append(seq_reward)
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(batch.non_tensor_batch['uid'], batch.batch[metric_name]):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2seq_reward_std = {}
for prompt_uid, seq_rewards in prompt_uid2seq_rewards.items():
prompt_uid2seq_reward_std[prompt_uid] = np.std(seq_rewards)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [uid for uid, std in prompt_uid2seq_reward_std.items() if std > 0]
filter_metric_dict["non_uniform_reward_prompt_ratio"] = len(kept_prompt_uids) / len(
prompt_uid2seq_rewards)
filter_metric_dict["non_uniform_reward_prompt_bsz"] = len(kept_prompt_uids)
kept_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std > 0]
filter_metric_dict[f"qualified_prompt_ratio/{metric_name}"] = len(kept_prompt_uids) / len(
prompt_uid2metric_vals)
filter_metric_dict[f"qualified_prompt_bsz/{metric_name}"] = len(kept_prompt_uids)
train_prompt_bsz = self.config.data.train_batch_size
fill_to_train_bsz = self.config.algorithm.filter_groups.fill_to_train_bsz
if len(kept_prompt_uids) > train_prompt_bsz or not fill_to_train_bsz:
kept_prompt_uids = kept_prompt_uids[:train_prompt_bsz]
else:
for prompt_uid in prompt_uid2seq_reward_std.keys():
for prompt_uid in prompt_uid2metric_std.keys():
if prompt_uid not in kept_prompt_uids:
kept_prompt_uids.append(prompt_uid)
if len(kept_prompt_uids) == train_prompt_bsz:
......@@ -948,7 +963,7 @@ class RayPPOTrainer(object):
for traj_idx, traj_prompt_uid in enumerate(batch.non_tensor_batch['uid']):
if traj_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(traj_idx)
filter_metric_dict["non_uniform_reward_traj_bsz"] = len(kept_traj_idxs)
filter_metric_dict[f"qualified_traj_bsz/{metric_name}"] = len(kept_traj_idxs)
world_size = self.actor_rollout_wg.world_size
kept_traj_idxs = kept_traj_idxs[:len(kept_traj_idxs) // world_size * world_size]
......
......@@ -139,7 +139,8 @@ class NaiveRewardManager:
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
exceed_len = valid_response_length - (self.max_resp_len - overlong_buffer_len)
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
final_reward += overlong_reward
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment