Unverified Commit 6d7d3707 by Mingjie LIU Committed by GitHub

[feat] support mfu calculation for megatron_workers (#475)

calculate mfu in update actor/critic when using megatron workers
parent b0e7a942
......@@ -33,10 +33,13 @@ from verl import DataProto
from verl.utils.fs import copy_to_local
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.model import load_megatron_model_weights
from verl.utils.flops_counter import FlopsCounter
from verl.utils.megatron_utils import init_model_parallel_config
from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad
from verl.utils import hf_tokenizer
from codetiming import Timer
from megatron.core import parallel_state as mpu
from megatron.core import ModelParallelConfig
......@@ -333,6 +336,9 @@ class ActorRolloutRefWorker(MegatronWorker):
actor_optimizer=None,
actor_optimizer_config=None)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
......@@ -344,7 +350,12 @@ class ActorRolloutRefWorker(MegatronWorker):
log_gpu_memory_usage('Before update policy', logger=logger)
dataloader = self.actor.make_minibatch_iterator(data=data)
metrics = self.actor.update_policy(dataloader=dataloader)
with Timer(name='update_policy', logger=None) as timer:
metrics = self.actor.update_policy(dataloader=dataloader)
delta_time = timer.last
global_num_tokens = data.meta_info['global_token_num']
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
log_gpu_memory_usage('After update policy', logger=logger)
......@@ -577,6 +588,7 @@ class CriticWorker(MegatronWorker):
critic_module=critic_module,
critic_optimizer=critic_optimizer,
critic_optimizer_config=critic_optimizer_config)
self.flops_counter = FlopsCounter(critic_model_config)
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
......@@ -590,7 +602,12 @@ class CriticWorker(MegatronWorker):
def update_critic(self, data: DataProto):
data = data.to('cuda')
dataloader = self.critic.make_minibatch_iterator(data)
metrics = self.critic.update_critic(dataloader=dataloader)
with Timer(name='update_critic', logger=None) as timer:
metrics = self.critic.update_critic(dataloader=dataloader)
delta_time = timer.last
global_num_tokens = data.meta_info['global_token_num']
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
output = DataProto(batch=None, meta_info={'metrics': metrics})
output = output.to('cpu')
return output
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment