Unverified Commit 3140cc2f by Kinman Lei Committed by GitHub

[rollout]: fix incorrect response_attention_mask in vLLM rollout (#213)

This PR addresses issue https://github.com/volcengine/verl/issues/212.

The changes include:
- read eos_token_id from generation_config to ensure alignment with vLLM
- modified the get_eos_mask function to accept both int and list types
for the eos_token parameter.
parent 27484a7b
...@@ -16,12 +16,12 @@ Utilities to create common models from huggingface ...@@ -16,12 +16,12 @@ Utilities to create common models from huggingface
""" """
import os import os
import warnings import warnings
from typing import Dict, Type from typing import Dict, Type, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, MistralForSequenceClassification, GenerationConfig
from verl.models.registry import ModelRegistry from verl.models.registry import ModelRegistry
...@@ -55,6 +55,23 @@ def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, t ...@@ -55,6 +55,23 @@ def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, t
return module_config return module_config
def get_generation_config(
model: str,
trust_remote_code: bool = False,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(model)
except OSError: # Not found
try:
config = get_huggingface_actor_config(
model,
trust_remote_code=trust_remote_code,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None
def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
""" """
......
...@@ -138,13 +138,21 @@ def masked_whiten(values, mask, shift_mean=True): ...@@ -138,13 +138,21 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened return whitened
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64): def get_eos_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
''' '''
e.g. end of sentence token=1 end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0] response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
''' '''
eos_mask = response_id.eq(eos_token).long() if isinstance(eos_token, int):
eos_token = [eos_token]
eos_mask = torch.zeros_like(response_id, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_id.eq(token)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool() eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype) eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask return eos_mask
......
...@@ -147,7 +147,7 @@ class ActorRolloutRefWorker(Worker): ...@@ -147,7 +147,7 @@ class ActorRolloutRefWorker(Worker):
trust_remote_code=False, trust_remote_code=False,
use_liger=False, use_liger=False,
role='actor'): role='actor'):
from verl.utils.model import print_model_size, update_model_config from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoConfig from transformers import AutoModelForCausalLM, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload
...@@ -171,6 +171,8 @@ class ActorRolloutRefWorker(Worker): ...@@ -171,6 +171,8 @@ class ActorRolloutRefWorker(Worker):
# override model kwargs # override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
if use_remove_padding: if use_remove_padding:
from verl.models.registry import check_model_support_rmpad from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type) check_model_support_rmpad(actor_model_config.model_type)
...@@ -445,7 +447,14 @@ class ActorRolloutRefWorker(Worker): ...@@ -445,7 +447,14 @@ class ActorRolloutRefWorker(Worker):
load_grad=self._is_offload_grad) load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda() prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info) prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager: with self.rollout_sharding_manager:
log_gpu_memory_usage('After entering rollout sharding manager', logger=logger) log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)
......
...@@ -135,9 +135,9 @@ class ActorRolloutRefWorker(MegatronWorker): ...@@ -135,9 +135,9 @@ class ActorRolloutRefWorker(MegatronWorker):
enable_gradient_checkpointing=False): enable_gradient_checkpointing=False):
from verl.utils.megatron.optimizer import get_megatron_optimizer from verl.utils.megatron.optimizer import get_megatron_optimizer
from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.models.gpt.gpt_model import ModelType
from verl.utils.model import print_model_size, update_model_config from verl.utils.model import print_model_size, update_model_config, get_generation_config
from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GenerationConfig
# Step 1: initialize the tokenizer # Step 1: initialize the tokenizer
local_path = copy_local_path_from_hdfs(model_path) local_path = copy_local_path_from_hdfs(model_path)
...@@ -146,6 +146,8 @@ class ActorRolloutRefWorker(MegatronWorker): ...@@ -146,6 +146,8 @@ class ActorRolloutRefWorker(MegatronWorker):
# Step 2: get the actor_model_config # Step 2: get the actor_model_config
actor_model_config = AutoConfig.from_pretrained(local_path) actor_model_config = AutoConfig.from_pretrained(local_path)
self.generation_config = get_generation_config(local_path)
override_config_kwargs = { override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id, 'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id, 'eos_token_id': self.tokenizer.eos_token_id,
...@@ -352,7 +354,14 @@ class ActorRolloutRefWorker(MegatronWorker): ...@@ -352,7 +354,14 @@ class ActorRolloutRefWorker(MegatronWorker):
assert self._is_rollout assert self._is_rollout
prompts.batch = prompts.batch.cuda() prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info) prompts.meta_info.update(meta_info)
with self.sharding_manager: with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger) log_gpu_memory_usage('After entering sharding manager', logger=logger)
......
...@@ -57,6 +57,8 @@ class NaiveRollout(BaseRollout): ...@@ -57,6 +57,8 @@ class NaiveRollout(BaseRollout):
# used to construct attention_mask # used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id'] eos_token_id = prompts.meta_info['eos_token_id']
if isinstance(eos_token, int):
eos_token = [eos_token]
batch_size = idx.size(0) batch_size = idx.size(0)
prompt_length = idx.size(1) prompt_length = idx.size(1)
...@@ -90,7 +92,8 @@ class NaiveRollout(BaseRollout): ...@@ -90,7 +92,8 @@ class NaiveRollout(BaseRollout):
attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1)
prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool()) for token_id in eos_token_id:
prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool())
prev_attention_mask.to(attention_mask.dtype) prev_attention_mask.to(attention_mask.dtype)
position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1) position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment