Unverified Commit 1facb9d2 by Guangming Sheng Committed by GitHub

[misc] feat: support different flash_attn versions with variable num returns (#100)

* add ci

* fix reward model and write  more ci script

* support different flash_attn version with variable num returns

* update transformers rmpad workflow

* balance workload

* lint

* lint
parent a0e8ed2c
......@@ -18,7 +18,7 @@ on:
jobs:
e2e_digit_completion:
runs-on: [self-hosted, l20-1]
runs-on: [self-hosted, l20-0]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
......
......@@ -30,10 +30,14 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository and upgrade to latest transformers
- name: Install the current repository and upgrade to latest transformers/flash_attn
run: |
pip3 install -e .[test]
pip3 install --upgrade transformers
- name: Running digit completon e2e training tests on 8 L20 GPUs
- name: Running digit completon e2e training tests on 8 L20 GPUs + flash_attn 2.5.8
run: |
pytest -s tests/model/test_transformer.py
- name: Running digit completon e2e training tests on 8 L20 GPUs + latest flash_attn
run: |
pip3 install --upgrade flash_attn --no-build-isolation
pytest -s tests/model/test_transformer.py
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer
import torch
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForSequenceClassification
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
......@@ -14,7 +13,6 @@ test_configs = [
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1)
]
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct']
def test_hf_casual_models():
......@@ -37,8 +35,8 @@ def test_hf_casual_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
......@@ -53,7 +51,7 @@ def test_hf_casual_models():
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask)
origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)
logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
......@@ -98,8 +96,8 @@ def test_hf_value_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
......
......@@ -323,8 +323,8 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
......@@ -581,8 +581,8 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
......
......@@ -152,8 +152,7 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas
from flash_attn.bert_padding import pad_input, unpad_input
batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(-1),
attention_mask=attention_mask)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad,
......
......@@ -313,8 +313,7 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
from flash_attn.bert_padding import pad_input, unpad_input
batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(-1),
attention_mask=attention_mask)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
......
......@@ -55,8 +55,8 @@ class DataParallelPPOActor(BasePPOActor):
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
......@@ -181,8 +181,8 @@ class DataParallelPPOActor(BasePPOActor):
if self.use_remove_padding:
full_response_mask = attention_mask.clone()
full_response_mask[:, :-response_length] = 0 # set the prompt part to zero
full_response_mask_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
full_response_mask.unsqueeze(-1), attention_mask=attention_mask)
full_response_mask_rmpad, *_ = unpad_input(full_response_mask.unsqueeze(-1),
attention_mask=attention_mask)
full_response_mask_rmpad = full_response_mask_rmpad.squeeze(-1) # (total_nnz)
entropy_loss = core_algos.compute_entropy_loss(logits, full_response_mask_rmpad) # (total_nnz,)
else:
......
......@@ -55,8 +55,8 @@ class DataParallelPPOCritic(BasePPOCritic):
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
......
......@@ -738,8 +738,8 @@ class RewardModelWorker(Worker):
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment