Commit 98fc5c65 by ZhangXiaoyun

modify

parents 8dd09ef8 612823ae
...@@ -55,7 +55,7 @@ jobs: ...@@ -55,7 +55,7 @@ jobs:
torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py
- name: Test the latest vLLM - name: Test the latest vLLM
run: | run: |
pip3 install --upgrade vllm pip3 install --upgrade vllm==0.7.3
cd tests/rollout cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py
- name: Run Qwen 0.5B generation test - name: Run Qwen 0.5B generation test
......
待完善 待完善
\ No newline at end of file
# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
FROM nvcr.io/nvidia/pytorch:24.08-py3
# uninstall nv-pytorch fork
RUN pip3 uninstall -y pytorch-quantization \
pytorch-triton torch torch-tensorrt torchvision \
xgboost transformer_engine flash_attn apex megatron-core
# Define environments
ENV MAX_JOBS=32
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
# Define installation arguments
ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
# Set apt source
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
{ \
echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
} > /etc/apt/sources.list
# Install systemctl
RUN apt-get update && \
apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
apt-get clean
# Install tini
RUN apt-get update && \
apt-get install -y tini && \
apt-get clean
# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
python -m pip install --upgrade pip
# Install torch-2.6.0 + vllm-0.8.1
RUN pip install --no-cache-dir vllm==0.8.1 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \
transformers>=4.49.0 accelerate datasets peft hf-transfer \
ray codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
pytest yapf py-spy pyext pre-commit ruff
# Install flash_attn-2.7.4.post1
RUN pip uninstall -y transformer-engine flash-attn && \
wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# Fix cv2
RUN pip uninstall -y pynvml nvidia-ml-py && \
pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \
pip install -U optree>=0.13.0
# Upgrading to vLLM >= 0.8
## Installation
Note: This version of veRL+vLLM 0.8+ supports **FSDP** for training and **vLLM** for rollout.
```bash
# Create the conda environment
conda create -n verl python==3.10
conda activate verl
# Install verl
git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
# Install the latest stable version of vLLM
pip3 install vllm==0.8.1
# Install flash-attn
pip3 install flash-attn --no-build-isolation
```
We have a pre-built docker image for veRL+vLLM 0.8.0. You can direct import it with the following command:
```bash
docker pull hiyouga/verl:ngc-th2.6.0-cu120-vllm0.8.0
```
## Features
vLLM 0.8+ supports cuda graph and V1 engine by default in veRL. To enable these features, remember to add the following lines to the bash script:
```bash
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
```
and also **remove** the environment variable if it exists:
```bash
export VLLM_ATTENTION_BACKEND=XFORMERS
```
...@@ -74,6 +74,7 @@ verl is fast with: ...@@ -74,6 +74,7 @@ verl is fast with:
perf/perf_tuning perf/perf_tuning
README_vllm0.7.md README_vllm0.7.md
README_vllm0.8.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -35,7 +35,6 @@ dependencies = [ ...@@ -35,7 +35,6 @@ dependencies = [
"datasets", "datasets",
"dill", "dill",
"hydra-core", "hydra-core",
"math-verify",
"numpy", "numpy",
"pandas", "pandas",
"peft", "peft",
...@@ -58,6 +57,7 @@ test = [ ...@@ -58,6 +57,7 @@ test = [
prime = ["pyext"] prime = ["pyext"]
gpu = ["liger-kernel", "flash-attn"] gpu = ["liger-kernel", "flash-attn"]
sglang = ["sglang[all]==0.4.3.post3"] sglang = ["sglang[all]==0.4.3.post3"]
math = ["math-verify"] # Add math-verify as an optional dependency
# URLs # URLs
[project.urls] [project.urls]
......
...@@ -6,7 +6,6 @@ dill ...@@ -6,7 +6,6 @@ dill
# flash-attn # flash-attn
hydra-core hydra-core
liger-kernel liger-kernel
math-verify[antlr4_9_3]
numpy numpy
pandas pandas
peft peft
......
...@@ -27,7 +27,6 @@ install_requires = [ ...@@ -27,7 +27,6 @@ install_requires = [
'datasets', 'datasets',
'dill', 'dill',
'hydra-core', 'hydra-core',
'math-verify',
'numpy', 'numpy',
'pandas', 'pandas',
'peft', 'peft',
...@@ -46,12 +45,14 @@ TEST_REQUIRES = ['pytest', 'yapf', 'py-spy'] ...@@ -46,12 +45,14 @@ TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
PRIME_REQUIRES = ['pyext'] PRIME_REQUIRES = ['pyext']
GEO_REQUIRES = ['mathruler'] GEO_REQUIRES = ['mathruler']
GPU_REQUIRES = ['liger-kernel', 'flash-attn'] GPU_REQUIRES = ['liger-kernel', 'flash-attn']
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
extras_require = { extras_require = {
'test': TEST_REQUIRES, 'test': TEST_REQUIRES,
'prime': PRIME_REQUIRES, 'prime': PRIME_REQUIRES,
'geo': GEO_REQUIRES, 'geo': GEO_REQUIRES,
'gpu': GPU_REQUIRES, 'gpu': GPU_REQUIRES,
'math': MATH_REQUIRES,
} }
from pathlib import Path from pathlib import Path
......
...@@ -24,7 +24,8 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb ...@@ -24,7 +24,8 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.utils import logging from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \
get_ulysses_sequence_parallel_world_size, validate_ulysses_config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -68,6 +69,8 @@ def llama_flash_attn_forward( ...@@ -68,6 +69,8 @@ def llama_flash_attn_forward(
ulysses_sp_size = get_ulysses_sequence_parallel_world_size() ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1: if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
...@@ -177,6 +180,8 @@ def llama_attn_forward( ...@@ -177,6 +180,8 @@ def llama_attn_forward(
ulysses_sp_size = get_ulysses_sequence_parallel_world_size() ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1: if ulysses_sp_size > 1:
validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
......
...@@ -20,7 +20,8 @@ from transformers.cache_utils import Cache ...@@ -20,7 +20,8 @@ from transformers.cache_utils import Cache
from transformers.utils import logging from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \
get_ulysses_sequence_parallel_world_size, validate_ulysses_config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -55,6 +56,8 @@ def qwen2_flash_attn_forward( ...@@ -55,6 +56,8 @@ def qwen2_flash_attn_forward(
ulysses_sp_size = get_ulysses_sequence_parallel_world_size() ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1: if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
...@@ -168,6 +171,8 @@ def qwen2_attn_forward( ...@@ -168,6 +171,8 @@ def qwen2_attn_forward(
ulysses_sp_size = get_ulysses_sequence_parallel_world_size() ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1: if ulysses_sp_size > 1:
validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
......
...@@ -18,7 +18,8 @@ import torch ...@@ -18,7 +18,8 @@ import torch
import os import os
from transformers.utils import is_flash_attn_greater_or_equal from transformers.utils import is_flash_attn_greater_or_equal
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \
get_ulysses_sequence_parallel_world_size, validate_ulysses_config
try: try:
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
...@@ -236,6 +237,8 @@ def ulysses_flash_attn_forward( ...@@ -236,6 +237,8 @@ def ulysses_flash_attn_forward(
ulysses_sp_size = get_ulysses_sequence_parallel_world_size() ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1: if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
......
...@@ -51,6 +51,10 @@ def get_custom_reward_fn(config): ...@@ -51,6 +51,10 @@ def get_custom_reward_fn(config):
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None) @hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config): def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future # isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '') os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
......
...@@ -19,12 +19,16 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N ...@@ -19,12 +19,16 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
from . import gsm8k from . import gsm8k
res = gsm8k.compute_score(solution_str, ground_truth) res = gsm8k.compute_score(solution_str, ground_truth)
elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
# from . import math from . import math
# res = math.compute_score(solution_str, ground_truth) res = math.compute_score(solution_str, ground_truth)
# Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy # [Optional] Math-Verify Integration
from . import math_verify # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
res = math_verify.compute_score(solution_str, ground_truth) # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
# To use it, override the `compute_score` function with the following implementation:
# from . import math_verify
# res = math_verify.compute_score(solution_str, ground_truth)
elif data_source in [ elif data_source in [
'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
'numina_olympiads' 'numina_olympiads'
......
...@@ -12,8 +12,11 @@ ...@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from math_verify.metric import math_metric try:
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
except ImportError:
print("To use Math-Verify, please install it first by running `pip install math-verify`.")
def compute_score(model_output: str, ground_truth: str) -> bool: def compute_score(model_output: str, ground_truth: str) -> bool:
...@@ -28,6 +31,6 @@ def compute_score(model_output: str, ground_truth: str) -> bool: ...@@ -28,6 +31,6 @@ def compute_score(model_output: str, ground_truth: str) -> bool:
try: try:
ret_score, _ = verify_func([ground_truth_boxed], [model_output]) ret_score, _ = verify_func([ground_truth_boxed], [model_output])
except Exception as e: except Exception as e:
print(e) pass
return ret_score return ret_score
...@@ -58,7 +58,10 @@ class Tracking(object): ...@@ -58,7 +58,10 @@ class Tracking(object):
swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten
swanlab.init(project=project_name, swanlab.init(project=project_name,
experiment_name=experiment_name, experiment_name=experiment_name,
config=config, config={
"FRAMEWORK": "veRL",
**config
},
logdir=SWANLAB_LOG_DIR, logdir=SWANLAB_LOG_DIR,
mode=SWANLAB_MODE) mode=SWANLAB_MODE)
self.logger["swanlab"] = swanlab self.logger["swanlab"] = swanlab
......
...@@ -286,3 +286,9 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, ...@@ -286,3 +286,9 @@ def ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor,
# we don't need to slice position ids # we don't need to slice position ids
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
return input_ids_rmpad, position_ids_rmpad, pad_size return input_ids_rmpad, position_ids_rmpad, pad_size
def validate_ulysses_config(num_heads, ulysses_sequence_size):
if ulysses_sequence_size > 1:
assert num_heads % ulysses_sequence_size == 0,\
f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})"
...@@ -103,12 +103,6 @@ class SGLangRollout(BaseRollout): ...@@ -103,12 +103,6 @@ class SGLangRollout(BaseRollout):
super().__init__() super().__init__()
self.config = config self.config = config
# TODO(linjunrong.ocss884): this substitution is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
del os.environ["CUDA_VISIBLE_DEVICES"]
if os.environ["ENSURE_CUDA_VISIBLE_DEVICES"]:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["ENSURE_CUDA_VISIBLE_DEVICES"]
assert not (not config.enforce_eager and assert not (not config.enforce_eager and
config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
...@@ -142,16 +136,17 @@ class SGLangRollout(BaseRollout): ...@@ -142,16 +136,17 @@ class SGLangRollout(BaseRollout):
# device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs) # device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs)
# get tp_rank of this process in this tp group # get tp_rank of this process in this tp group
global_rank = device_mesh_cpu.get_rank() visible_devices = [None] * device_mesh_cpu.size(1)
tp_size = device_mesh_cpu["tp"].mesh.size()[0] torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"],
src_rank = global_rank // tp_size * tp_size device_mesh_cpu.get_group("tp"))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(visible_devices)
self.inference_engine = VerlEngine( self.inference_engine = VerlEngine(
model_path=actor_module, model_path=actor_module,
dtype=config.dtype, dtype=config.dtype,
mem_fraction_static=config.gpu_memory_utilization, mem_fraction_static=config.gpu_memory_utilization,
device_mesh_cpu=device_mesh_cpu["tp"], device_mesh_cpu=device_mesh_cpu["tp"],
base_gpu_id=src_rank, base_gpu_id=0,
gpu_id_step=1, gpu_id_step=1,
# NOTE(Chenyang): if you want to debug the sglang engine # NOTE(Chenyang): if you want to debug the sglang engine
# please set the following parameters # please set the following parameters
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment