Unverified Commit 54574690 by Hongpeng Guo Committed by GitHub

[Config] Providing an option to turn off `torch.compile` in actor (#554)

## Summary

Providing an option in the config to turn off the `torch.compile` used
in `dp_actor.py`

## Usage

Adding the following line to the driver or cli scripts to turn off
`torch.compile`.
```python
+actor_rollout_ref.actor.use_torch_compile=False
```
Otherwise, `torch.compile` will be used by default

## Related Issue

#354 #245

---------

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
parent d2db7252
......@@ -84,6 +84,7 @@ Actor/Rollout/Reference Policy
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
......@@ -176,6 +177,8 @@ Actor/Rollout/Reference Policy
- ``actor_rollout_ref.actor.clip_ratio``: PPO clip ratio
- ``actor_rollout_ref.actor.use_torch_compile``: Whether to use torch compile in actor
- ``actor_rollout_ref.actor.entropy_coeff``: The weight of entropy when
calculating PPO loss
......
......@@ -26,6 +26,7 @@ actor_rollout_ref:
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
use_torch_compile: True # False to disable torch compile
clip_ratio: 0.2
entropy_coeff: 0.001
ppo_epochs: 1
......
......@@ -33,6 +33,7 @@ actor_rollout_ref:
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
......
......@@ -53,7 +53,10 @@ class DataParallelPPOActor(BasePPOActor):
self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
self.compute_entropy_from_logits = (
torch.compile(verl_F.entropy_from_logits, dynamic=True)
if self.config.get('use_torch_compile', True) # use torch compile by default
else verl_F.entropy_from_logits)
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment