Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
V
verl
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ZhangXiaoyun
verl
Commits
818e4de2
Unverified
Commit
818e4de2
authored
Feb 02, 2025
by
HL
Committed by
GitHub
Feb 02, 2025
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
megatron: fix config error and add compute log prob interface (#186)
parent
fbc8fe82
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
132 additions
and
20 deletions
+132
-20
.github/workflows/e2e_gsm8k_megatron.yml
+50
-0
docs/start/install.rst
+1
-0
examples/ppo_trainer/run_deepseek_megatron.sh
+4
-1
tests/e2e/run_deepseek_megatron.sh
+41
-0
verl/single_controller/ray/megatron.py
+1
-1
verl/trainer/ppo/ray_trainer.py
+1
-1
verl/utils/config.py
+5
-5
verl/workers/actor/megatron_actor.py
+5
-0
verl/workers/critic/megatron_critic.py
+5
-1
verl/workers/megatron_workers.py
+19
-11
No files found.
.github/workflows/e2e_gsm8k_megatron.yml
0 → 100644
View file @
818e4de2
name
:
e2e_gsm8k_megatron
on
:
# Trigger the workflow on push or pull request,
# but only for the main branch
push
:
branches
:
-
main
paths
:
-
"
**/*.py"
-
.github/workflows/e2e_gsm8k_megatron.yml
pull_request
:
branches
:
-
main
paths
:
-
"
**/*.py"
-
.github/workflows/e2e_gsm8k_megatron.yml
-
"
tests/e2e/*.sh"
jobs
:
e2e_gsm8k_megatron
:
runs-on
:
[
self-hosted
,
l20-0
]
env
:
HTTP_PROXY
:
${{ secrets.PROXY_HTTP }}
HTTPS_PROXY
:
${{ secrets.PROXY_HTTPS }}
NO_PROXY
:
"
localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER
:
1
container
:
image
:
verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options
:
--gpus all --shm-size=10g
steps
:
-
uses
:
actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
# v4.2.2
with
:
fetch-depth
:
0
-
name
:
Install the current repository
run
:
|
pip3 install hf_transfer
pip3 install -e .[test]
-
name
:
Prepare gsm8k dataset
run
:
|
python3 examples/data_preprocess/gsm8k.py
-
name
:
Running gsm8k e2e training tests on 8 L20 GPUs with Megatron
run
:
|
ray stop --force
[ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_deepseek_megatron.sh
\ No newline at end of file
docs/start/install.rst
View file @
818e4de2
...
@@ -62,6 +62,7 @@ You can also get the Megatron code after verl's patch via
...
@@ -62,6 +62,7 @@ You can also get the Megatron code after verl's patch via
.. code:: bash
.. code:: bash
git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$(pwd)/Megatron-LM
Install from custom environment
Install from custom environment
---------------------------------
---------------------------------
...
...
examples/ppo_trainer/run_deepseek_megatron.sh
View file @
818e4de2
set
-x
set
-x
python3
-m
verl.trainer.main_ppo
--config-path
=
./config
--config-name
=
'ppo_megatron_trainer'
\
# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml
python3
-m
verl.trainer.main_ppo
--config-path
=
config
\
--config-name
=
'ppo_megatron_trainer.yaml'
\
data.train_files
=
$HOME
/data/gsm8k/train.parquet
\
data.train_files
=
$HOME
/data/gsm8k/train.parquet
\
data.val_files
=
$HOME
/data/gsm8k/test.parquet
\
data.val_files
=
$HOME
/data/gsm8k/test.parquet
\
data.train_batch_size
=
1024
\
data.train_batch_size
=
1024
\
...
...
tests/e2e/run_deepseek_megatron.sh
0 → 100644
View file @
818e4de2
set
-x
# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml
huggingface-cli download deepseek-ai/deepseek-coder-1.3b-instruct
python3
-m
verl.trainer.main_ppo
--config-path
=
config
\
--config-name
=
'ppo_megatron_trainer.yaml'
\
data.train_files
=
$HOME
/data/gsm8k/train.parquet
\
data.val_files
=
$HOME
/data/gsm8k/test.parquet
\
data.train_batch_size
=
1024
\
data.val_batch_size
=
1312
\
data.max_prompt_length
=
512
\
data.max_response_length
=
512
\
actor_rollout_ref.model.path
=
deepseek-ai/deepseek-coder-1.3b-instruct
\
actor_rollout_ref.actor.optim.lr
=
2e-6
\
actor_rollout_ref.actor.ppo_mini_batch_size
=
256
\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
=
4
\
actor_rollout_ref.actor.megatron.tensor_model_parallel_size
=
2
\
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu
=
8
\
actor_rollout_ref.rollout.tensor_model_parallel_size
=
2
\
actor_rollout_ref.rollout.name
=
vllm
\
actor_rollout_ref.rollout.gpu_memory_utilization
=
0.5
\
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu
=
16
\
actor_rollout_ref.ref.megatron.tensor_model_parallel_size
=
2
\
critic.optim.lr
=
2e-5
\
critic.model.path
=
deepseek-ai/deepseek-coder-1.3b-instruct
\
critic.model.enable_gradient_checkpointing
=
False
\
critic.ppo_micro_batch_size_per_gpu
=
4
\
critic.megatron.tensor_model_parallel_size
=
2
\
algorithm.kl_ctrl.kl_coef
=
0.001
\
trainer.critic_warmup
=
0
\
trainer.logger
=[
'console'
]
\
trainer.project_name
=
'verl_megatron_gsm8k_examples'
\
trainer.experiment_name
=
'deepseek_llm_1b3_function_rm'
\
trainer.n_gpus_per_node
=
8
\
trainer.nnodes
=
1
\
trainer.save_freq
=
-1
\
trainer.test_freq
=
1
\
trainer.total_epochs
=
15
\
trainer.total_training_steps
=
3
$@
verl/single_controller/ray/megatron.py
View file @
818e4de2
...
@@ -21,7 +21,7 @@ from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobal
...
@@ -21,7 +21,7 @@ from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobal
from
verl.single_controller.base.megatron.worker_group
import
MegatronWorkerGroup
from
verl.single_controller.base.megatron.worker_group
import
MegatronWorkerGroup
# NOTE(sgm): for opensource megatron-core
# NOTE(sgm): for open
-
source megatron-core
class
NVMegatronRayWorkerGroup
(
RayWorkerGroup
,
MegatronWorkerGroup
):
class
NVMegatronRayWorkerGroup
(
RayWorkerGroup
,
MegatronWorkerGroup
):
"""
"""
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
...
...
verl/trainer/ppo/ray_trainer.py
View file @
818e4de2
...
@@ -637,7 +637,7 @@ class RayPPOTrainer(object):
...
@@ -637,7 +637,7 @@ class RayPPOTrainer(object):
batch
.
batch
[
'token_level_scores'
]
=
reward_tensor
batch
.
batch
[
'token_level_scores'
]
=
reward_tensor
# compute rewards. apply_kl_penalty if available
# compute rewards. apply_kl_penalty if available
if
not
self
.
config
.
actor_rollout_ref
.
actor
.
use_kl_loss
:
if
not
self
.
config
.
actor_rollout_ref
.
actor
.
get
(
'use_kl_loss'
,
False
)
:
batch
,
kl_metrics
=
apply_kl_penalty
(
batch
,
batch
,
kl_metrics
=
apply_kl_penalty
(
batch
,
kl_ctrl
=
self
.
kl_ctrl
,
kl_ctrl
=
self
.
kl_ctrl
,
kl_penalty
=
self
.
config
.
algorithm
.
kl_penalty
)
kl_penalty
=
self
.
config
.
algorithm
.
kl_penalty
)
...
...
verl/utils/config.py
View file @
818e4de2
...
@@ -74,27 +74,27 @@ def validate_config(config):
...
@@ -74,27 +74,27 @@ def validate_config(config):
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if
not
config
.
actor_rollout_ref
.
actor
.
use_dynamic_bsz
:
if
not
config
.
actor_rollout_ref
.
actor
.
use_dynamic_bsz
:
sp_size
=
config
.
actor_rollout_ref
.
actor
.
ulysses_sequence_parallel_size
sp_size
=
config
.
actor_rollout_ref
.
actor
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
if
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
is
not
None
:
if
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
is
not
None
:
assert
config
.
actor_rollout_ref
.
actor
.
ppo_mini_batch_size
%
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
==
0
assert
config
.
actor_rollout_ref
.
actor
.
ppo_mini_batch_size
%
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
==
0
assert
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
*
sp_size
>=
n_gpus
assert
config
.
actor_rollout_ref
.
actor
.
ppo_micro_batch_size
*
sp_size
>=
n_gpus
# critic
# critic
if
not
config
.
critic
.
use_dynamic_bsz
:
if
not
config
.
critic
.
use_dynamic_bsz
:
sp_size
=
config
.
critic
.
ulysses_sequence_parallel_size
sp_size
=
config
.
critic
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
if
config
.
critic
.
ppo_micro_batch_size
is
not
None
:
if
config
.
critic
.
ppo_micro_batch_size
is
not
None
:
assert
config
.
critic
.
ppo_mini_batch_size
%
config
.
critic
.
ppo_micro_batch_size
==
0
assert
config
.
critic
.
ppo_mini_batch_size
%
config
.
critic
.
ppo_micro_batch_size
==
0
assert
config
.
critic
.
ppo_micro_batch_size
*
sp_size
>=
n_gpus
assert
config
.
critic
.
ppo_micro_batch_size
*
sp_size
>=
n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if
config
.
actor_rollout_ref
.
actor
.
strategy
==
'fsdp'
:
if
config
.
actor_rollout_ref
.
actor
.
strategy
==
'fsdp'
:
if
config
.
actor_rollout_ref
.
actor
.
ulysses_sequence_parallel_size
>
1
or
\
if
config
.
actor_rollout_ref
.
actor
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
>
1
or
\
config
.
actor_rollout_ref
.
ref
.
ulysses_sequence_parallel_size
>
1
:
config
.
actor_rollout_ref
.
ref
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
>
1
:
assert
config
.
actor_rollout_ref
.
model
.
use_remove_padding
,
\
assert
config
.
actor_rollout_ref
.
model
.
use_remove_padding
,
\
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
if
config
.
critic
.
strategy
==
'fsdp'
:
if
config
.
critic
.
strategy
==
'fsdp'
:
if
config
.
critic
.
ulysses_sequence_parallel_size
>
1
:
if
config
.
critic
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
>
1
:
assert
config
.
critic
.
model
.
use_remove_padding
,
\
assert
config
.
critic
.
model
.
use_remove_padding
,
\
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
...
...
verl/workers/actor/megatron_actor.py
View file @
818e4de2
...
@@ -107,6 +107,7 @@ class MegatronPPOActor(BasePPOActor):
...
@@ -107,6 +107,7 @@ class MegatronPPOActor(BasePPOActor):
>>> actor_optimizer=actor_optimizer)
>>> actor_optimizer=actor_optimizer)
"""
"""
super
()
.
__init__
(
config
)
super
()
.
__init__
(
config
)
self
.
_validate_config
(
config
)
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
megatron_config
=
megatron_config
self
.
megatron_config
=
megatron_config
# self.megatron_args = get_args()
# self.megatron_args = get_args()
...
@@ -126,6 +127,10 @@ class MegatronPPOActor(BasePPOActor):
...
@@ -126,6 +127,10 @@ class MegatronPPOActor(BasePPOActor):
'reduce_grads_use_alltoall'
:
False
'reduce_grads_use_alltoall'
:
False
})
})
def
_validate_config
(
self
,
config
)
->
None
:
"""Validate config options not implemented for Megatron backend"""
assert
config
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
==
1
def
compute_log_prob
(
self
,
data
:
DataProto
)
->
torch
.
Tensor
:
def
compute_log_prob
(
self
,
data
:
DataProto
)
->
torch
.
Tensor
:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
...
...
verl/workers/critic/megatron_critic.py
View file @
818e4de2
...
@@ -43,7 +43,7 @@ class MegatronPPOCritic(BasePPOCritic):
...
@@ -43,7 +43,7 @@ class MegatronPPOCritic(BasePPOCritic):
def
__init__
(
self
,
config
,
model_config
,
megatron_config
,
critic_module
:
nn
.
ModuleList
,
def
__init__
(
self
,
config
,
model_config
,
megatron_config
,
critic_module
:
nn
.
ModuleList
,
critic_optimizer
:
DistributedOptimizer
,
critic_optimizer_config
:
OptimizerConfig
):
critic_optimizer
:
DistributedOptimizer
,
critic_optimizer_config
:
OptimizerConfig
):
super
()
.
__init__
(
config
=
config
)
super
()
.
__init__
(
config
=
config
)
self
.
_validate_config
(
config
)
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
megatron_config
=
megatron_config
self
.
megatron_config
=
megatron_config
...
@@ -74,6 +74,10 @@ class MegatronPPOCritic(BasePPOCritic):
...
@@ -74,6 +74,10 @@ class MegatronPPOCritic(BasePPOCritic):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
_validate_config
(
self
,
config
)
->
None
:
"""Validate config options not implemented for Megatron backend"""
assert
config
.
get
(
'ulysses_sequence_parallel_size'
,
1
)
==
1
def
compute_values
(
self
,
data
:
DataProto
)
->
DataProto
:
def
compute_values
(
self
,
data
:
DataProto
)
->
DataProto
:
# data.batch = data.batch.to(self.critic_module.module.device)
# data.batch = data.batch.to(self.critic_module.module.device)
responses
=
data
.
batch
[
'responses'
]
responses
=
data
.
batch
[
'responses'
]
...
...
verl/workers/megatron_workers.py
View file @
818e4de2
...
@@ -112,7 +112,7 @@ class ActorRolloutRefWorker(MegatronWorker):
...
@@ -112,7 +112,7 @@ class ActorRolloutRefWorker(MegatronWorker):
# normalize config
# normalize config
if
self
.
_is_actor
and
self
.
_is_rollout
:
if
self
.
_is_actor
and
self
.
_is_rollout
:
self
.
config
.
actor
.
ppo_mini_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
actor
.
ppo_mini_batch_size
//=
mpu
.
get_data_parallel_world_size
()
if
self
.
config
.
actor
.
ppo_micro_batch_size
is
not
None
:
if
self
.
config
.
actor
.
get
(
'ppo_micro_batch_size'
,
None
)
:
self
.
config
.
actor
.
ppo_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
actor
.
ppo_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
rollout
.
log_prob_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
rollout
.
log_prob_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
actor
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
actor
.
ppo_micro_batch_size
self
.
config
.
actor
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
actor
.
ppo_micro_batch_size
...
@@ -122,7 +122,7 @@ class ActorRolloutRefWorker(MegatronWorker):
...
@@ -122,7 +122,7 @@ class ActorRolloutRefWorker(MegatronWorker):
self
.
_is_offload_grad
=
self
.
config
.
actor
.
get
(
'grad_offload'
,
False
)
self
.
_is_offload_grad
=
self
.
config
.
actor
.
get
(
'grad_offload'
,
False
)
self
.
_is_offload_optimizer
=
self
.
config
.
actor
.
get
(
'optimizer_offload'
,
False
)
self
.
_is_offload_optimizer
=
self
.
config
.
actor
.
get
(
'optimizer_offload'
,
False
)
elif
self
.
_is_ref
:
elif
self
.
_is_ref
:
if
self
.
config
.
ref
.
ppo_micro_batch_size
is
not
None
:
if
self
.
config
.
ref
.
get
(
'ppo_micro_batch_size'
,
None
)
:
self
.
config
.
ref
.
log_prob_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
ref
.
log_prob_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
ref
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
ref
.
ppo_micro_batch_size
self
.
config
.
ref
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
ref
.
ppo_micro_batch_size
self
.
_is_offload_param
=
self
.
config
.
ref
.
get
(
'param_offload'
,
False
)
self
.
_is_offload_param
=
self
.
config
.
ref
.
get
(
'param_offload'
,
False
)
...
@@ -364,14 +364,6 @@ class ActorRolloutRefWorker(MegatronWorker):
...
@@ -364,14 +364,6 @@ class ActorRolloutRefWorker(MegatronWorker):
output
=
self
.
sharding_manager
.
postprocess_data
(
output
)
output
=
self
.
sharding_manager
.
postprocess_data
(
output
)
validate
=
prompts
.
meta_info
.
get
(
'validate'
,
False
)
if
self
.
_is_actor
and
not
validate
:
# we should always recompute old_log_probs when it is HybridEngine
output
.
meta_info
[
'micro_batch_size'
]
=
self
.
config
.
rollout
.
log_prob_micro_batch_size_per_gpu
output
.
meta_info
[
'temperature'
]
=
self
.
config
.
rollout
.
temperature
old_log_probs
=
self
.
actor
.
compute_log_prob
(
data
=
output
)
output
.
batch
[
'old_log_probs'
]
=
old_log_probs
output
=
output
.
to
(
'cpu'
)
output
=
output
.
to
(
'cpu'
)
# clear kv cache
# clear kv cache
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -397,6 +389,22 @@ class ActorRolloutRefWorker(MegatronWorker):
...
@@ -397,6 +389,22 @@ class ActorRolloutRefWorker(MegatronWorker):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
output
return
output
@register
(
dispatch_mode
=
Dispatch
.
MEGATRON_COMPUTE_PROTO
)
def
compute_log_prob
(
self
,
data
:
DataProto
):
assert
self
.
_is_actor
data
=
data
.
to
(
'cuda'
)
output
=
data
# we should always recompute old_log_probs when it is HybridEngine
output
.
meta_info
[
'micro_batch_size'
]
=
self
.
config
.
rollout
.
log_prob_micro_batch_size_per_gpu
output
.
meta_info
[
'temperature'
]
=
self
.
config
.
rollout
.
temperature
old_log_probs
=
self
.
actor
.
compute_log_prob
(
data
=
output
)
output
.
batch
[
'old_log_probs'
]
=
old_log_probs
output
=
output
.
to
(
'cpu'
)
# clear kv cache
torch
.
cuda
.
empty_cache
()
log_gpu_memory_usage
(
'After recompute log prob'
,
logger
=
logger
)
return
output
@register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
@register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
load_checkpoint
(
self
,
checkpoint_path
):
def
load_checkpoint
(
self
,
checkpoint_path
):
pass
pass
...
@@ -445,7 +453,7 @@ class CriticWorker(MegatronWorker):
...
@@ -445,7 +453,7 @@ class CriticWorker(MegatronWorker):
# normalize config
# normalize config
self
.
config
.
ppo_mini_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
ppo_mini_batch_size
//=
mpu
.
get_data_parallel_world_size
()
if
self
.
config
.
ppo_micro_batch_size
is
not
None
:
if
self
.
config
.
get
(
'ppo_micro_batch_size'
,
None
)
:
self
.
config
.
ppo_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
ppo_micro_batch_size
//=
mpu
.
get_data_parallel_world_size
()
self
.
config
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
ppo_micro_batch_size
self
.
config
.
ppo_micro_batch_size_per_gpu
=
self
.
config
.
ppo_micro_batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment