Unverified Commit a339f6ff by G.O.D Committed by GitHub

skip special tokens (#715)

it should skip special tokens here. just like trl do
https://github.com/huggingface/trl/blob/fc2b041b58f6fbe766dceaec819bc5a8f9d209da/trl/trainer/grpo_trainer.py#L597


if `skip_special_tokens=False`,  completion 

```
<think>...</think><answer>....</answer>
```

will be decoded as things such as
```
<think>...</think><answer>....</answer><|im_end|><|endoftext|>
```

which will render typical `format_reward_func` mismatch

```python
r"^<think>.*?</think>\s*<answer>.*?</answer>$"
```
parent c523a314
...@@ -43,8 +43,8 @@ class NaiveRewardManager: ...@@ -43,8 +43,8 @@ class NaiveRewardManager:
valid_response_ids = response_ids[:valid_response_length] valid_response_ids = response_ids[:valid_response_length]
# decode # decode
prompt_str = self.tokenizer.decode(valid_prompt_ids) prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids) response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
...@@ -88,8 +88,8 @@ class NaiveRewardManager: ...@@ -88,8 +88,8 @@ class NaiveRewardManager:
valid_response_ids = response_ids[:valid_response_length] valid_response_ids = response_ids[:valid_response_length]
# decode # decode
prompt_str = self.tokenizer.decode(valid_prompt_ids) prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids) response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment