Unverified Commit cb943be5 by Guangming Sheng Committed by GitHub

[misc] fix: validation batch repeat before feed into rollout (#614)

parent 6133ae92
......@@ -505,6 +505,10 @@ class RayPPOTrainer(object):
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# repeat test batch
test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n,
interleave=True)
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
return {}
......@@ -537,6 +541,7 @@ class RayPPOTrainer(object):
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print('validation generation end')
......
......@@ -191,7 +191,7 @@ class vLLMRollout(BaseRollout):
'top_k': self.config.val_kwargs.top_k,
'top_p': self.config.val_kwargs.top_p,
'temperature': self.config.val_kwargs.temperature,
'n': self.config.val_kwargs.n,
'n': 1, # if validate, already repeat in ray_trainer
}
# users can customize different sampling_params at different run
......
......@@ -204,7 +204,7 @@ class vLLMRollout(BaseRollout):
'top_k': self.config.val_kwargs.top_k,
'top_p': self.config.val_kwargs.top_p,
'temperature': self.config.val_kwargs.temperature,
'n': self.config.val_kwargs.n,
'n': 1, # if validate, already repeat in ray_trainer
}
# users can customize different sampling_params at different run
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment