Unverified Commit 333e6d62 by Junrong Lin Committed by GitHub

[rollout] feat: add SGLang as rollout engine to verl (#490)

#22 . WIP, will add more details tomorrow :)

---------

Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
parent 3b18b0eb
name: e2e_sglang_gsm8k
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/e2e_sglang_gsm8k.yml
pull_request:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- "verl/trainer/config/*.yaml"
- .github/workflows/e2e_sglang_gsm8k.yml
- "tests/e2e/*.sh"
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
# Declare permissions just read content.
permissions:
contents: read
jobs:
e2e_sglang_gsm8k:
runs-on: [self-hosted, l20-1]
timeout-minutes: 40 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.3.post3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test,gpu,sglang] --no-deps
- name: Prepare gsm8k dataset
run: |
ray stop --force
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm and save ckpt
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_function_rm.sh sglang
......@@ -93,6 +93,7 @@ celerybeat-schedule
# virtualenv
venv/
.venv/
ENV/
# Spyder project settings
......@@ -122,4 +123,5 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
# local logs
logs
log
\ No newline at end of file
log
outputs
......@@ -10,7 +10,7 @@ Requirements
verl supports various backends. Currently, the following configurations are available:
- **FSDP** and **Megatron-LM** (optional) for training.
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
- **SGLang**, **vLLM** and **TGI** for rollout generation.
Training backends
------------------
......@@ -19,6 +19,25 @@ We recommend using **FSDP** backend to investigate, research and prototype diffe
For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support Megatron-LM v0.4 [1]_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`.
Install verl-SGLang from scratch
-------------------------------------
**SGLang has largely support the rearch and inference workload at xAI. For verl-sglang installation, ignore the version conflicts reported by pip with vllm. And, SGLang support native API for RLHF, do not need to patch a single line of code.**
The following steps are quick installation guide for verl-SGLang.
.. code:: bash
# Create a virtual environment and use uv for quick installation
python3 -m venv ~/.python/verl-sglang && source ~/.python/verl-sglang/bin/activate
python3 -m pip install --upgrade pip && python3 -m pip install --upgrade uv
# Install verl-SGLang
git clone https://github.com/volcengine/verl verl-sglang && cd verl-sglang
python3 -m uv pip install .
# Install the latest stable version of sglang with verl support, currently, the latest version is 0.4.3.post3
# For SGLang installation, you can also refer to https://docs.sglang.ai/start/install.html
python3 -m uv pip install "sglang[all]==0.4.3.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
Install from docker image
-------------------------
......@@ -73,6 +92,7 @@ Image and tag: ``whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-
git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$(pwd)/Megatron-LM
Install from custom environment
---------------------------------
......
......@@ -57,6 +57,7 @@ test = [
]
prime = ["pyext"]
gpu = ["liger-kernel", "flash-attn"]
sglang = ["sglang[all]==0.4.3.post3"]
# URLs
[project.urls]
......
......@@ -17,5 +17,5 @@ ray[default]
tensordict<0.6
torchdata
transformers
vllm<=0.6.3
# vllm==0.6.3.post1
wandb
set -x
ENGINE=${1:-vllm}
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
......@@ -17,7 +17,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
......@@ -36,5 +36,5 @@ python3 -m verl.trainer.main_ppo \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=1 \
trainer.default_local_dir=$HOME/ckpt/ \
trainer.total_training_steps=1 $@
trainer.default_local_dir=$HOME/$ENGINE/ckpt/ \
trainer.total_training_steps=1
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.entrypoints.verl_engine import VerlEngine
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GenerationConfig
from verl.utils.torch_functional import pad_sequence_to_length
def levenshtein(s1, s2):
m, n = len(s1), len(s2)
# Initialize matrix of zeros
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize first column and first row of the matrix
for i in range(m + 1):
dp[i][0] = i # Deletion from s1 to empty string
for j in range(n + 1):
dp[0][j] = j # Insertion to s1 from empty string
# Compute the Levenshtein distance matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
)
return dp[m][n]
def are_lists_similar(a, b):
if len(a) != len(b):
print("The lists are of different lengths.")
return False
total_length = 0
total_diff = 0
for s1, s2 in zip(a, b):
max_len = max(len(s1), len(s2))
total_length += max_len
diff = levenshtein(s1, s2)
total_diff += diff
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")
percentage_difference = (total_diff / total_length) * 100
print(f"Total difference: {percentage_difference:.2f}%")
return percentage_difference <= 10
def initialize_global_process_group(timeout_second=36000):
from datetime import timedelta
import torch.distributed
# NOTE MODIFIED should provide backend=None to have nccl+gloo
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
return local_rank, rank, world_size
def test_sglang_spmd():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
initialize_global_process_group()
# fill rollout config
max_prompt_length = 16
max_response_length = 16
# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')
preencode_prompts = [
"Who won the Champions League in 2019?",
"The founder of Apple is",
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
actor_model.to(torch.bfloat16)
sampling_params = dict(n=1,
temperature=0,
top_p=1,
top_k=-1,
max_new_tokens=max_response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=True,
spaces_between_special_tokens=True,
ignore_eos=False)
tensor_parallel_size = 4
device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
print('building sglang rollout engine')
llm = VerlEngine(model_path=local_model_path,
dtype="bfloat16",
mem_fraction_static=0.5,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=0,
gpu_id_step=1)
llm.release_memory_occupation()
print("start generation")
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_size = input_ids.size(0)
generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_response_length,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
hf_response_tokens = tokenizer.batch_decode(response)
print(f"hf response: {hf_response_tokens}")
print(f"{sampling_params=}")
idx_list = []
batch_size = input_ids.shape[0]
pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)
sglang_response_tokens = []
for output in outputs:
print(f"{output=}")
generated_text = output["text"]
sglang_response_tokens.append(generated_text)
print(f"sglang response: {sglang_response_tokens}")
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \
f"Strings differ more than 10%:\n"
print("Check Pass")
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
......@@ -465,6 +465,7 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
for key, user_defined_cls in cls_dict.items():
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
# directly instantiate the class without remote
# in worker class, e.g. <verl.single_controller.base.worker.Worker> when DISABLE_WORKER_INIT == 1 it will return immediately
with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
**init_args_dict[key].get('kwargs', {}))
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
......@@ -14,6 +14,7 @@
from importlib.metadata import version, PackageNotFoundError
from packaging import version as vs
from verl.utils.import_utils import is_sglang_available
def get_version(pkg):
......@@ -59,6 +60,7 @@ elif vs.parse(package_version) >= vs.parse('0.7.0'):
from vllm import LLM
from vllm.distributed import parallel_state
else:
raise ValueError(
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
)
if not is_sglang_available():
raise ValueError(
f'vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
)
......@@ -16,6 +16,7 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
"""
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
import os
import ray
import hydra
......@@ -54,7 +55,9 @@ def main(config):
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={
......
......@@ -38,6 +38,15 @@ def is_vllm_available():
return False
@cache
def is_sglang_available():
try:
import sglang
return True
except ImportError:
return False
def import_external_libs(external_libs=None):
if external_libs is None:
return
......
......@@ -78,7 +78,7 @@ class ActorRolloutRefWorker(Worker):
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group()
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
......@@ -302,21 +302,18 @@ class ActorRolloutRefWorker(Worker):
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}'
rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp'])
if self.config.rollout.name == 'hf':
rollout_name = self.config.rollout.name
if rollout_name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
if self.config.rollout.use_fire_sampling:
from verl.workers.rollout.vllm_rollout import FIREvLLMRollout as vLLMRollout
from verl.workers.rollout.vllm_rollout import vllm_mode
else:
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
elif rollout_name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
local_path = copy_to_local(self.config.model.path)
if vllm_mode == 'customized':
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
......@@ -331,7 +328,7 @@ class ActorRolloutRefWorker(Worker):
device_mesh=rollout_device_mesh)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage('After building vllm rollout', logger=None)
log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
rollout_sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
......@@ -341,6 +338,30 @@ class ActorRolloutRefWorker(Worker):
device_mesh=rollout_device_mesh)
log_gpu_memory_usage('After building sharding manager', logger=None)
elif rollout_name == 'sglang':
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
# However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
rollout = SGLangRollout(actor_module=self.config.model.path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
rollout_sharding_manager = FSDPSGLangShardingManager(module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params='hf' in self.config.rollout.load_format,
device_mesh=rollout_device_mesh)
log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, rollout_sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
......@@ -490,7 +511,6 @@ class ActorRolloutRefWorker(Worker):
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.rollout_sharding_manager.postprocess_data(output)
......
......@@ -813,4 +813,4 @@ class RewardModelWorker(MegatronWorker):
data.batch = data.batch.cuda()
output = self.rm.compute_reward(data)
output = output.to('cpu')
return output
return output
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .sglang_rollout import SGLangRollout
......@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
from verl.utils.import_utils import (
is_vllm_available,
is_sglang_available,
is_megatron_core_available,
)
from .base import BaseShardingManager
from .fsdp_ulysses import FSDPUlyssesShardingManager
......@@ -31,3 +35,13 @@ if is_vllm_available():
from .fsdp_vllm import FSDPVLLMShardingManager
else:
FSDPVLLMShardingManager = None
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
# However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import SGLangShardingManager and user need to import use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
# if is_sglang_available():
# from .fsdp.fsdp_sglang import FSDPSGLangShardingManager
# else:
# FSDPSGLangShardingManager = None
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
from verl.utils.debug import log_gpu_memory_usage
from sglang.srt.entrypoints.verl_engine import VerlEngine
from .base import BaseShardingManager
from verl.third_party.sglang import parallel_state as sglang_ps
# from vllm.distributed import parallel_state as sglang_ps
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
class FSDPSGLangShardingManager(BaseShardingManager):
def __init__(self,
module: FSDP,
inference_engine: VerlEngine,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.device_mesh = device_mesh
# Full params
self.full_params = full_params
if full_params:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig())
else:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh['dp'].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def __enter__(self):
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory
load_format = None if self.full_params else 'dtensor'
self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None)
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
del params
torch.cuda.empty_cache()
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
# TODO: offload FSDP model weights
# self.module.cpu()
# torch.cuda.empty_cache()
# if torch.distributed.get_rank() == 0:
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage('Before SGLang offload in sharding manager', logger=logger)
self.inference_engine.release_memory_occupation
log_gpu_memory_usage('After SGLang offload in sharding manager', logger=logger)
# self.module.to('cuda')
# if torch.distributed.get_rank() == 0:
# print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
self.module.train()
# add empty cache after each compute
torch.cuda.empty_cache()
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=self.device_mesh["infer_tp"].mesh.size()[0],
group=self.device_mesh["infer_tp"].get_group(),
dim=0)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp
global_rank = self.device_mesh.get_rank()
tp_rank = self.device_mesh["infer_tp"].get_local_rank()
tp_size = self.device_mesh["infer_tp"].mesh.size()[0]
src_rank = global_rank // tp_size * tp_size
broadcast_dict_tensor(data.batch, src=src_rank, group=self.device_mesh["infer_tp"].get_group())
if tp_size > 1:
local_prompts = data.chunk(chunks=tp_size)
data = local_prompts[tp_rank]
return data
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment