Unverified Commit 8bf9b95a by Guangming Sheng Committed by GitHub

[misc] fix: only return old_log_prob and temp to fix union problem in old_log_prob (#137)

- As titled
parent a63f4a0f
...@@ -468,11 +468,11 @@ class ActorRolloutRefWorker(Worker): ...@@ -468,11 +468,11 @@ class ActorRolloutRefWorker(Worker):
# perform recompute log_prob # perform recompute log_prob
with self.ulysses_sharding_manager: with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data) data = self.ulysses_sharding_manager.preprocess_data(data)
old_log_probs = self.actor.compute_log_prob(data=data) output = self.actor.compute_log_prob(data=data)
data.batch['old_log_probs'] = old_log_probs output = DataProto.from_dict(tensors={'old_log_probs': output},
data = self.ulysses_sharding_manager.postprocess_data(data) meta_info={'temperature': self.config.rollout.temperature})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = data.select(batch_keys=['old_log_probs'])
output = output.to('cpu') output = output.to('cpu')
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment