Unverified Commit 5a94e14d by Chi Zhang Committed by GitHub

[perf] fix: set use_reentrant=False when enable gradient checkpointing (#114)

- Set use_reentrant=False to avoid duplicate allgather in backward when
gradient checkpointing is enabled.
- Optimize temperature computation by using inplace op
- Fix testing logics
parent e8eb9e4e
...@@ -611,7 +611,8 @@ class RayPPOTrainer(object): ...@@ -611,7 +611,8 @@ class RayPPOTrainer(object):
metrics.update(actor_output_metrics) metrics.update(actor_output_metrics)
# validate # validate
if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0: if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw): with _timer('testing', timing_raw):
val_metrics: dict = self._validate() val_metrics: dict = self._validate()
metrics.update(val_metrics) metrics.update(val_metrics)
......
...@@ -92,7 +92,7 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -92,7 +92,7 @@ class DataParallelPPOActor(BasePPOActor):
use_cache=False) # prevent model thinks we are generating use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad /= temperature logits_rmpad.div_(temperature)
# compute entropy # compute entropy
entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
...@@ -127,7 +127,8 @@ class DataParallelPPOActor(BasePPOActor): ...@@ -127,7 +127,8 @@ class DataParallelPPOActor(BasePPOActor):
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature logits = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1:-1] # (bsz, response_length) logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses']) log_probs = logprobs_from_logits(logits, micro_batch['responses'])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
......
...@@ -164,7 +164,7 @@ class ActorRolloutRefWorker(Worker): ...@@ -164,7 +164,7 @@ class ActorRolloutRefWorker(Worker):
actor_module.to(torch_dtype) actor_module.to(torch_dtype)
if enable_gradient_checkpointing: if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable() actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
torch.distributed.barrier() torch.distributed.barrier()
if self.rank == 0: if self.rank == 0:
...@@ -212,7 +212,8 @@ class ActorRolloutRefWorker(Worker): ...@@ -212,7 +212,8 @@ class ActorRolloutRefWorker(Worker):
sharding_strategy=sharding_strategy, # zero3 sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
sync_module_states=True, sync_module_states=True,
device_mesh=self.device_mesh) device_mesh=self.device_mesh,
forward_prefetch=False)
log_gpu_memory_usage('After Actor FSDP init', logger=logger) log_gpu_memory_usage('After Actor FSDP init', logger=logger)
...@@ -575,7 +576,7 @@ class CriticWorker(Worker): ...@@ -575,7 +576,7 @@ class CriticWorker(Worker):
critic_module.to(torch_dtype) critic_module.to(torch_dtype)
if config.model.get('enable_gradient_checkpointing', False): if config.model.get('enable_gradient_checkpointing', False):
critic_module.gradient_checkpointing_enable() critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if self.rank == 0: if self.rank == 0:
print_model_size(critic_module) print_model_size(critic_module)
...@@ -603,7 +604,8 @@ class CriticWorker(Worker): ...@@ -603,7 +604,8 @@ class CriticWorker(Worker):
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
sync_module_states=True) sync_module_states=True,
forward_prefetch=False)
log_gpu_memory_usage('After critic FSDP', logger=None) log_gpu_memory_usage('After critic FSDP', logger=None)
...@@ -806,7 +808,8 @@ class RewardModelWorker(Worker): ...@@ -806,7 +808,8 @@ class RewardModelWorker(Worker):
device_id=torch.cuda.current_device(), device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3 sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
sync_module_states=True, sync_module_states=True,
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload)) cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload),
forward_prefetch=False)
return reward_module return reward_module
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment