Unverified Commit b46f55ec by hoshi-hiyouga Committed by GitHub

[feat] Initial support for VLMs, add Qwen2.5VL GRPO example (#386)

## What does this PR do?

This PR migrates the feature of RL on VLMs in our implementation in
[EasyR1](https://github.com/hiyouga/EasyR1) fork back to veRL. We have
validated this feature using Qwen2.5-VL 7B model on 8*H100 GPUs. The
configuration and data processing script are provided along this PR for
easy reproducing.

## How to reproduce?

1. Download and preprocess the dataset

```bash
python3 examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k
```

2. Start GRPO training

```bash
bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh
```

## Dependencies

- vllm>=0.7.3
- transformers>=4.49.0
- [qwen-vl-utils](https://pypi.org/project/qwen-vl-utils/)
- [mathruler](https://pypi.org/project/mathruler/)

## Major Changes

### New dataflow for multimodal RL

In this PR, we introduce two new concepts in the dataflow,
`multi_modal_data` and `multi_modal_inputs`. The former means the
multi-modal features required by the **rollout** worker (such as vLLM),
while the latter means the multi-modal features required by the
**actor/critic** worker (such as an HF model). They are different
because the rollout and actor workers have their own data format
requirements.

Taking Qwen2-VL + huggingface + vLLM as an example, the data structure
should be:

- **multi_modal_data**: {"image": [PIL.Image, PIL.Image, ...]}
- **multi_modal_inputs**: {"pixel_values": torch.Tensor,
"image_grid_thw": torch.Tensor}

Both of them are converted to numpy objects and placed in the non-tensor
batch in DataProto.

This design can be extended to other modalities/VLMs easily due to the
agnostic of models.

### Other changes

- Data
- Support pre-processing the
[Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
dataset.
- Support `config.data.image_key`, which should be **a list of Pillow
images**.

- Actor/Ref/Critic
  - Support `multi_modal_inputs`.
  - Process position ids to adapt to the m-rope .

- Rollout
- Update dtensor weight loader to adapt to the Qwen2-VL architecture in
vLLM 0.7+.
  - Support `multi_modal_data`.
- Use `raw_prompt_ids` as the vLLM inputs to **avoid unpadding** the
input ids.

- Reward Manager
- Add **mathruler** for more accurate math scores on the Geometry 3k
dataset

- Models
  - Support calculating the position ids for the m-rope in Qwen2-VL.
- Support removing padding in flash attention2 for m-rope (transformers
itself **does not support it**).

- Sharding Manager
  - Support all-gathering the non-tensor batch.

- FSDP Workers / Checkpoint Merger
  - Support `AutoModelForVision2Seq` at model initialization.

Note: The Ulysses parallelism is not completed yet. We will support it
in the next update.

## Performance

We provide the estimated MFU of the language model part for H100 GPUs.
These values are lower than the actual ones because **we did not compute
the FLOPs of the vision tower part**.

- `remove_padding=False`: MFU ~7%
- `remove_padding=True`: MFU ~20%

The training and test reward score curves are presented as follows.


![image](https://github.com/user-attachments/assets/ecb9fc27-8591-4c5b-ae4b-4ba77c6e30f9)

## Who can review?

@vermouth1992 @PeterSH6
parent a0a4d5fa
name: e2e_vlm_geo3k
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/e2e_vlm_geo3k.yml
pull_request:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/e2e_vlm_geo3k.yml
- "tests/e2e/*.sh"
# Declare permissions just read content.
permissions:
contents: read
jobs:
e2e_vlm_geo3k:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_HOME: ${{ secrets.HF_HOME }}
container:
image: hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.3-rc1
options: --gpus all --shm-size=40g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install --no-deps -e .[test,gpu,geo]
pip3 install --no-deps mathruler
python -c "import transformers; print(transformers.__version__)"
- name: Prepare geo3k dataset
run: |
ray stop --force
python3 examples/data_preprocess/geo3k.py
- name: Running geo3k vlm e2e training tests on 8 L20 GPUs with rmpad using function rm
run: |
ray stop --force
bash tests/e2e/run_qwen2vl_geo3k_function_rm.sh
......@@ -113,4 +113,7 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
*.swp
# ckpt
*.lock
\ No newline at end of file
*.lock
# data
*.parquet
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Geometry3k dataset to parquet format
"""
import os
import datasets
from verl.utils.hdfs_io import copy, makedirs
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/geo3k')
parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args()
data_source = 'hiyouga/geometry3k'
dataset = datasets.load_dataset(data_source)
train_dataset = dataset['train']
test_dataset = dataset['test']
instruction_following = r"Please reason step by step, and put your final answer within \boxed{}."
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
problem = example.pop('problem')
prompt = problem + ' ' + instruction_following
answer = example.pop('answer')
images = example.pop('images')
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": prompt,
}],
"images": images,
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": answer
},
"extra_info": {
'split': split,
'index': idx,
'answer': answer,
"question": problem,
}
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True, num_proc=8)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True, num_proc=8)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1536 \
data.max_response_length=1536 \
data.image_key=images \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
......@@ -17,7 +17,7 @@ import re
import os
import torch
import argparse
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from torch.distributed._tensor import DTensor, Shard, Placement
......@@ -140,6 +140,8 @@ if __name__ == '__main__':
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
......
......@@ -43,11 +43,13 @@ install_requires = [
TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
PRIME_REQUIRES = ['pyext']
GEO_REQUIRES = ['mathruler']
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
extras_require = {
'test': TEST_REQUIRES,
'prime': PRIME_REQUIRES,
'geo': GEO_REQUIRES,
'gpu': GPU_REQUIRES,
}
......
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=1536 \
data.max_response_length=1536 \
data.image_key=images \
actor_rollout_ref.model.path=Qwen/Qwen2-VL-2B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
algorithm.adv_estimator=grpo \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_geo3k' \
trainer.experiment_name='qwen2vl_e2e_ci_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
......@@ -19,18 +19,25 @@ import torch.nn as nn
# Supported models using HF Rmpad
# TODO(sgm): HF may supported more than listed here, we should add more after testing
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}
_MODELS_SUPPORT_RMPAD = {'llama', 'mistral', 'gemma', 'qwen2', 'qwen2_vl', 'qwen2_5_vl'}
def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
if not model_type in _REOVEPAD_MODELS.keys():
if not model_type in _MODELS_SUPPORT_RMPAD:
raise ValueError(f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
f"RMPad supported architectures: {_MODELS_SUPPORT_RMPAD}."
f"Please set `use_remove_padding=False` in the model config.")
if model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward
Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward
print("Qwen2vl patch applied!")
# Supported models in Megatron-LM
# Architecture -> (module, class).
......
......@@ -203,7 +203,11 @@ def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) ->
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if "visual" in name:
continue
name = "language_model." + name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -216,6 +220,11 @@ def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) ->
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if "visual" in name:
name = name
else:
name = "language_model." + name
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
......@@ -355,6 +364,7 @@ __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
"DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader,
"Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
"Qwen2_5_VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
}
......
......@@ -10,6 +10,7 @@ data:
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
image_key: images
actor_rollout_ref:
hybrid_engine: True
......
......@@ -46,8 +46,9 @@ def main_task(config, compute_score=None):
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
from verl.utils import hf_tokenizer, hf_processor
tokenizer = hf_tokenizer(local_path)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
......@@ -117,6 +118,7 @@ def main_task(config, compute_score=None):
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
......
......@@ -355,12 +355,14 @@ class RayPPOTrainer(object):
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
......@@ -492,7 +494,9 @@ class RayPPOTrainer(object):
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
image_key=self.config.data.get('image_key', 'images'),
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
......@@ -507,13 +511,16 @@ class RayPPOTrainer(object):
self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler)
self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
image_key=self.config.data.get('image_key', 'images'),
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
......@@ -523,6 +530,7 @@ class RayPPOTrainer(object):
# Validation datasets are sent to inference engines as a whole batch,
# which will schedule the memory themselves.
batch_size=len(self.val_dataset),
num_workers=8,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
......@@ -619,7 +627,17 @@ class RayPPOTrainer(object):
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
)
else:
test_gen_batch = test_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
)
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
......@@ -880,7 +898,16 @@ class RayPPOTrainer(object):
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
if 'multi_modal_inputs' in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
)
else:
gen_batch = batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
)
with _timer('step', timing_raw):
# generate a batch
......
......@@ -13,6 +13,6 @@
# limitations under the License.
from . import tokenizer
from .tokenizer import *
from .tokenizer import hf_tokenizer, hf_processor
__all__ = tokenizer.__all__
\ No newline at end of file
......@@ -15,11 +15,11 @@ import os
import shutil
from filelock import FileLock
import tempfile
from typing import Union
import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, ProcessorMixin
import numpy as np
import random
......@@ -40,14 +40,15 @@ class BaseCheckpointManager:
"""
def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer):
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, processing_class: Union[PreTrainedTokenizer,
ProcessorMixin]):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.processing_class = processing_class
assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank()
......
......@@ -16,7 +16,7 @@ import ray
import os
import warnings
from typing import Union
import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
......@@ -24,7 +24,7 @@ from torch.distributed.fsdp import ShardedStateDictConfig, ShardedOptimStateDict
from verl.utils.fs import copy_to_local, is_non_local
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
......@@ -41,12 +41,22 @@ class FSDPCheckpointManager(BaseCheckpointManager):
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer and config for ckpt merge
- huggingface tokenizer/processor and config for ckpt merge
"""
def __init__(self, model: FSDP, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, tokenizer: PreTrainedTokenizer, *args, **kwargs):
super().__init__(model, optimizer, lr_scheduler, tokenizer)
def __init__(self,
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
**kwargs):
if processing_class is None:
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning)
processing_class = kwargs.pop("tokenizer")
super().__init__(model, optimizer, lr_scheduler, processing_class)
def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs):
if path is None:
......@@ -142,7 +152,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
hf_local_path = os.path.join(local_path, 'huggingface')
os.makedirs(hf_local_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
self.tokenizer.save_pretrained(hf_local_path)
self.processing_class.save_pretrained(hf_local_path)
torch.distributed.barrier()
......
......@@ -14,32 +14,29 @@
from omegaconf import ListConfig
import os
from typing import List, Union
from typing import List, Union, Optional
import copy
import pandas as pd
from collections import defaultdict
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, ProcessorMixin
from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
def collate_fn(data_list: list[dict]) -> dict:
tensors = {}
non_tensors = {}
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
if key not in tensors:
tensors[key] = []
tensors[key].append(val)
else:
if key not in non_tensors:
non_tensors[key] = []
non_tensors[key].append(val)
for key, val in tensors.items():
......@@ -48,10 +45,31 @@ def collate_fn(data_list: list[dict]) -> dict:
for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
output = {}
output.update(tensors)
output.update(non_tensors)
return output
return {**tensors, **non_tensors}
def process_image(image: dict, max_pixels: int = 2048 * 2048, min_pixels: int = 512 * 512):
import math
from io import BytesIO
from PIL import Image
if isinstance(image, dict):
image = Image.open(BytesIO(image['bytes']))
if (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if image.mode != 'RGB':
image = image.convert('RGB')
return image
class RLHFDataset(Dataset):
......@@ -62,7 +80,9 @@ class RLHFDataset(Dataset):
def __init__(self,
parquet_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin] = None,
prompt_key='prompt',
image_key='images',
max_prompt_length=1024,
filter_prompts=True,
cache_dir='~/.cache/verl/rlhf',
......@@ -76,8 +96,10 @@ class RLHFDataset(Dataset):
self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume
self.cache_dir = os.path.expanduser(cache_dir)
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.filter_prompts = filter_prompts
......@@ -132,12 +154,36 @@ class RLHFDataset(Dataset):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict = self.dataframe.iloc[item].to_dict()
row_dict: dict = self.dataframe.iloc[item].to_dict()
chat = row_dict.pop(self.prompt_key)
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
if self.image_key in row_dict: # expand image token
raw_prompt = prompt_with_chat_template.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
row_dict['multi_modal_data'] = {'image': [process_image(image) for image in row_dict.pop(self.image_key)]}
image_inputs = self.processor.image_processor(row_dict['multi_modal_data']['image'], return_tensors='pt')
image_grid_thw = image_inputs['image_grid_thw']
row_dict['multi_modal_inputs'] = {key: val for key, val in image_inputs.items()}
if image_grid_thw is not None:
merge_length = self.processor.image_processor.merge_size**2
index = 0
while '<image>' in prompt_with_chat_template:
prompt_with_chat_template = prompt_with_chat_template.replace(
'<image>',
'<|vision_start|>' + '<|placeholder|>' * (image_grid_thw[index].prod() // merge_length) +
'<|vision_end|>',
1,
)
index += 1
prompt_with_chat_template = prompt_with_chat_template.replace('<|placeholder|>',
self.processor.image_token)
else:
raw_prompt = prompt_with_chat_template
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
......@@ -145,11 +191,22 @@ class RLHFDataset(Dataset):
left_pad=True,
truncation=self.truncation)
position_ids = compute_position_id_with_mask(attention_mask)
if self.image_key in row_dict:
from verl.models.transformers.qwen2_vl import get_rope_index
position_ids = get_rope_index(
self.processor,
input_ids=input_ids[0],
image_grid_thw=image_grid_thw,
attention_mask=attention_mask[0],
) # (3, seq_len)
else:
position_ids = compute_position_id_with_mask(attention_mask)
row_dict['input_ids'] = input_ids[0]
row_dict['attention_mask'] = attention_mask[0]
row_dict['position_ids'] = position_ids[0]
row_dict['raw_prompt_ids'] = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
# encode prompts without chat template
if self.return_raw_chat:
......
......@@ -13,9 +13,9 @@
# limitations under the License.
import torch
from transformers import PretrainedConfig, Qwen2Config, LlamaConfig
from transformers import PretrainedConfig
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit="T"):
......@@ -59,18 +59,22 @@ class FlopsCounter:
"""
def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. "
if not config.model_type in VALID_CONFIG_TYPE:
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {self.config.model_type}. "
f"MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops}
self.estimate_func = {
'qwen2': self._estimate_qwen2_flops,
'llama': self._estimate_qwen2_flops,
'qwen2_vl': self._estimate_qwen2_flops,
'qwen2_5_vl': self._estimate_qwen2_flops
}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
......
......@@ -30,6 +30,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']:
from . import prime_code
res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
elif data_source in ['hiyouga/geometry3k']:
from . import geo3k
res = geo3k.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mathruler.grader import extract_boxed_content, grade_answer
def compute_score(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
if grade_answer(answer, ground_truth):
return 1.0 # correct answer
return 0.0 # wrong answer
......@@ -14,7 +14,7 @@
"""Utils for tokenization."""
import warnings
__all__ = ['hf_tokenizer']
__all__ = ['hf_tokenizer', 'hf_processor']
def set_pad_token_id(tokenizer):
......@@ -56,4 +56,25 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw
tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
if correct_pad_token:
set_pad_token_id(tokenizer)
return tokenizer
\ No newline at end of file
return tokenizer
def hf_processor(name_or_path, **kwargs):
"""Create a huggingface processor to process multimodal data.
Args:
name_or_path (str): The name of the processor.
Returns:
transformers.ProcessorMixin: The pretrained processor.
"""
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(name_or_path, **kwargs)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None
return processor
......@@ -62,11 +62,19 @@ class DataParallelPPOActor(BasePPOActor):
log_probs: # (bs, response_len)
"""
response_length = micro_batch['responses'].size(-1)
multi_modal_inputs = {}
if 'multi_modal_inputs' in micro_batch:
for key in micro_batch['multi_modal_inputs'][0].keys():
multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']],
dim=0)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
......@@ -74,8 +82,13 @@ class DataParallelPPOActor(BasePPOActor):
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
if position_ids.dim() == 3:
position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."),
indices).transpose(0, 1).unsqueeze(
1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
......@@ -94,6 +107,7 @@ class DataParallelPPOActor(BasePPOActor):
output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
......@@ -131,6 +145,7 @@ class DataParallelPPOActor(BasePPOActor):
output = self.actor_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False) # prevent model thinks we are generating
logits = output.logits
logits.div_(temperature)
......@@ -177,8 +192,13 @@ class DataParallelPPOActor(BasePPOActor):
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
if use_dynamic_bsz:
if has_multi_modal_inputs:
num_micro_batches = data.batch.batch_size[0] // micro_batch_size
non_tensor_select_keys = ['multi_modal_inputs']
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
......@@ -187,6 +207,9 @@ class DataParallelPPOActor(BasePPOActor):
log_probs_lst = []
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
log_probs_lst.append(log_probs)
......@@ -210,17 +233,27 @@ class DataParallelPPOActor(BasePPOActor):
if self.config.use_kl_loss:
select_keys.append('ref_log_prob')
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
if has_multi_modal_inputs:
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
non_tensor_select_keys = ['multi_modal_inputs']
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
if has_multi_modal_inputs:
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
......@@ -231,7 +264,11 @@ class DataParallelPPOActor(BasePPOActor):
self.actor_optimizer.zero_grad()
for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
if isinstance(data, DataProto):
data = {**data.batch.cuda(), **data.non_tensor_batch}
else:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
......
......@@ -49,11 +49,19 @@ class DataParallelPPOCritic(BasePPOCritic):
def _forward_micro_batch(self, micro_batch):
response_length = micro_batch['responses'].size(-1)
multi_modal_inputs = {}
if 'multi_modal_inputs' in micro_batch:
for key in micro_batch['multi_modal_inputs'][0].keys():
multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']],
dim=0)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids']
batch, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
......@@ -61,8 +69,13 @@ class DataParallelPPOCritic(BasePPOCritic):
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
if position_ids.dim() == 3:
position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."),
indices).transpose(0, 1).unsqueeze(
1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
......@@ -74,6 +87,7 @@ class DataParallelPPOCritic(BasePPOCritic):
output = self.critic_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
......@@ -92,6 +106,7 @@ class DataParallelPPOCritic(BasePPOCritic):
output = self.critic_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False) # prevent model thinks we are generating
values = output.logits
values = values[:, -response_length - 1:-1].squeeze(-1)
......@@ -113,8 +128,13 @@ class DataParallelPPOCritic(BasePPOCritic):
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
if use_dynamic_bsz:
if has_multi_modal_inputs:
num_micro_batches = data.batch.batch_size[0] // micro_batch_size
non_tensor_select_keys = ['multi_modal_inputs']
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
......@@ -123,6 +143,9 @@ class DataParallelPPOCritic(BasePPOCritic):
values_lst = []
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
values = self._forward_micro_batch(micro_batch)
values_lst.append(values)
......@@ -147,15 +170,25 @@ class DataParallelPPOCritic(BasePPOCritic):
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns']
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.ppo_mini_batch_size)
if has_multi_modal_inputs:
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
non_tensor_select_keys = ['multi_modal_inputs']
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
if has_multi_modal_inputs:
num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
......@@ -165,7 +198,11 @@ class DataParallelPPOCritic(BasePPOCritic):
self.critic_optimizer.zero_grad()
for data in micro_batches:
data = data.cuda() # critic device is cpu when using offload
if isinstance(data, DataProto):
data = {**data.batch.cuda(), **data.non_tensor_batch}
else:
data = data.cuda() # critic device is cpu when using offload
input_ids = data['input_ids']
responses = data['responses']
attention_mask = data['attention_mask']
......
......@@ -27,7 +27,7 @@ from omegaconf import DictConfig, open_dict
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.utils import hf_tokenizer
from verl.utils import hf_tokenizer, hf_processor
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
......@@ -151,7 +151,7 @@ class ActorRolloutRefWorker(Worker):
role='actor'):
from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoConfig
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForVision2Seq
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
from torch import optim
......@@ -163,6 +163,7 @@ class ActorRolloutRefWorker(Worker):
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
......@@ -198,11 +199,16 @@ class ActorRolloutRefWorker(Worker):
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
actor_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
else:
actor_module_class = AutoModelForCausalLM
actor_module = actor_module_class.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
......@@ -400,10 +406,11 @@ class ActorRolloutRefWorker(Worker):
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
tokenizer=self.tokenizer)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)
torch.cuda.empty_cache()
......@@ -641,6 +648,7 @@ class CriticWorker(Worker):
tokenizer_path = copy_to_local(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
......@@ -764,10 +772,11 @@ class CriticWorker(Worker):
critic_optimizer=self.critic_optimizer)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
tokenizer=self.tokenizer)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)
torch.cuda.empty_cache()
......
......@@ -24,6 +24,7 @@ When working with Megatron:
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
import numpy as np
from typing import List
from contextlib import contextmanager
from omegaconf import DictConfig
......@@ -31,7 +32,7 @@ import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn
from typing import Any, Union
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
from verl.workers.rollout.base import BaseRollout
......@@ -54,6 +55,13 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[in
return token_ids
def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
if isinstance(value, torch.Tensor):
return value.repeat_interleave(repeats, dim=0)
else:
return np.repeat(value, repeats, axis=0)
class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):
......@@ -111,7 +119,7 @@ class vLLMRollout(BaseRollout):
kwargs = dict(
n=1,
logprobs=1, # can be set to 0 and let actor to recompute
logprobs=0, # can be set to 0 and let actor to recompute
max_tokens=config.response_length,
)
......@@ -161,10 +169,23 @@ class vLLMRollout(BaseRollout):
batch_size = idx.size(0)
idx_list = []
# parse idx from torch.Tensor to List[List[str]]
for i in range(batch_size):
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))
non_tensor_batch = prompts.non_tensor_batch
if 'raw_prompt_ids' not in non_tensor_batch:
non_tensor_batch['raw_prompt_ids'] = np.array(
[_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)
if batch_size != len(non_tensor_batch['raw_prompt_ids']):
raise RuntimeError('vllm sharding manager is not work properly.')
if 'multi_modal_data' in non_tensor_batch:
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
non_tensor_batch.pop('multi_modal_data')):
vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data})
else:
vllm_inputs = [{
'prompt_token_ids': raw_prompt_ids
} for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]
do_sample = prompts.meta_info.get('do_sample', True)
if not do_sample:
......@@ -180,9 +201,8 @@ class vLLMRollout(BaseRollout):
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
outputs = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
prompts=vllm_inputs, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)
# TODO(sgm): disable logprob when recompute_log_prob is enable
......@@ -197,15 +217,21 @@ class vLLMRollout(BaseRollout):
max_length=self.config.response_length).to(idx.device)
if self.config.n > 1 and do_sample:
idx = idx.repeat_interleave(self.config.n, dim=0)
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
idx = _repeat_interleave(idx, self.config.n)
attention_mask = _repeat_interleave(attention_mask, self.config.n)
position_ids = _repeat_interleave(position_ids, self.config.n)
batch_size = batch_size * self.config.n
if 'multi_modal_inputs' in non_tensor_batch.keys():
non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'],
self.config.n)
seq = torch.cat([idx, response], dim=-1)
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
......@@ -232,4 +258,4 @@ class vLLMRollout(BaseRollout):
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
self.inference_engine.free_cache_engine()
return DataProto(batch=batch)
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
......@@ -15,6 +15,7 @@
import os
import logging
import torch
import numpy as np
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from torch.distributed.device_mesh import DeviceMesh
......@@ -126,17 +127,20 @@ class FSDPVLLMShardingManager(BaseShardingManager):
def preprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=vllm_ps.get_tensor_model_parallel_world_size(),
group=vllm_ps.get_tensor_model_parallel_group(),
dim=0)
group = vllm_ps.get_tensor_model_parallel_group()
else:
data.batch = allgather_dict_tensors(data.batch.contiguous(),
size=vllm_ps.get_tensor_model_parallel_world_size(),
group=vllm_ps.get_tensor_model_parallel_group().device_group,
dim=0)
group = vllm_ps.get_tensor_model_parallel_group().device_group
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=tp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(tp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
return data
def postprocess_data(self, data: DataProto) -> DataProto:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment