Unverified Commit bbda2e57 by Chi Zhang Committed by GitHub

[ppo] fix: fix minibatch size when n > 1 for megatron worker (#370)

parent 656accb0
......@@ -333,7 +333,7 @@ class RayWorkerGroup(WorkerGroup):
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
def execute_all_async(self, method_name: str, *args, **kwargs):
# Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers),
# Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers),
# we'll distribute each element in these lists to the corresponding worker
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
length = len(self._workers)
......
......@@ -111,6 +111,7 @@ class ActorRolloutRefWorker(MegatronWorker):
# normalize config
if self._is_actor and self._is_rollout:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
if self.config.actor.get('ppo_micro_batch_size', None):
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment