Unverified Commit fb3793ab by Chujie Zheng Committed by GitHub

validate `use_remove_padding` when applying sequence parallelism (#153)

parent 679798cd
......@@ -86,4 +86,16 @@ def validate_config(config):
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.ulysses_sequence_parallel_size > 1 or \
config.actor_rollout_ref.ref.ulysses_sequence_parallel_size > 1:
assert config.actor_rollout_ref.model.use_remove_padding, \
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
if config.critic.strategy == 'fsdp':
if config.critic.ulysses_sequence_parallel_size > 1:
assert config.critic.model.use_remove_padding, \
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
print("[validate_config] All configuration checks passed successfully!")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment