Unverified Commit 9bb02d27 by CajZella Committed by GitHub

[bugfix] PRIME filter overlong propmts & padding side incorrect & use xformers (#570)

### Description
- fix filter_overlong_prompts setting in PRIME

- fix padding side incorrect for Qwen in PRIME 

- When I utilize PRIME recipe to train Qwen series models, I got
“*ValueError: You are attempting to perform batched generation with
padding_side='right' this may lead to unexpected behaviour for Flash
Attention version of Qwen2. Make sure to call tokenizer.padding_side =
'left' before tokenizing the input.*” So I set `use_cache = False` when
calling model to calculate output logits.

- fix CUDA error with vllm v0.6.3 

- When I run PRIME, I may get an error — *CUDA error: an illegal memory
access was encountered*. According to
https://github.com/vllm-project/vllm/issues/10389, I set
`VLLM_ATTENTION_BACKEND=XFORMERS` .
parent 79e072f1
......@@ -97,7 +97,8 @@ class DataParallelPRIMERewardModel:
else:
rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids']).logits
position_ids=micro_batch['position_ids'],
use_cache=False).logits
rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
......@@ -124,7 +125,8 @@ class DataParallelPRIMERewardModel:
else:
ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids']).logits
position_ids=micro_batch['position_ids'],
use_cache=False).logits
ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
ref_log_labels = ref_log_prob.gather(dim=-1,
......
......@@ -183,7 +183,8 @@ class RayPRIMETrainer(RayPPOTrainer):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False))
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
......@@ -205,7 +206,8 @@ class RayPRIMETrainer(RayPPOTrainer):
max_prompt_length=self.config.data.max_prompt_length,
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False))
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
......
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
......@@ -17,6 +19,7 @@ python3 -m recipe.prime.main_prime \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=3072 \
data.filter_overlong_prompts=True \
data.filter_accuracy=True \
data.accuracy_lower_bound=0.2 \
data.accuracy_upper_bound=0.8 \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment