Unverified Commit 6872dbef by Guangming Sheng Committed by GitHub

[misc] fix: load and offload in compute log prob (#208)

- As titled
- Relevant: https://github.com/volcengine/verl/issues/181
parent 89ba48e7
......@@ -470,6 +470,10 @@ class ActorRolloutRefWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data = data.to('cuda')
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
......@@ -491,7 +495,13 @@ class ActorRolloutRefWorker(Worker):
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
# NOTE(sgm): the grad is already in CPU, only offload param here
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After compute_log_prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment