Unverified Commit 3fc3e2b7 by Joel Committed by GitHub

fix: remove redundant torch.cuda.empty_cache() (#575)

#556 take effort to remove remove unnecessary empty_cache, but will
cause CUDA oom at vllm wake_up.
```text
  File "/opt/tiger/ray/session_2025-03-13_12-11-30_408315_2895/runtime_resources/working_dir_files/_ray_pkg_a64b690733067c5c/verl/workers/fsdp_workers.py", line 481, in generate_sequences
    with self.rollout_sharding_manager:
  File "/opt/tiger/ray/session_2025-03-13_12-11-30_408315_2895/runtime_resources/working_dir_files/_ray_pkg_a64b690733067c5c/verl/workers/sharding_manager/fsdp_vllm.py", line 82, in __enter__
    self.inference_engine.wake_up()
  File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/llm.py", line 1244, in wake_up
    self.llm_engine.wake_up()
  File "/usr/local/lib/python3.11/dist-packages/vllm/engine/llm_engine.py", line 1859, in wake_up
    self.model_executor.wake_up()
  File "/usr/local/lib/python3.11/dist-packages/vllm/executor/executor_base.py", line 216, in wake_up
    self.collective_rpc("wake_up")
  File "/usr/local/lib/python3.11/dist-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/utils.py", line 2196, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/vllm/worker/worker.py", line 140, in wake_up
    allocator.wake_up()
  File "/usr/local/lib/python3.11/dist-packages/vllm/device_allocator/cumem.py", line 207, in wake_up
    create_and_map(handle)
  File "/usr/local/lib/python3.11/dist-packages/vllm/device_allocator/cumem.py", line 75, in create_and_map
    python_create_and_map(*allocation_handle)
RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
```
This PR remove all redundant `torch.cuda.empty_cache()` in FSDP worker
and only empty cache before vllm wake_up and after vllm sleep, since
vllm has its own caching memory allocator
[CuMemAllocator](https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103).
Out of vllm scope, we should avoid empty cache to let pytorch using
caching memory to speed up memory allocations.

- [x] Cleanup FSDP worker torch.cuda.empty_cache()
- [ ] Cleanup Megatron worker torch.cuda.empty_cache()
parent 9bb02d27
......@@ -247,8 +247,6 @@ class PRIMERewardModelWorker(Worker):
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to('cuda')
......@@ -321,7 +319,6 @@ class PRIMERewardModelWorker(Worker):
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
torch.cuda.empty_cache()
output = output.to('cpu')
return output
......
......@@ -411,8 +411,6 @@ class ActorRolloutRefWorker(Worker):
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
......@@ -498,7 +496,6 @@ class ActorRolloutRefWorker(Worker):
output = output.to('cpu')
# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
......@@ -561,7 +558,6 @@ class ActorRolloutRefWorker(Worker):
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
......@@ -784,8 +780,6 @@ class CriticWorker(Worker):
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
......@@ -981,7 +975,6 @@ class RewardModelWorker(Worker):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
self.reward_module = self._build_model(config=self.config)
torch.cuda.empty_cache()
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
......@@ -1154,5 +1147,4 @@ class RewardModelWorker(Worker):
self.reward_module._handle.reshard(True)
output = output.to('cpu')
torch.cuda.empty_cache()
return output
......@@ -70,6 +70,15 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self.gen_random_states = None
def __enter__(self):
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
......@@ -89,7 +98,6 @@ class FSDPVLLMShardingManager(BaseShardingManager):
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)
del params
torch.cuda.empty_cache()
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)
# TODO: offload FSDP model weights
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment