Unverified Commit f8b4d085 by ZSL98 Committed by GitHub

[testing][rollout] feat: support integration of vllm>=0.7.0 (spmd-version) (#209)

This PR aims to integrate vllm>=0.7.0 and preserve:
**Backward compatibility**: 0.3.1, 0.4.2, 0.5.4, 0.6.3 are still
supported
**Forward compatibility**: Future versions of vllm (>= 0.7.0) will be
supported without requiring manual maintenance for each new release.

The readme of this Beta version is located at docs/README_vllm0.7.md,
where users can find the installation method and related features. This
readme is copied as below.

---
# Readme for verl(vllm>=0.7) version
## Installation

Note: This version of veRL supports **FSDP** for training and **vLLM**
for rollout. (Megatron-LM is not supported yet.)

```
# Create the conda environment
conda create -n verl python==3.10
conda activate verl

# Install verl
git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
# Install vLLM>=0.7
pip3 install vllm==0.7.0
# Install flash-attn
pip3 install flash-attn --no-build-isolation

```

For existing stable vllm versions (<=0.7.2), you also need to make some
tiny patches manually on vllm (/path/to/site-packages/vllm after
installation) after the above steps:

- vllm/distributed/parallel_state.py: Remove the assertion below:

```
if (world_size
        != tensor_model_parallel_size * pipeline_model_parallel_size):
    raise RuntimeError(
        f"world_size ({world_size}) is not equal to "
        f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
        f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

```

- vllm/executor/uniproc_executor.py: change `local_rank = rank` to
`local_rank = int(os.environ["LOCAL_RANK"])`
- vllm/model_executor/model_loader/weight_utils.py: remove the
`torch.cuda.empty_cache()` in `pt_weights_iterator`

These modifications have already been merged into the main branch of
vLLM. To avoid modifying these files manually, you can directly build
vLLM from source.

## Features

### Use cuda graph

After installation, examples using FSDP as training backends can be
used. By default, the `enforce_eager` is set to True, which disables the
cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add
the following lines to the bash script:

```
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \

```

For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh,
the rollout generation time is 115 seconds with vLLM0.6.3, while it is
85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation
duration is further reduced to 62 seconds.

**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in
vLLM>=0.7, there is a potential performance issue on the stability of
rollout generation time (Some iterations would see generation time
bursts). We are working with the vLLM team to check this issue.

### Other features in vLLM

1. **num_scheduler_step>1:** not supported yet (weight loading has not
been aligned with `MultiStepModelRunner`)
2. **Prefix caching:** not supported yet (vLLM sleep mode does not
support prefix caching)
3. **Chunked prefill:** supported

---------

Co-authored-by: zhangshulai <zhangshulai@bytedance.com>
parent 63f75138
...@@ -40,3 +40,8 @@ jobs: ...@@ -40,3 +40,8 @@ jobs:
run: | run: |
cd tests/rollout cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py
- name: Test the latest vLLM
run: |
pip3 install --upgrade vllm
cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py
...@@ -89,6 +89,9 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex ...@@ -89,6 +89,9 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex
## Performance Tuning Guide ## Performance Tuning Guide
The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details. The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details.
## vLLM v0.7 testing version
We have released a testing version of veRL that supports vLLM>=0.7.0. Please refer to [this document](https://github.com/volcengine/verl/docs/README_vllm0.7.md) for installation guide and more information.
## Contribution Guide ## Contribution Guide
Contributions from the community are welcome! Contributions from the community are welcome!
......
# Readme for verl(vllm>=0.7) version
## Installation
Note: This version of veRL supports **FSDP** for training and **vLLM** for rollout. (Megatron-LM is not supported yet.)
```
# Create the conda environment
conda create -n verl python==3.10
conda activate verl
# Install verl
git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
# Install vLLM>=0.7
pip3 install "vllm>=0.7.0"
# Install flash-attn
pip3 install flash-attn --no-build-isolation
```
For existing stable vllm versions (<=0.7.2), you also need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps:
- vllm/distributed/parallel_state.py: Remove the assertion below:
```
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
```
- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ["LOCAL_RANK"])`
- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator`
These modifications have already been merged into the main branch of vLLM. To avoid modifying these files manually, you can directly build vLLM from source.
## Features
### Use cuda graph
After installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script:
```
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
```
For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 115 seconds with vLLM0.6.3, while it is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds.
**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts). We are working with the vLLM team to check this issue.
### Other features in vLLM
1. **num_scheduler_step>1:** not supported yet (weight loading has not been aligned with `MultiStepModelRunner`)
2. **Prefix caching:** not supported yet (vLLM sleep mode does not support prefix caching)
3. **Chunked prefill:** supported
\ No newline at end of file
...@@ -40,7 +40,7 @@ dependencies = [ ...@@ -40,7 +40,7 @@ dependencies = [
"ray>=2.38", "ray>=2.38",
"tensordict", "tensordict",
"transformers<4.48", "transformers<4.48",
"vllm<=0.6.3", "vllm<=0.7.3",
"peft", "peft",
"liger-kernel", "liger-kernel",
"pylatexenc", "pylatexenc",
......
...@@ -12,7 +12,7 @@ pybind11 ...@@ -12,7 +12,7 @@ pybind11
ray>=2.38 ray>=2.38
tensordict<0.6 tensordict<0.6
transformers<4.48 transformers<4.48
vllm<=0.6.3 vllm
wandb wandb
liger-kernel liger-kernel
pylatexenc pylatexenc
......
...@@ -23,6 +23,9 @@ from verl.third_party.vllm import LLM ...@@ -23,6 +23,9 @@ from verl.third_party.vllm import LLM
from vllm import SamplingParams from vllm import SamplingParams
import time
import torch.distributed as dist
def main(): def main():
assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example' assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example'
...@@ -112,10 +115,25 @@ def main(): ...@@ -112,10 +115,25 @@ def main():
enforce_eager=True, enforce_eager=True,
dtype='bfloat16', dtype='bfloat16',
load_format='dummy_dtensor', load_format='dummy_dtensor',
gpu_memory_utilization=0.1, gpu_memory_utilization=0.8,
trust_remote_code=True) trust_remote_code=True)
# Warmup iterations
for _ in range(10):
torch.cuda.synchronize()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
torch.cuda.synchronize()
dist.barrier()
start_time = time.time()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
torch.cuda.synchronize()
dist.barrier()
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.6f} seconds")
input_ids = input_ids.cuda() input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda() attention_mask = attention_mask.cuda()
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import transformers
from vllm import LLM, SamplingParams
from verl.utils.model import update_model_config
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import GenerationConfig
from verl.utils.torch_functional import pad_sequence_to_length
def levenshtein(s1, s2):
m, n = len(s1), len(s2)
# Initialize matrix of zeros
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize first column and first row of the matrix
for i in range(m + 1):
dp[i][0] = i # Deletion from s1 to empty string
for j in range(n + 1):
dp[0][j] = j # Insertion to s1 from empty string
# Compute the Levenshtein distance matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
)
return dp[m][n]
def are_lists_similar(a, b):
if len(a) != len(b):
print("The lists are of different lengths.")
return False
total_length = 0
total_diff = 0
for s1, s2 in zip(a, b):
max_len = max(len(s1), len(s2))
total_length += max_len
diff = levenshtein(s1, s2)
total_diff += diff
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")
percentage_difference = (total_diff / total_length) * 100
print(f"Total difference: {percentage_difference:.2f}%")
return percentage_difference <= 10
def test_vllm_spmd():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
# fill rollout config
max_prompt_length = 16
max_response_length = 16
# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
from verl.utils.fs import copy_local_path_from_hdfs
local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')
preencode_prompts = [
"Who won the Champions League in 2019?",
"The founder of Apple is",
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
actor_model.to(torch.bfloat16)
actor_model_config = AutoConfig.from_pretrained(local_model_path)
temperature = 0
top_p = 1
kwargs = dict(n=1,
temperature=temperature,
top_p=top_p,
max_tokens=max_response_length,
logprobs=1,
ignore_eos=True)
sampling_params = SamplingParams(**kwargs)
tensor_parallel_size = 4
llm = LLM(model=local_model_path,
enable_sleep_mode=True,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="external_launcher",
dtype='bfloat16',
gpu_memory_utilization=0.5)
print('start generation')
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_size = input_ids.size(0)
generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_response_length,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]
hf_response_tokens = tokenizer.batch_decode(response)
print(f'hf response: {hf_response_tokens}')
outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)
vllm_response_tokens = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
vllm_response_tokens.append(generated_text)
print(f'vllm response: {vllm_response_tokens}')
assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \
f'Strings differ more than 10%:\n'
print('Check Pass')
# if __name__ == "__main__":
# test_vllm_spmd()
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from importlib.metadata import version, PackageNotFoundError from importlib.metadata import version, PackageNotFoundError
from packaging import version as vs
def get_version(pkg): def get_version(pkg):
...@@ -24,6 +25,7 @@ def get_version(pkg): ...@@ -24,6 +25,7 @@ def get_version(pkg):
package_name = 'vllm' package_name = 'vllm'
package_version = get_version(package_name) package_version = get_version(package_name)
vllm_version = None
if package_version == '0.3.1': if package_version == '0.3.1':
vllm_version = '0.3.1' vllm_version = '0.3.1'
...@@ -45,7 +47,14 @@ elif package_version == '0.6.3': ...@@ -45,7 +47,14 @@ elif package_version == '0.6.3':
from .vllm_v_0_6_3.llm import LLM from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state from .vllm_v_0_6_3 import parallel_state
elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'):
# From 0.6.6.post2 on, vllm supports SPMD inference
# See https://github.com/vllm-project/vllm/pull/12071
from vllm import LLM
from vllm.distributed import parallel_state
from .vllm_spmd.dtensor_weight_loaders import load_dtensor_weights
else: else:
raise ValueError( raise ValueError(
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
) )
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -225,6 +225,20 @@ def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[Tens ...@@ -225,6 +225,20 @@ def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[Tens
return tensors.split(batch_size) return tensors.split(batch_size)
def pad_2d_list_to_length(response, pad_token_id, max_length=None):
"""
pad a 2D list (e.g. responses, logprobs) to a 2D tensor.
"""
response_length = max(len(sub_list) for sub_list in response)
if max_length is not None and max_length > response_length:
target_length = max_length
else:
target_length = response_length
padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]
tensor = torch.tensor(padded_response)
return tensor
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
""" """
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length. pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
......
...@@ -300,13 +300,23 @@ class ActorRolloutRefWorker(Worker): ...@@ -300,13 +300,23 @@ class ActorRolloutRefWorker(Worker):
rollout_sharding_manager = BaseShardingManager() rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing? # TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm': elif self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import FSDPVLLMShardingManager from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None) log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp, local_path = copy_local_path_from_hdfs(self.config.model.path)
config=self.config.rollout, if vllm_mode == 'customized':
tokenizer=self.tokenizer, rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
model_hf_config=self.actor_model_config) config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
elif vllm_mode == 'spmd':
rollout = vLLMRollout(model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage('After building vllm rollout', logger=None) log_gpu_memory_usage('After building vllm rollout', logger=None)
if torch.distributed.get_world_size() == 1: if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf' self.config.rollout.load_format = 'dummy_hf'
......
...@@ -223,7 +223,7 @@ class ActorRolloutRefWorker(MegatronWorker): ...@@ -223,7 +223,7 @@ class ActorRolloutRefWorker(MegatronWorker):
def _build_rollout(self): def _build_rollout(self):
if self.config.rollout.name == 'vllm': if self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import MegatronVLLMShardingManager from verl.workers.sharding_manager import MegatronVLLMShardingManager
from verl.utils.model import normalize_pp_vpp_params from verl.utils.model import normalize_pp_vpp_params
...@@ -247,6 +247,7 @@ class ActorRolloutRefWorker(MegatronWorker): ...@@ -247,6 +247,7 @@ class ActorRolloutRefWorker(MegatronWorker):
params = normalize_pp_vpp_params(params=params, params = normalize_pp_vpp_params(params=params,
num_hidden_layers=self.actor_model_config.num_hidden_layers, num_hidden_layers=self.actor_model_config.num_hidden_layers,
layer_name='layers') layer_name='layers')
assert vllm_mode == 'customized', "Support for vllm>=0.7 for Megatron-LM backend has not been implemented yet."
rollout = vLLMRollout(actor_module=params, rollout = vLLMRollout(actor_module=params,
config=self.config.rollout, config=self.config.rollout,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
......
...@@ -12,4 +12,22 @@ ...@@ -12,4 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .vllm_rollout import vLLMRollout from importlib.metadata import version, PackageNotFoundError
\ No newline at end of file
def get_version(pkg):
try:
return version(pkg)
except PackageNotFoundError:
return None
package_name = 'vllm'
package_version = get_version(package_name)
if package_version <= '0.6.3':
vllm_mode = 'customized'
from .vllm_rollout import vLLMRollout
else:
vllm_mode = 'spmd'
from .vllm_rollout_spmd import vLLMRollout
...@@ -24,6 +24,7 @@ from verl.third_party.vllm import parallel_state as vllm_ps ...@@ -24,6 +24,7 @@ from verl.third_party.vllm import parallel_state as vllm_ps
from verl import DataProto from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
from verl.utils.debug import log_gpu_memory_usage from verl.utils.debug import log_gpu_memory_usage
from verl.third_party.vllm import vllm_version
from .base import BaseShardingManager from .base import BaseShardingManager
...@@ -72,7 +73,17 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -72,7 +73,17 @@ class FSDPVLLMShardingManager(BaseShardingManager):
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
# Copy, not share memory # Copy, not share memory
load_format = 'hf' if self.full_params else 'dtensor' load_format = 'hf' if self.full_params else 'dtensor'
self.inference_engine.sync_model_weights(params, load_format=load_format) if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
self.inference_engine.sync_model_weights(params, load_format=load_format)
else:
self.inference_engine.wake_up()
# TODO(ZSL): deal with 'hf' format
if load_format == 'dtensor':
from verl.third_party.vllm import load_dtensor_weights
load_dtensor_weights(
params, self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model)
else:
raise NotImplementedError(f'load_format {load_format} not implemented')
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
del params del params
...@@ -92,7 +103,11 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -92,7 +103,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger)
self.inference_engine.offload_model_weights() # TODO(ZSL): check this
if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
self.inference_engine.offload_model_weights()
else:
self.inference_engine.sleep(level=1)
log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger)
# self.module.to('cuda') # self.module.to('cuda')
...@@ -111,18 +126,29 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -111,18 +126,29 @@ class FSDPVLLMShardingManager(BaseShardingManager):
def preprocess_data(self, data: DataProto) -> DataProto: def preprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp # TODO: Current impl doesn't consider FSDP with torch micro-dp
data.batch = allgather_dict_tensors(data.batch.contiguous(), if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
size=vllm_ps.get_tensor_model_parallel_world_size(), data.batch = allgather_dict_tensors(data.batch.contiguous(),
group=vllm_ps.get_tensor_model_parallel_group(), size=vllm_ps.get_tensor_model_parallel_world_size(),
dim=0) group=vllm_ps.get_tensor_model_parallel_group(),
dim=0)
else:
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=vllm_ps.get_tensor_model_parallel_world_size(),
group=vllm_ps.get_tensor_model_parallel_group().device_group,
dim=0)
return data return data
def postprocess_data(self, data: DataProto) -> DataProto: def postprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp # TODO: Current impl doesn't consider FSDP with torch micro-dp
broadcast_dict_tensor(data.batch, local_world_size = vllm_ps.get_tensor_model_parallel_world_size()
src=vllm_ps.get_tensor_model_parallel_src_rank(), src_rank = (torch.distributed.get_rank() // local_world_size) * local_world_size
group=vllm_ps.get_tensor_model_parallel_group()) if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
broadcast_dict_tensor(data.batch, src=src_rank, group=vllm_ps.get_tensor_model_parallel_group())
else:
broadcast_dict_tensor(data.batch,
src=src_rank,
group=vllm_ps.get_tensor_model_parallel_group().device_group)
dp_rank = torch.distributed.get_rank() dp_rank = torch.distributed.get_rank()
dp_size = torch.distributed.get_world_size() # not consider torch micro-dp dp_size = torch.distributed.get_world_size() # not consider torch micro-dp
tp_size = vllm_ps.get_tensor_model_parallel_world_size() tp_size = vllm_ps.get_tensor_model_parallel_world_size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment