Unverified Commit 5a94e14d by Chi Zhang Committed by GitHub

[perf] fix: set use_reentrant=False when enable gradient checkpointing (#114)

- Set use_reentrant=False to avoid duplicate allgather in backward when
gradient checkpointing is enabled.
- Optimize temperature computation by using inplace op
- Fix testing logics
parent e8eb9e4e
......@@ -611,7 +611,8 @@ class RayPPOTrainer(object):
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0:
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
......
......@@ -92,7 +92,7 @@ class DataParallelPPOActor(BasePPOActor):
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad /= temperature
logits_rmpad.div_(temperature)
# compute entropy
entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
......@@ -127,7 +127,8 @@ class DataParallelPPOActor(BasePPOActor):
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature
logits = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
......
......@@ -164,7 +164,7 @@ class ActorRolloutRefWorker(Worker):
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
torch.distributed.barrier()
if self.rank == 0:
......@@ -212,7 +212,8 @@ class ActorRolloutRefWorker(Worker):
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
device_mesh=self.device_mesh,
forward_prefetch=False)
log_gpu_memory_usage('After Actor FSDP init', logger=logger)
......@@ -575,7 +576,7 @@ class CriticWorker(Worker):
critic_module.to(torch_dtype)
if config.model.get('enable_gradient_checkpointing', False):
critic_module.gradient_checkpointing_enable()
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if self.rank == 0:
print_model_size(critic_module)
......@@ -603,7 +604,8 @@ class CriticWorker(Worker):
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
sync_module_states=True)
sync_module_states=True,
forward_prefetch=False)
log_gpu_memory_usage('After critic FSDP', logger=None)
......@@ -806,7 +808,8 @@ class RewardModelWorker(Worker):
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload))
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload),
forward_prefetch=False)
return reward_module
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment