Unverified Commit e8eb9e4e by Guangming Sheng Committed by GitHub

[misc][Long Context] feat: support ulysses for long context training (#109)

parent 594d80ad
...@@ -56,3 +56,7 @@ jobs: ...@@ -56,3 +56,7 @@ jobs:
run: | run: |
ray stop --force ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh bash tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh
- name: Running gsm8k e2e with rmpad using model rm and ulysses sp=2
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh
set -x
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=64 \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.ulysses_sequence_parallel_size=2 \
critic.model.use_remove_padding=True \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=64 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \
trainer.n_gpus_per_node=8 \
+trainer.val_before_train=False \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=15 $@
...@@ -27,7 +27,8 @@ actor_rollout_ref: ...@@ -27,7 +27,8 @@ actor_rollout_ref:
clip_ratio: 0.2 clip_ratio: 0.2
entropy_coeff: 0.0 entropy_coeff: 0.0
ppo_epochs: 1 ppo_epochs: 1
shuffle: True shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim: optim:
lr: 1e-4 lr: 1e-4
fsdp_config: fsdp_config:
...@@ -44,6 +45,7 @@ actor_rollout_ref: ...@@ -44,6 +45,7 @@ actor_rollout_ref:
# transformer_layer_cls_to_wrap: None # transformer_layer_cls_to_wrap: None
min_num_params: 0 min_num_params: 0
micro_batch_size: 200 micro_batch_size: 200
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout: rollout:
name: hf name: hf
temperature: 1.0 temperature: 1.0
...@@ -59,7 +61,7 @@ actor_rollout_ref: ...@@ -59,7 +61,7 @@ actor_rollout_ref:
enforce_eager: True enforce_eager: True
free_cache_engine: True free_cache_engine: True
load_format: dummy_dtensor load_format: dummy_dtensor
tensor_model_parallel_size: 2 tensor_model_parallel_size: 1
max_num_batched_tokens: 8192 max_num_batched_tokens: 8192
max_num_seqs: 1024 max_num_seqs: 1024
log_prob_micro_batch_size: 200 log_prob_micro_batch_size: 200
...@@ -89,6 +91,7 @@ critic: ...@@ -89,6 +91,7 @@ critic:
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 200 ppo_micro_batch_size: 200
forward_micro_batch_size: ${critic.ppo_micro_batch_size} forward_micro_batch_size: ${critic.ppo_micro_batch_size}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle} shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0 grad_clip: 1.0
...@@ -112,6 +115,7 @@ reward_model: ...@@ -112,6 +115,7 @@ reward_model:
min_num_params: 0 min_num_params: 0
micro_batch_size: 8 micro_batch_size: 8
max_length: null max_length: null
ulysses_sequence_parallel_size: 1 # sp size
algorithm: algorithm:
gamma: 1.0 gamma: 1.0
......
...@@ -34,6 +34,7 @@ python3 -m verl.trainer.main_ppo \ ...@@ -34,6 +34,7 @@ python3 -m verl.trainer.main_ppo \
algorithm.kl_ctrl.kl_coef=0.001 \ algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \ trainer.critic_warmup=0 \
trainer.logger=['console'] \ trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.project_name='verl_example_gsm8k' \ trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_function_rm' \ trainer.experiment_name='qwen_e2e_ci_function_rm' \
trainer.n_gpus_per_node=8 \ trainer.n_gpus_per_node=8 \
......
...@@ -42,6 +42,7 @@ python3 -m verl.trainer.main_ppo \ ...@@ -42,6 +42,7 @@ python3 -m verl.trainer.main_ppo \
algorithm.kl_ctrl.kl_coef=0.001 \ algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \ trainer.critic_warmup=0 \
trainer.logger=['console'] \ trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.project_name='verl_example' \ trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \ trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
trainer.n_gpus_per_node=8 \ trainer.n_gpus_per_node=8 \
......
...@@ -41,6 +41,7 @@ python3 -m verl.trainer.main_ppo \ ...@@ -41,6 +41,7 @@ python3 -m verl.trainer.main_ppo \
reward_model.micro_batch_size=16 \ reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \ algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \ trainer.critic_warmup=0 \
+trainer.val_before_train=False \
trainer.logger=['console'] \ trainer.logger=['console'] \
trainer.project_name='verl_example' \ trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \ trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
......
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2 with flash_attn has some issues
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.ulysses_sequence_parallel_size=2 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.ulysses_sequence_parallel_size=2 \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
+trainer.val_before_train=False \
trainer.logger=['console'] \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm_sp2' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
...@@ -28,7 +28,8 @@ def check_model_support_rmpad(model_type: str): ...@@ -28,7 +28,8 @@ def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str) assert isinstance(model_type, str)
if not model_type in _REOVEPAD_MODELS.keys(): if not model_type in _REOVEPAD_MODELS.keys():
raise ValueError(f"Model architecture {model_type} is not supported for now. " raise ValueError(f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}") f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
f"Please set `use_remove_padding=False` in the model config.")
# Supported models in Megatron-LM # Supported models in Megatron-LM
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from typing import Optional, List, Union, Tuple, Unpack, Callable
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__)
def llama_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
adapt from transformers 4.47.1
"""
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# trade off: repeat first and then all to all
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory.")
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply monkey-patch function to models
"""
#### Open Source Models
#### transformers version < 4.48
def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
_PATCH_NAME_TO_FUNC = {
'llama': apply_monkey_patch_to_llama,
'qwen2': apply_monkey_patch_to_qwen2,
}
from transformers import PretrainedConfig
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.")
success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]()
success_apply_monkey_patch = True
if success_apply_monkey_patch and verbose:
print(f'Applying monkey patch to model {config.model_type}')
elif not success_apply_monkey_patch:
raise NotImplementedError(f'Ulysses for model {config.model_type} is not implemented, \
please set `ulysses_sequence_parallel_size=1`')
return success_apply_monkey_patch
from functools import lru_cache
from packaging import version
import importlib.metadata
@lru_cache()
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError:
raise ModuleNotFoundError("The `transformers` package is not installed.")
# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from typing import Optional, Tuple
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__)
def qwen2_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory.")
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
self.layer_idx >= self.config.max_window_layers):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
# use full_q_len to reshape
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
...@@ -26,7 +26,8 @@ actor_rollout_ref: ...@@ -26,7 +26,8 @@ actor_rollout_ref:
clip_ratio: 0.2 clip_ratio: 0.2
entropy_coeff: 0.001 entropy_coeff: 0.001
ppo_epochs: 1 ppo_epochs: 1
shuffle: True shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim: optim:
lr: 1e-6 lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
...@@ -47,6 +48,7 @@ actor_rollout_ref: ...@@ -47,6 +48,7 @@ actor_rollout_ref:
# transformer_layer_cls_to_wrap: None # transformer_layer_cls_to_wrap: None
min_num_params: 0 min_num_params: 0
log_prob_micro_batch_size: 128 log_prob_micro_batch_size: 128
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout: rollout:
name: vllm name: vllm
temperature: 1.0 temperature: 1.0
...@@ -95,6 +97,7 @@ critic: ...@@ -95,6 +97,7 @@ critic:
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64 ppo_micro_batch_size: 64
forward_micro_batch_size: ${critic.ppo_micro_batch_size} forward_micro_batch_size: ${critic.ppo_micro_batch_size}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle} shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0 grad_clip: 1.0
...@@ -113,6 +116,7 @@ reward_model: ...@@ -113,6 +116,7 @@ reward_model:
param_offload: False param_offload: False
micro_batch_size: 64 micro_batch_size: 64
max_length: null max_length: null
ulysses_sequence_parallel_size: 1 # sp size
algorithm: algorithm:
gamma: 1.0 gamma: 1.0
......
...@@ -150,8 +150,8 @@ def create_random_mask(input_ids: torch.Tensor, ...@@ -150,8 +150,8 @@ def create_random_mask(input_ids: torch.Tensor,
Returns: Returns:
""" """
assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token < 1. assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.
assert max_ratio_of_left_padding > 0 and max_ratio_of_left_padding < 1. assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.
assert min_ratio_of_valid_token <= max_ratio_of_valid_token assert min_ratio_of_valid_token <= max_ratio_of_valid_token
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
......
...@@ -25,6 +25,8 @@ from verl.trainer.ppo import core_algos ...@@ -25,6 +25,8 @@ from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
import verl.utils.torch_functional as verl_F
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
...@@ -45,8 +47,15 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -45,8 +47,15 @@ class DataParallelPPOActor(BasePPOActor):
self.actor_optimizer = actor_optimizer self.actor_optimizer = actor_optimizer
self.use_remove_padding = self.config.get('use_remove_padding', False) self.use_remove_padding = self.config.get('use_remove_padding', False)
print(f'Actor use_remove_padding={self.use_remove_padding}') print(f'Actor use_remove_padding={self.use_remove_padding}')
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
response_length = micro_batch['responses'].size(-1) response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
input_ids = micro_batch['input_ids'] input_ids = micro_batch['input_ids']
...@@ -62,29 +71,68 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -62,29 +71,68 @@ class DataParallelPPOActor(BasePPOActor):
# unpad the position_ids to align the rotary # unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1) indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen # only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(input_ids=input_ids_rmpad, output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None, attention_mask=None,
position_ids=position_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad /= temperature logits_rmpad /= temperature
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad, # compute entropy
indices=indices, entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
batch_size=batch_size,
seqlen=seqlen, # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
response_length=response_length) # (batch, seqlen) log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
logits = logits_rmpad
else: # gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad back to (bsz, seqlen)
full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
# only return response part:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
else: # not using rmpad and no ulysses sp
output = self.actor_module(input_ids=input_ids, output = self.actor_module(input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature logits = output.logits / temperature
logits = logits[:, -response_length - 1:-1] logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses']) log_probs = logprobs_from_logits(logits, micro_batch['responses'])
return logits, log_probs entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
return entropy, log_probs
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""Make minibatch iterator for updating the actor """Make minibatch iterator for updating the actor
...@@ -94,7 +142,7 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -94,7 +142,7 @@ class DataParallelPPOActor(BasePPOActor):
data = data.select(batch_keys=select_keys) data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size, return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs, epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle}) dataloader_kwargs={'shuffle': False}) # TODO: hardcode to False
def _optimizer_step(self): def _optimizer_step(self):
assert self.config.grad_clip is not None assert self.config.grad_clip is not None
...@@ -170,23 +218,17 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -170,23 +218,17 @@ class DataParallelPPOActor(BasePPOActor):
clip_ratio = self.config.clip_ratio clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff entropy_coeff = self.config.entropy_coeff
logits, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) # all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob, log_prob=log_prob,
advantages=advantages, advantages=advantages,
eos_mask=response_mask, eos_mask=response_mask,
cliprange=clip_ratio) cliprange=clip_ratio)
# compute entropy loss # compute entropy loss from entropy
if self.use_remove_padding: entropy_loss = verl_F.masked_mean(entropy, response_mask)
full_response_mask = attention_mask.clone()
full_response_mask[:, :-response_length] = 0 # set the prompt part to zero
full_response_mask_rmpad, *_ = unpad_input(full_response_mask.unsqueeze(-1),
attention_mask=attention_mask)
full_response_mask_rmpad = full_response_mask_rmpad.squeeze(-1) # (total_nnz)
entropy_loss = core_algos.compute_entropy_loss(logits, full_response_mask_rmpad) # (total_nnz,)
else:
entropy_loss = core_algos.compute_entropy_loss(logits, response_mask)
# compute policy loss # compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff policy_loss = pg_loss - entropy_loss * entropy_coeff
......
...@@ -28,6 +28,7 @@ from verl.trainer.ppo import core_algos ...@@ -28,6 +28,7 @@ from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
...@@ -46,6 +47,8 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -46,6 +47,8 @@ class DataParallelPPOCritic(BasePPOCritic):
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
def _forward_micro_batch(self, micro_batch): def _forward_micro_batch(self, micro_batch):
response_length = micro_batch['responses'].size(-1) response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
...@@ -62,6 +65,13 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -62,6 +65,13 @@ class DataParallelPPOCritic(BasePPOCritic):
# unpad the position_ids to align the rotary # unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1) indices).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
# only pass input_ids and position_ids to enable flash_attn_varlen # only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(input_ids=input_ids_rmpad, output = self.critic_module(input_ids=input_ids_rmpad,
attention_mask=None, attention_mask=None,
...@@ -70,6 +80,13 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -70,6 +80,13 @@ class DataParallelPPOCritic(BasePPOCritic):
values_rmpad = output.logits values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz) values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
values_rmpad = gather_outpus_and_unpad(values_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad it back # pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1:-1] values = values[:, -response_length - 1:-1]
...@@ -87,7 +104,7 @@ class DataParallelPPOCritic(BasePPOCritic): ...@@ -87,7 +104,7 @@ class DataParallelPPOCritic(BasePPOCritic):
data = data.select(batch_keys=select_keys) data = data.select(batch_keys=select_keys)
return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size, return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs, epochs=self.config.ppo_epochs,
dataloader_kwargs={'shuffle': self.config.shuffle}) dataloader_kwargs={'shuffle': False}) # TODO: hardcode to False
def _optimizer_step(self): def _optimizer_step(self):
assert self.config.grad_clip is not None assert self.config.grad_clip is not None
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from verl.utils.import_utils import is_vllm_available, is_megatron_core_available from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
from .base import BaseShardingManager from .base import BaseShardingManager
from .fsdp_ulysses import FSDPUlyssesShardingManager
AllGatherPPModel = None AllGatherPPModel = None
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
"""
from typing import Optional
from .base import BaseShardingManager
import random
from torch.distributed.device_mesh import DeviceMesh
from verl.utils.torch_functional import allgather_dict_tensors
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
import numpy as np
import torch
import torch.distributed
from verl import DataProto
class FSDPUlyssesShardingManager(BaseShardingManager):
"""
Sharding manager to support data resharding when using FSDP + Ulysses
"""
def __init__(self, device_mesh: DeviceMesh):
super().__init__()
self.device_mesh = device_mesh
self.seed_offset = 12345
def __enter__(self):
if self.device_mesh is not None:
# We have a global SP group
# so we have to change to use model-specific sp group
self.prev_sp_group = get_ulysses_sequence_parallel_group()
set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group())
# TODO: check how to set seed for each model
def __exit__(self, exc_type, exc_value, traceback):
# restore random states
if self.device_mesh is not None:
# revert to previous sp group
set_ulysses_sequence_parallel_group(self.prev_sp_group)
# TODO: check how to set seed for each model
def preprocess_data(self, data: DataProto) -> DataProto:
"""
AllGather data from sp region
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
In Ulysses, we need to make sure the same data is used across a SP group
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
group = self.device_mesh['sp'].get_group()
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(sp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
}
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""
Split the data to follow FSDP partition
"""
if self.device_mesh is not None:
sp_size = self.device_mesh['sp'].size()
sp_rank = self.device_mesh['sp'].get_local_rank()
data = data.chunk(chunks=sp_size)[sp_rank]
return data
\ No newline at end of file
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
import torch import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from torch.distributed.device_mesh import DeviceMesh
from verl.third_party.vllm import LLM from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps from verl.third_party.vllm import parallel_state as vllm_ps
...@@ -32,10 +33,16 @@ logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) ...@@ -32,10 +33,16 @@ logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
class FSDPVLLMShardingManager(BaseShardingManager): class FSDPVLLMShardingManager(BaseShardingManager):
def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False): def __init__(self,
module: FSDP,
inference_engine: LLM,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None):
self.module = module self.module = module
self.inference_engine = inference_engine self.inference_engine = inference_engine
self.model_config = model_config self.model_config = model_config
self.device_mesh = device_mesh
# Full params # Full params
self.full_params = full_params self.full_params = full_params
...@@ -48,6 +55,17 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -48,6 +55,17 @@ class FSDPVLLMShardingManager(BaseShardingManager):
state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig()) state_dict_config=ShardedStateDictConfig())
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh['dp'].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def __enter__(self): def __enter__(self):
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger) log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict() params = self.module.state_dict()
...@@ -67,6 +85,11 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -67,6 +85,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# if torch.distributed.get_rank() == 0: # if torch.distributed.get_rank() == 0:
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger)
self.inference_engine.offload_model_weights() self.inference_engine.offload_model_weights()
...@@ -81,6 +104,11 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -81,6 +104,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# add empty cache after each compute # add empty cache after each compute
torch.cuda.empty_cache() torch.cuda.empty_cache()
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto: def preprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp # TODO: Current impl doesn't consider FSDP with torch micro-dp
data.batch = allgather_dict_tensors(data.batch.contiguous(), data.batch = allgather_dict_tensors(data.batch.contiguous(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment