Unverified Commit d36422be by zhou fan Committed by GitHub

feat: add support for ulysses sequence parallel for transformers >= 0.48 (#357)

close #312 

Add support for ulysses sp for transformers >= 0.48

I've tested transformers 0.45.0, 0.46.0, 0.47.0, 0.48.0 and 0.49.0,
using sp=2 with the following script in my local env
```bash
#!/bin/bash

set -ex
VERSIONS=("4.45.0" "4.46.0" "4.47.0" "4.48.0" "4.49.0")

for version in "${VERSIONS[@]}"; do
    echo "Testing with Transformers version ${version}"
    echo "----------------------------------------"
    
    pip install "transformers==${version}"
    
    PYTHONPATH=./ torchrun --nproc_per_node=2 tests/model/test_transformers_ulysses.py
    
    echo "----------------------------------------"
    echo "Completed testing for version ${version}"
    echo ""
done
```
parent e53dcdb9
......@@ -52,3 +52,22 @@ jobs:
run: |
pip3 install hf_transfer
torchrun --nproc_per_node=8 tests/checkpoint/test_fsdp_ckpt.py
- name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers
run: |
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.48.0
run: |
pip3 install transformers==4.48.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.47.0
run: |
pip3 install transformers==4.47.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.46.0
run: |
pip3 install transformers==4.46.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.45.0
run: |
pip3 install transformers==4.45.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
......@@ -18,28 +18,22 @@ import torch.distributed
from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
from verl.models.transformers.llama import llama_flash_attn_forward
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
from verl.protocol import DataProto
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange
from transformers import LlamaConfig, Qwen2Config
from transformers import AutoModelForCausalLM
from verl.models.transformers.monkey_patch import apply_monkey_patch_to_llama, apply_monkey_patch_to_qwen2
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = {
'llama': (LlamaConfig(num_hidden_layers=2), LlamaFlashAttention2),
'qwen2': (Qwen2Config(num_hidden_layers=2), Qwen2FlashAttention2)
'llama': (LlamaConfig(num_hidden_layers=2), apply_monkey_patch_to_llama),
'qwen2': (Qwen2Config(num_hidden_layers=2), apply_monkey_patch_to_qwen2)
}
patches = {'llama': llama_flash_attn_forward, 'qwen2': qwen2_flash_attn_forward}
def sync_model_parameters_global(layer):
# synchronize weights
......@@ -61,9 +55,9 @@ def test_hf_casual_fwd():
seqlen = 128
response_length = 127
for model_name, (config, attn) in test_configs.items():
for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load
attn.forward = patches[model_name]
apply_monkey_patch()
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
......@@ -139,9 +133,9 @@ def test_hf_casual_fwd_bwd():
seqlen = 128
response_length = 127
for model_name, (config, attn) in test_configs.items():
for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load
attn.forward = patches[model_name]
apply_monkey_patch()
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
......
......@@ -13,14 +13,14 @@
# limitations under the License.
import torch
from typing import Optional, List, Union, Tuple, Callable
from typing import Optional, Tuple, Callable
import sys
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
......@@ -41,8 +41,10 @@ def llama_flash_attn_forward(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
adapt from transformers 4.47.1
"""
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].
"""
output_attentions = False
bsz, q_len, _ = hidden_states.size()
......@@ -147,3 +149,73 @@ def llama_flash_attn_forward(
attn_weights = None
return attn_output, attn_weights, past_key_value
def llama_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0.
"""
from transformers.models.llama.modeling_llama import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
bsz, q_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
......@@ -15,20 +15,27 @@
Apply monkey-patch function to models
"""
#### Open Source Models
#### transformers version < 4.48
def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
if is_transformers_version_in_range("4.45.0", "4.47.1"):
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"):
from transformers.models.llama.modeling_llama import LlamaAttention
from verl.models.transformers.llama import llama_attn_forward
LlamaAttention.forward = llama_attn_forward
def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
if is_transformers_version_in_range("4.45.0", "4.47.1"):
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from verl.models.transformers.qwen2 import qwen2_attn_forward
Qwen2Attention.forward = qwen2_attn_forward
_PATCH_NAME_TO_FUNC = {
......@@ -40,9 +47,9 @@ from transformers import PretrainedConfig
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
if not is_transformers_version_in_range("4.45.0", "4.49.0"):
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.")
"Please install a version between 4.45.0 and 4.49.0 to use this ulysses feature.")
success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]()
......
......@@ -13,12 +13,13 @@
# limitations under the License.
import torch
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.cache_utils import Cache
from transformers.utils import logging
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.processing_utils import Unpack
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size
logger = logging.get_logger(__name__)
......@@ -35,6 +36,11 @@ def qwen2_flash_attn_forward(
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
"""
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
......@@ -134,3 +140,82 @@ def qwen2_flash_attn_forward(
attn_weights = None
return attn_output, attn_weights, past_key_value
def qwen2_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0.
"""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
bsz, q_len, _ = hidden_states.shape
hidden_shape = (bsz, q_len, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
sliding_window = None
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and
self.layer_idx >= self.config.max_window_layers):
sliding_window = self.config.sliding_window
from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment