Unverified Commit 6e8667bd by Guangming Sheng Committed by GitHub

[example] add a split placement tutorial (#43)

* [example] add a split placement tutorial

* lint
parent ed2eaf4e
# Split Placement Example
Here we introduce how to run the naive implementation of the split placement of PPO algorithm.
We will release the complete version of flexible placement in the near future.
For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example.
### Step 1: Placing the models to different GPUs
Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs.
```python
actor_rollout_ref_pool_id = 'actor_rollout_ref_pool'
critic_pool_id = 'critic_pool'
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:
resource_pool_spec = {
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
}
else:
resource_pool_spec = {
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
}
print(f'resource_pool_spec: {resource_pool_spec}')
mapping = {
Role.ActorRollout: actor_rollout_ref_pool_id,
Role.Critic: critic_pool_id,
Role.RefPolicy: actor_rollout_ref_pool_id,
}
mapping[Role.RewardModel] = critic_pool_id
```
### Step 2: Make the models executed asynchronously
Based on the model placement, we need to make the models executed asynchronously.
To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations.
For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py`
```
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
def update_actor(self, data: DataProto):
...
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
def update_critic(self, data: DataProto):
...
```
We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we
### Step 3: Execute these operation in parallel in the single controller process
To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process.
```python
critic_output = critic_output.get()
actor_output = actor_output.get()
```
### Step 4: Run the split placement example
```
bash run_deepseek7b_llm.sh
```
\ No newline at end of file
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: {}
enable_gradient_checkpointing: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
ppo_epochs: 1
shuffle: True
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
grad_offload: False
optimizer_offload: False
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: 128
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: 128
# for hf rollout
do_sample: True
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
fsdp_config:
param_offload: False
grad_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
fsdp_config:
min_num_params: 0
param_offload: False
micro_batch_size: 64
max_length: null
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
total_epochs: 30
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'tracking']
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: 2
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl import DataProto
import torch
from verl.utils.reward_score import gsm8k, math
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
def _select_rm_score_fn(data_source):
if data_source == 'openai/gsm8k':
return gsm8k.compute_score
elif data_source == 'lighteval/MATH':
return math.compute_score
else:
raise NotImplementedError
class RewardManager():
def __init__(self, tokenizer, num_examine) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
def __call__(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if 'rm_scores' in data.batch.keys():
return data.batch['rm_scores']
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
already_print_data_sources = {}
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch['responses']
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
sequences = torch.cat((valid_prompt_ids, valid_response_ids))
sequences_str = self.tokenizer.decode(sequences)
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
# select rm_score
data_source = data_item.non_tensor_batch['data_source']
compute_score_fn = _select_rm_score_fn(data_source)
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
reward_tensor[i, valid_response_length - 1] = score
if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print(sequences_str)
return reward_tensor
import ray
import hydra
from split_monkey_patch import fit
@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None)
def main(config):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config))
@ray.remote
def main_task(config):
from verl.utils.fs import copy_local_path_from_hdfs
from transformers import AutoTokenizer
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
from verl.utils import set_pad_token_id
set_pad_token_id(tokenizer)
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ActorRolloutRefWorker,
Role.Critic: CriticWorker,
Role.RefPolicy: ActorRolloutRefWorker
}
# NOTE: initialze two resource pool
actor_rollout_ref_pool_id = 'actor_rollout_ref_pool'
critic_pool_id = 'critic_pool'
if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0:
resource_pool_spec = {
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes,
}
else:
resource_pool_spec = {
actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2),
}
print(f'resource_pool_spec: {resource_pool_spec}')
mapping = {
Role.ActorRollout: actor_rollout_ref_pool_id,
Role.Critic: critic_pool_id,
Role.RefPolicy: actor_rollout_ref_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
mapping[Role.RewardModel] = critic_pool_id
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
# Note that we always use function-based RM for validation
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
RayPPOTrainer.fit = fit
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()
set -x
python3 main_ppo_split.py \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=16 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','tracking'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='deepseek_llm_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=15 $@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An naive implementation of split placment example
"""
import os
from pprint import pprint
from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls
from codetiming import Timer
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
global_steps = 0
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# batch = batch.to('cuda')
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
# generate a batch
with Timer(name='gen', logger=None) as timer:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
metrics['timing/gen'] = timer.last
batch = batch.union(gen_batch_output)
if self.use_reference_policy:
# compute reference log_prob
with Timer(name='ref', logger=None) as timer:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
metrics['timing/ref'] = timer.last
# compute values
with Timer(name='values', logger=None) as timer:
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
metrics['timing/values'] = timer.last
with Timer(name='adv', logger=None) as timer:
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)
metrics['timing/adv'] = timer.last
# update critic
if self.use_critic:
with Timer(name='update_critic_call', logger=None) as timer:
critic_output = self.critic_wg.update_critic(batch)
metrics['timing/update_critic_call'] = timer.last
# implement critic warmup
if self.config.trainer.critic_warmup <= global_steps:
# update actor
with Timer(name='update_actor_call', logger=None) as timer:
actor_output = self.actor_rollout_wg.update_actor(batch)
metrics['timing/update_acto_call'] = timer.last
# NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class
with Timer(name='update_actor_critic', logger=None) as timer:
# NOTE: get the DataProtoFuture
critic_output = critic_output.get()
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
# NOTE: get the DataProtoFuture
actor_output = actor_output.get()
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
metrics['timing/update_actor_critic'] = timer.last
# validate
if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
with Timer(name='testing', logger=None) as timer:
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
metrics['timing/testing'] = timer.last
metrics.update(val_metrics)
# collect metrics
data_metrics = compute_data_metrics(batch=batch)
metrics.update(data_metrics)
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=global_steps)
if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{global_steps}')
actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)
if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{global_steps}')
critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)
global_steps += 1
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
......@@ -63,7 +63,10 @@ class ResourcePoolManager:
# Due to the Ray issue, we can only support max_colocate_count=1 for now.
# This means that each GPU can only have one process.
# We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1)
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
use_gpu=True,
max_colocate_count=1,
name_prefix=resource_pool_name)
self.resource_pool_dict[resource_pool_name] = resource_pool
def get_resource_pool(self, role: Role) -> RayResourcePool:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment