Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
V
verl
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ZhangXiaoyun
verl
Commits
b677a61e
Unverified
Commit
b677a61e
authored
Mar 02, 2025
by
Weizhe Chen
Committed by
GitHub
Mar 02, 2025
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rollout: FIRE sampling added. (#58)
parent
128781cf
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
306 additions
and
2 deletions
+306
-2
.github/workflows/e2e_digit_completion_fire.yml
+46
-0
tests/e2e/run_ray_trainer_fire_sampling.sh
+40
-0
verl/trainer/config/ppo_trainer.yaml
+1
-0
verl/workers/fsdp_workers.py
+5
-1
verl/workers/rollout/vllm_rollout/__init__.py
+1
-0
verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py
+211
-0
verl/workers/rollout/vllm_rollout/vllm_rollout.py
+2
-1
No files found.
.github/workflows/e2e_digit_completion_fire.yml
0 → 100644
View file @
b677a61e
name
:
e2e_digit_completion_fire
on
:
# Trigger the workflow on push or pull request,
# but only for the main branch
push
:
branches
:
-
main
paths
:
-
"
**/*.py"
-
.github/workflows/e2e_digit_completion_fire.yml
pull_request
:
branches
:
-
main
paths
:
-
"
**/*.py"
-
.github/workflows/e2e_digit_completion_fire.yml
-
"
tests/e2e/*.sh"
# Declare permissions just read content.
permissions
:
contents
:
read
jobs
:
e2e_digit_completion
:
runs-on
:
[
self-hosted
,
l20-0
]
env
:
HTTP_PROXY
:
${{ secrets.PROXY_HTTP }}
HTTPS_PROXY
:
${{ secrets.PROXY_HTTPS }}
NO_PROXY
:
"
localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER
:
1
container
:
image
:
verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options
:
--gpus all --shm-size=10g
steps
:
-
uses
:
actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
# v4.2.2
with
:
fetch-depth
:
0
-
name
:
Install the current repository
run
:
|
pip3 install hf_transfer
pip3 install -e .[test]
-
name
:
Running digit completon e2e training tests on 8 L20 GPUs
run
:
|
ray stop --force
bash tests/e2e/run_ray_trainer_fire_sampling.sh
tests/e2e/run_ray_trainer_fire_sampling.sh
0 → 100644
View file @
b677a61e
#!/usr/bin/env bash
set
-e
-x
OUTPUT_FILE
=
"/tmp/output_ray_trainer.txt"
export
PATH
=
$PATH
:~/.local/bin
rm
-rf
$OUTPUT_FILE
python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py
\
data.train_files
=
tests/e2e/arithmetic_sequence/data/train.parquet
\
data.val_files
=
tests/e2e/arithmetic_sequence/data/test.parquet
\
data.train_batch_size
=
800
\
data.val_batch_size
=
200
\
data.max_prompt_length
=
16
\
data.max_response_length
=
32
\
data.return_raw_input_ids
=
True
\
actor_rollout_ref.model.path
=
tests/e2e/arithmetic_sequence/model
\
actor_rollout_ref.model.external_lib
=
tests.e2e.envs.digit_completion
\
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
=
200
\
actor_rollout_ref.actor.entropy_coeff
=
0
\
actor_rollout_ref.actor.optim.lr
=
1e-4
\
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu
=
200
\
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu
=
200
\
actor_rollout_ref.rollout.name
=
hf
\
actor_rollout_ref.rollout.use_fire_sampling
=
True
\
actor_rollout_ref.rollout.tensor_model_parallel_size
=
1
\
critic.ppo_micro_batch_size_per_gpu
=
200
\
critic.model.path
=
tests/e2e/arithmetic_sequence/model
\
critic.optim.lr
=
1e-3
\
algorithm.kl_ctrl.kl_coef
=
0.005
\
trainer.total_epochs
=
200
\
trainer.experiment_name
=
arithmetic_sequences
\
trainer.logger
=[
'console'
]
\
trainer.n_gpus_per_node
=
1
\
trainer.test_freq
=
1
\
trainer.save_freq
=
110 | tee
$OUTPUT_FILE
;
python3 tests/e2e/check_results.py
--output_file
=
$OUTPUT_FILE
rm
-rf
$OUTPUT_FILE
verl/trainer/config/ppo_trainer.yaml
View file @
b677a61e
...
...
@@ -64,6 +64,7 @@ actor_rollout_ref:
temperature
:
1.0
top_k
:
-1
# 0 for hf rollout, -1 for vllm rollout
top_p
:
1
use_fire_sampling
:
False
# https://arxiv.org/abs/2410.21236
prompt_length
:
${data.max_prompt_length}
# not use for opensource
response_length
:
${data.max_response_length}
# for vllm rollout
...
...
verl/workers/fsdp_workers.py
View file @
b677a61e
...
...
@@ -302,7 +302,11 @@ class ActorRolloutRefWorker(Worker):
rollout_sharding_manager
=
BaseShardingManager
()
# TODO: a sharding manager that do nothing?
elif
self
.
config
.
rollout
.
name
==
'vllm'
:
from
verl.workers.rollout.vllm_rollout
import
vLLMRollout
,
vllm_mode
if
self
.
config
.
rollout
.
use_fire_sampling
:
from
verl.workers.rollout.vllm_rollout
import
FIREvLLMRollout
as
vLLMRollout
from
verl.workers.rollout.vllm_rollout
import
vllm_mode
else
:
from
verl.workers.rollout.vllm_rollout
import
vLLMRollout
,
vllm_mode
from
verl.workers.sharding_manager
import
FSDPVLLMShardingManager
log_gpu_memory_usage
(
'Before building vllm rollout'
,
logger
=
None
)
local_path
=
copy_to_local
(
self
.
config
.
model
.
path
)
...
...
verl/workers/rollout/vllm_rollout/__init__.py
View file @
b677a61e
...
...
@@ -28,6 +28,7 @@ package_version = get_version(package_name)
if
package_version
<=
'0.6.3'
:
vllm_mode
=
'customized'
from
.vllm_rollout
import
vLLMRollout
from
.fire_vllm_rollout
import
FIREvLLMRollout
else
:
vllm_mode
=
'spmd'
from
.vllm_rollout_spmd
import
vLLMRollout
verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py
0 → 100644
View file @
b677a61e
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from
typing
import
List
from
contextlib
import
contextmanager
from
omegaconf
import
DictConfig
import
torch
import
torch.distributed
from
tensordict
import
TensorDict
from
torch
import
nn
from
verl
import
DataProto
from
verl.utils.torch_functional
import
get_eos_mask
,
pad_sequence_to_length
from
verl.workers.rollout.base
import
BaseRollout
from
verl.workers.rollout.vllm_rollout.vllm_rollout
import
vLLMRollout
from
verl.third_party.vllm
import
LLM
,
vllm_version
from
verl.third_party.vllm
import
parallel_state
as
vllm_ps
from
vllm
import
SamplingParams
# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def
_pre_process_inputs
(
pad_token_id
,
prompt_token_ids
:
torch
.
Tensor
)
->
List
[
int
]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index
=
torch
.
nonzero
(
prompt_token_ids
!=
pad_token_id
,
as_tuple
=
False
)[
0
][
0
]
token_ids
=
prompt_token_ids
[
non_pad_index
:]
.
tolist
()
return
token_ids
class
FIREvLLMRollout
(
vLLMRollout
):
def
__init__
(
self
,
actor_module
:
nn
.
Module
,
config
:
DictConfig
,
tokenizer
,
model_hf_config
,
**
kwargs
):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in vllm
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super
()
.
__init__
(
actor_module
,
config
,
tokenizer
,
model_hf_config
,
**
kwargs
)
self
.
use_fire_sampling
=
config
.
get
(
'use_fire_sampling'
,
False
)
if
self
.
use_fire_sampling
:
kwargs_0
=
kwargs
.
copy
()
kwargs_0
[
'temperature'
]
=
30
kwargs_0
[
'max_tokens'
]
=
1
if
'top_k'
not
in
kwargs_0
or
kwargs_0
[
'top_k'
]
<=
0
:
kwargs_0
[
'top_k'
]
=
16
kwargs
[
'max_tokens'
]
-=
1
self
.
sampling_params_0
=
SamplingParams
(
**
kwargs_0
)
@contextmanager
def
update_sampling_params
(
self
,
**
kwargs
):
# update sampling params
old_sampling_params_args
=
{}
if
kwargs
:
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
self
.
sampling_params
,
key
):
old_value
=
getattr
(
self
.
sampling_params
,
key
)
old_sampling_params_args
[
key
]
=
old_value
setattr
(
self
.
sampling_params
,
key
,
value
)
if
self
.
use_fire_sampling
:
old_sampling_params_args_0
=
{}
if
kwargs
:
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
self
.
sampling_params_0
,
key
):
old_value
=
getattr
(
self
.
sampling_params_0
,
key
)
old_sampling_params_args_0
[
key
]
=
old_value
setattr
(
self
.
sampling_params_0
,
key
,
value
)
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for
key
,
value
in
old_sampling_params_args
.
items
():
setattr
(
self
.
sampling_params
,
key
,
value
)
if
self
.
use_fire_sampling
:
for
key
,
value
in
old_sampling_params_args_0
.
items
():
setattr
(
self
.
sampling_params_0
,
key
,
value
)
@torch.no_grad
()
def
generate_sequences
(
self
,
prompts
:
DataProto
,
**
kwargs
)
->
DataProto
:
# rebuild vllm cache engine
if
self
.
config
.
free_cache_engine
:
self
.
inference_engine
.
init_cache_engine
()
idx
=
prompts
.
batch
[
'input_ids'
]
# (bs, prompt_length)
# left-padded attention_mask
attention_mask
=
prompts
.
batch
[
'attention_mask'
]
position_ids
=
prompts
.
batch
[
'position_ids'
]
# used to construct attention_mask
eos_token_id
=
prompts
.
meta_info
[
'eos_token_id'
]
batch_size
=
idx
.
size
(
0
)
idx_list
=
[]
# parse idx from torch.Tensor to List[List[str]]
for
i
in
range
(
batch_size
):
idx_list
.
append
(
_pre_process_inputs
(
self
.
pad_token_id
,
idx
[
i
]))
do_sample
=
prompts
.
meta_info
.
get
(
'do_sample'
,
True
)
if
not
do_sample
:
kwargs
=
{
'best_of'
:
1
,
'top_p'
:
1.0
,
'top_k'
:
-
1
,
'min_p'
:
0.0
,
'temperature'
:
0
,
'n'
:
1
# if greedy, only 1 response
}
if
not
self
.
use_fire_sampling
:
# users can customize different sampling_params at different run
with
self
.
update_sampling_params
(
**
kwargs
):
output
=
self
.
inference_engine
.
generate
(
prompts
=
None
,
# because we have already convert it to prompt token id
sampling_params
=
self
.
sampling_params
,
prompt_token_ids
=
idx_list
,
use_tqdm
=
False
)
response
=
output
[
0
]
.
to
(
idx
.
device
)
# (bs, response_length)
log_probs
=
output
[
1
]
.
to
(
idx
.
device
)
# (bs, response_length)
else
:
with
self
.
update_sampling_params
(
**
kwargs
):
output_0
=
self
.
inference_engine
.
generate
(
prompts
=
None
,
# because we have already convert it to prompt token id
sampling_params
=
self
.
sampling_params_0
,
prompt_token_ids
=
idx_list
,
use_tqdm
=
False
)
new_idx_list
=
[]
for
i
in
range
(
batch_size
):
new_idx_list
.
append
(
idx_list
[
i
]
+
output_0
[
0
][
i
]
.
tolist
())
output
=
self
.
inference_engine
.
generate
(
prompts
=
None
,
# because we have already convert it to prompt token id
sampling_params
=
self
.
sampling_params
,
prompt_token_ids
=
new_idx_list
,
use_tqdm
=
False
)
response
=
torch
.
cat
([
output_0
[
0
],
output
[
0
]],
dim
=
1
)
.
to
(
idx
.
device
)
# (bs, response_length)
log_probs
=
torch
.
cat
([
output_0
[
1
],
output
[
1
]],
dim
=
1
)
.
to
(
idx
.
device
)
# (bs, response_length)
if
response
.
shape
[
1
]
<
self
.
config
.
response_length
:
response
=
pad_sequence_to_length
(
response
,
self
.
config
.
response_length
,
self
.
pad_token_id
)
log_probs
=
pad_sequence_to_length
(
log_probs
,
self
.
config
.
response_length
,
self
.
pad_token_id
)
if
self
.
config
.
n
>
1
and
do_sample
:
idx
=
idx
.
repeat_interleave
(
self
.
config
.
n
,
dim
=
0
)
attention_mask
=
attention_mask
.
repeat_interleave
(
self
.
config
.
n
,
dim
=
0
)
position_ids
=
position_ids
.
repeat_interleave
(
self
.
config
.
n
,
dim
=
0
)
batch_size
=
batch_size
*
self
.
config
.
n
seq
=
torch
.
cat
([
idx
,
response
],
dim
=-
1
)
response_length
=
response
.
size
(
1
)
delta_position_id
=
torch
.
arange
(
1
,
response_length
+
1
,
device
=
position_ids
.
device
)
delta_position_id
=
delta_position_id
.
unsqueeze
(
0
)
.
repeat
(
batch_size
,
1
)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids
=
position_ids
[:,
-
1
:]
+
delta_position_id
position_ids
=
torch
.
cat
([
position_ids
,
response_position_ids
],
dim
=-
1
)
response_attention_mask
=
get_eos_mask
(
response_id
=
response
,
eos_token
=
eos_token_id
,
dtype
=
attention_mask
.
dtype
)
attention_mask
=
torch
.
cat
((
attention_mask
,
response_attention_mask
),
dim
=-
1
)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch
=
TensorDict
(
{
'prompts'
:
idx
,
'responses'
:
response
,
'input_ids'
:
seq
,
# here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
},
batch_size
=
batch_size
)
# free vllm cache engine
if
self
.
config
.
free_cache_engine
:
self
.
inference_engine
.
free_cache_engine
()
return
DataProto
(
batch
=
batch
)
verl/workers/rollout/vllm_rollout/vllm_rollout.py
View file @
b677a61e
...
...
@@ -238,4 +238,4 @@ class vLLMRollout(BaseRollout):
if
self
.
config
.
free_cache_engine
:
self
.
inference_engine
.
free_cache_engine
()
return
DataProto
(
batch
=
batch
)
return
DataProto
(
batch
=
batch
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment