Unverified Commit 3f6d45d9 by Lumeng Wu Committed by GitHub

fix: support transformers==4.50.0 (#704)

https://github.com/volcengine/verl/issues/703
parent 612823ae
...@@ -56,6 +56,10 @@ jobs: ...@@ -56,6 +56,10 @@ jobs:
- name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers - name: Running transformers ulysses tests on 8 L20 GPUs + latest transformers
run: | run: |
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.49.0
run: |
pip3 install transformers==4.49.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.48.0 - name: Running transformers ulysses tests on 8 L20 GPUs + transformers 4.48.0
run: | run: |
pip3 install transformers==4.48.0 pip3 install transformers==4.48.0
......
...@@ -166,7 +166,7 @@ def llama_attn_forward( ...@@ -166,7 +166,7 @@ def llama_attn_forward(
""" """
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0. NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.
""" """
from transformers.models.llama.modeling_llama import eager_attention_forward from transformers.models.llama.modeling_llama import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
......
...@@ -21,7 +21,7 @@ def apply_monkey_patch_to_llama(): ...@@ -21,7 +21,7 @@ def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward LlamaFlashAttention2.forward = llama_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"): elif is_transformers_version_in_range("4.48.0", "4.50.0"):
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from verl.models.transformers.llama import llama_attn_forward from verl.models.transformers.llama import llama_attn_forward
LlamaAttention.forward = llama_attn_forward LlamaAttention.forward = llama_attn_forward
...@@ -32,7 +32,7 @@ def apply_monkey_patch_to_qwen2(): ...@@ -32,7 +32,7 @@ def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2 from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
elif is_transformers_version_in_range("4.48.0", "4.49.0"): elif is_transformers_version_in_range("4.48.0", "4.50.0"):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from verl.models.transformers.qwen2 import qwen2_attn_forward from verl.models.transformers.qwen2 import qwen2_attn_forward
Qwen2Attention.forward = qwen2_attn_forward Qwen2Attention.forward = qwen2_attn_forward
...@@ -47,9 +47,9 @@ from transformers import PretrainedConfig ...@@ -47,9 +47,9 @@ from transformers import PretrainedConfig
def apply_monkey_patch(config: PretrainedConfig, verbose=True): def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.49.0"): if not is_transformers_version_in_range("4.45.0", "4.50.0"):
raise AssertionError("The installed `transformers` version doesn't support ulysses patch. " raise AssertionError("The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.49.0 to use this ulysses feature.") "Please install a version between 4.45.0 and 4.50.0 to use this ulysses feature.")
success_apply_monkey_patch = False success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC: if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]() _PATCH_NAME_TO_FUNC[config.model_type]()
......
...@@ -157,7 +157,7 @@ def qwen2_attn_forward( ...@@ -157,7 +157,7 @@ def qwen2_attn_forward(
""" """
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.49.0. NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.
""" """
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
bsz, q_len, _ = hidden_states.shape bsz, q_len, _ = hidden_states.shape
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment