Unverified Commit d36422be by zhou fan Committed by GitHub

feat: add support for ulysses sequence parallel for transformers >= 0.48 (#357)

close #312 

Add support for ulysses sp for transformers >= 0.48

I've tested transformers 0.45.0, 0.46.0, 0.47.0, 0.48.0 and 0.49.0,
using sp=2 with the following script in my local env
```bash
#!/bin/bash

set -ex
VERSIONS=("4.45.0" "4.46.0" "4.47.0" "4.48.0" "4.49.0")

for version in "${VERSIONS[@]}"; do
    echo "Testing with Transformers version ${version}"
    echo "----------------------------------------"
    
    pip install "transformers==${version}"
    
    PYTHONPATH=./ torchrun --nproc_per_node=2 tests/model/test_transformers_ulysses.py
    
    echo "----------------------------------------"
    echo "Completed testing for version ${version}"
    echo ""
done
```
parent e53dcdb9
...@@ -52,3 +52,22 @@ jobs: ...@@ -52,3 +52,22 @@ jobs:
run: | run: |
pip3 install hf_transfer pip3 install hf_transfer
torchrun --nproc_per_node=8 tests/checkpoint/test_fsdp_ckpt.py torchrun --nproc_per_node=8 tests/checkpoint/test_fsdp_ckpt.py
- name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers
run: |
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.48.0
run: |
pip3 install transformers==4.48.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.47.0
run: |
pip3 install transformers==4.47.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.46.0
run: |
pip3 install transformers==4.46.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.45.0
run: |
pip3 install transformers==4.45.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
...@@ -18,28 +18,22 @@ import torch.distributed ...@@ -18,28 +18,22 @@ import torch.distributed
from torch.distributed import init_device_mesh from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import create_random_mask, compute_position_id_with_mask from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group
from verl.workers.sharding_manager import FSDPUlyssesShardingManager from verl.workers.sharding_manager import FSDPUlyssesShardingManager
from verl.models.transformers.llama import llama_flash_attn_forward
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
from verl.protocol import DataProto from verl.protocol import DataProto
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange
from transformers import LlamaConfig, Qwen2Config
from transformers import AutoModelForCausalLM
from verl.models.transformers.monkey_patch import apply_monkey_patch_to_llama, apply_monkey_patch_to_qwen2
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification
# TODO(sgm): add more models for test # TODO(sgm): add more models for test
# we only need one scale for each model # we only need one scale for each model
test_configs = { test_configs = {
'llama': (LlamaConfig(num_hidden_layers=2), LlamaFlashAttention2), 'llama': (LlamaConfig(num_hidden_layers=2), apply_monkey_patch_to_llama),
'qwen2': (Qwen2Config(num_hidden_layers=2), Qwen2FlashAttention2) 'qwen2': (Qwen2Config(num_hidden_layers=2), apply_monkey_patch_to_qwen2)
} }
patches = {'llama': llama_flash_attn_forward, 'qwen2': qwen2_flash_attn_forward}
def sync_model_parameters_global(layer): def sync_model_parameters_global(layer):
# synchronize weights # synchronize weights
...@@ -61,9 +55,9 @@ def test_hf_casual_fwd(): ...@@ -61,9 +55,9 @@ def test_hf_casual_fwd():
seqlen = 128 seqlen = 128
response_length = 127 response_length = 127
for model_name, (config, attn) in test_configs.items(): for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load # patch before load
attn.forward = patches[model_name] apply_monkey_patch()
with torch.device('cuda'): with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config, model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
...@@ -139,9 +133,9 @@ def test_hf_casual_fwd_bwd(): ...@@ -139,9 +133,9 @@ def test_hf_casual_fwd_bwd():
seqlen = 128 seqlen = 128
response_length = 127 response_length = 127
for model_name, (config, attn) in test_configs.items(): for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load # patch before load
attn.forward = patches[model_name] apply_monkey_patch()
with torch.device('cuda'): with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config, model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
......
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import torch import torch
from typing import Optional, List, Union, Tuple, Callable from typing import Optional, Tuple, Callable
import sys import sys
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
from typing import Unpack from typing import Unpack
else: else:
from typing_extensions import Unpack from typing_extensions import Unpack
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.utils import logging from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
...@@ -41,8 +41,10 @@ def llama_flash_attn_forward( ...@@ -41,8 +41,10 @@ def llama_flash_attn_forward(
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
""" """
adapt from transformers 4.47.1 Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
"""
NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].
"""
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -147,3 +149,73 @@ def llama_flash_attn_forward( ...@@ -147,3 +149,73 @@ def llama_flash_attn_forward(
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def llama_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0.
"""
from transformers.models.llama.modeling_llama import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
bsz, q_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
...@@ -15,20 +15,27 @@ ...@@ -15,20 +15,27 @@
Apply monkey-patch function to models Apply monkey-patch function to models
""" """
#### Open Source Models
#### transformers version < 4.48
def apply_monkey_patch_to_llama(): def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 if is_transformers_version_in_range("4.45.0", "4.47.1"):
from verl.models.transformers.llama import llama_flash_attn_forward from transformers.models.llama.modeling_llama import LlamaFlashAttention2
LlamaFlashAttention2.forward = llama_flash_attn_forward from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"):
from transformers.models.llama.modeling_llama import LlamaAttention
from verl.models.transformers.llama import llama_attn_forward
LlamaAttention.forward = llama_attn_forward
def apply_monkey_patch_to_qwen2(): def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2 if is_transformers_version_in_range("4.45.0", "4.47.1"):
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from verl.models.transformers.qwen2 import qwen2_attn_forward
Qwen2Attention.forward = qwen2_attn_forward
_PATCH_NAME_TO_FUNC = { _PATCH_NAME_TO_FUNC = {
...@@ -40,9 +47,9 @@ from transformers import PretrainedConfig ...@@ -40,9 +47,9 @@ from transformers import PretrainedConfig
def apply_monkey_patch(config: PretrainedConfig, verbose=True): def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"): if not is_transformers_version_in_range("4.45.0", "4.49.0"):
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. " raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.") "Please install a version between 4.45.0 and 4.49.0 to use this ulysses feature.")
success_apply_monkey_patch = False success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC: if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]() _PATCH_NAME_TO_FUNC[config.model_type]()
......
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple, Callable
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.utils import logging from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.processing_utils import Unpack
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -35,6 +36,11 @@ def qwen2_flash_attn_forward( ...@@ -35,6 +36,11 @@ def qwen2_flash_attn_forward(
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
"""
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.
"""
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
...@@ -134,3 +140,82 @@ def qwen2_flash_attn_forward( ...@@ -134,3 +140,82 @@ def qwen2_flash_attn_forward(
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def qwen2_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0.
"""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
bsz, q_len, _ = hidden_states.shape
hidden_shape = (bsz, q_len, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
sliding_window = None
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
self.layer_idx >= self.config.max_window_layers):
sliding_window = self.config.sliding_window
from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment