Unverified Commit 3140cc2f by Kinman Lei Committed by GitHub

[rollout]: fix incorrect response_attention_mask in vLLM rollout (#213)

This PR addresses issue https://github.com/volcengine/verl/issues/212.

The changes include:
- read eos_token_id from generation_config to ensure alignment with vLLM
- modified the get_eos_mask function to accept both int and list types
for the eos_token parameter.
parent 27484a7b
......@@ -16,12 +16,12 @@ Utilities to create common models from huggingface
"""
import os
import warnings
from typing import Dict, Type
from typing import Dict, Type, Optional
import numpy as np
import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification, GenerationConfig
from verl.models.registry import ModelRegistry
......@@ -55,6 +55,23 @@ def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, t
return module_config
def get_generation_config(
model: str,
trust_remote_code: bool = False,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(model)
except OSError: # Not found
try:
config = get_huggingface_actor_config(
model,
trust_remote_code=trust_remote_code,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None
def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
"""
......
......@@ -138,13 +138,21 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
'''
e.g. end of sentence token=1
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
'''
eos_mask = response_id.eq(eos_token).long()
if isinstance(eos_token, int):
eos_token = [eos_token]
eos_mask = torch.zeros_like(response_id, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_id.eq(token)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
......
......@@ -147,7 +147,7 @@ class ActorRolloutRefWorker(Worker):
trust_remote_code=False,
use_liger=False,
role='actor'):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
......@@ -171,6 +171,8 @@ class ActorRolloutRefWorker(Worker):
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
if use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
......@@ -445,7 +447,14 @@ class ActorRolloutRefWorker(Worker):
load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)
......
......@@ -135,9 +135,9 @@ class ActorRolloutRefWorker(MegatronWorker):
enable_gradient_checkpointing=False):
from verl.utils.megatron.optimizer import get_megatron_optimizer
from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config
from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GenerationConfig
# Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path)
......@@ -146,6 +146,8 @@ class ActorRolloutRefWorker(MegatronWorker):
# Step 2: get the actor_model_config
actor_model_config = AutoConfig.from_pretrained(local_path)
self.generation_config = get_generation_config(local_path)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
......@@ -352,7 +354,14 @@ class ActorRolloutRefWorker(MegatronWorker):
assert self._is_rollout
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
......
......@@ -57,6 +57,8 @@ class NaiveRollout(BaseRollout):
# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']
if isinstance(eos_token, int):
eos_token = [eos_token]
batch_size = idx.size(0)
prompt_length = idx.size(1)
......@@ -90,7 +92,8 @@ class NaiveRollout(BaseRollout):
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool())
for token_id in eos_token_id:
prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool())
prev_attention_mask.to(attention_mask.dtype)
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment