Unverified Commit 818e4de2 by HL Committed by GitHub

megatron: fix config error and add compute log prob interface (#186)

parent fbc8fe82
name: e2e_gsm8k_megatron
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gsm8k_megatron.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gsm8k_megatron.yml
- "tests/e2e/*.sh"
jobs:
e2e_gsm8k_megatron:
runs-on: [self-hosted, l20-0]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
- name: Prepare gsm8k dataset
run: |
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron
run: |
ray stop --force
[ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_deepseek_megatron.sh
\ No newline at end of file
......@@ -62,6 +62,7 @@ You can also get the Megatron code after verl's patch via
.. code:: bash
git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$(pwd)/Megatron-LM
Install from custom environment
---------------------------------
......
set -x
python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml
python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
......
set -x
# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml
huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct
python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \
actor_rollout_ref.actor.optim.lr=2e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
critic.optim.lr=2e-5 \
critic.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.megatron.tensor_model_parallel_size=2 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='deepseek_llm_1b3_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=1 \
trainer.total_epochs=15 \
trainer.total_training_steps=3 $@
......@@ -21,7 +21,7 @@ from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobal
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
# NOTE(sgm): for opensource megatron-core
# NOTE(sgm): for open-source megatron-core
class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
"""
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
......
......@@ -637,7 +637,7 @@ class RayPPOTrainer(object):
batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.use_kl_loss:
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
......
......@@ -74,27 +74,27 @@ def validate_config(config):
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
sp_size = config.actor_rollout_ref.actor.ulysses_sequence_parallel_size
sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
# critic
if not config.critic.use_dynamic_bsz:
sp_size = config.critic.ulysses_sequence_parallel_size
sp_size = config.critic.get('ulysses_sequence_parallel_size', 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.ulysses_sequence_parallel_size > 1 or \
config.actor_rollout_ref.ref.ulysses_sequence_parallel_size > 1:
if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \
config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1:
assert config.actor_rollout_ref.model.use_remove_padding, \
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
if config.critic.strategy == 'fsdp':
if config.critic.ulysses_sequence_parallel_size > 1:
if config.critic.get('ulysses_sequence_parallel_size', 1) > 1:
assert config.critic.model.use_remove_padding, \
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
......
......@@ -107,6 +107,7 @@ class MegatronPPOActor(BasePPOActor):
>>> actor_optimizer=actor_optimizer)
"""
super().__init__(config)
self._validate_config(config)
self.model_config = model_config
self.megatron_config = megatron_config
# self.megatron_args = get_args()
......@@ -126,6 +127,10 @@ class MegatronPPOActor(BasePPOActor):
'reduce_grads_use_alltoall': False
})
def _validate_config(self, config) -> None:
"""Validate config options not implemented for Megatron backend"""
assert config.get('ulysses_sequence_parallel_size', 1) == 1
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
......
......@@ -43,7 +43,7 @@ class MegatronPPOCritic(BasePPOCritic):
def __init__(self, config, model_config, megatron_config, critic_module: nn.ModuleList,
critic_optimizer: DistributedOptimizer, critic_optimizer_config: OptimizerConfig):
super().__init__(config=config)
self._validate_config(config)
self.model_config = model_config
self.megatron_config = megatron_config
......@@ -74,6 +74,10 @@ class MegatronPPOCritic(BasePPOCritic):
else:
raise NotImplementedError
def _validate_config(self, config) -> None:
"""Validate config options not implemented for Megatron backend"""
assert config.get('ulysses_sequence_parallel_size', 1) == 1
def compute_values(self, data: DataProto) -> DataProto:
# data.batch = data.batch.to(self.critic_module.module.device)
responses = data.batch['responses']
......
......@@ -112,7 +112,7 @@ class ActorRolloutRefWorker(MegatronWorker):
# normalize config
if self._is_actor and self._is_rollout:
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
if self.config.actor.ppo_micro_batch_size is not None:
if self.config.actor.get('ppo_micro_batch_size', None):
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
......@@ -122,7 +122,7 @@ class ActorRolloutRefWorker(MegatronWorker):
self._is_offload_grad = self.config.actor.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False)
elif self._is_ref:
if self.config.ref.ppo_micro_batch_size is not None:
if self.config.ref.get('ppo_micro_batch_size', None):
self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.ppo_micro_batch_size
self._is_offload_param = self.config.ref.get('param_offload', False)
......@@ -364,14 +364,6 @@ class ActorRolloutRefWorker(MegatronWorker):
output = self.sharding_manager.postprocess_data(output)
validate = prompts.meta_info.get('validate', False)
if self._is_actor and not validate:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
output.meta_info['temperature'] = self.config.rollout.temperature
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
# clear kv cache
torch.cuda.empty_cache()
......@@ -397,6 +389,22 @@ class ActorRolloutRefWorker(MegatronWorker):
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
data = data.to('cuda')
output = data
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
output.meta_info['temperature'] = self.config.rollout.temperature
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, checkpoint_path):
pass
......@@ -445,7 +453,7 @@ class CriticWorker(MegatronWorker):
# normalize config
self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
if self.config.ppo_micro_batch_size is not None:
if self.config.get('ppo_micro_batch_size', None):
self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment