Unverified Commit 0cfd548c by Yan Bai Committed by GitHub

megatron:Update megatron-lm to `core_r0.11.0` (#392)

# Support Megatron mcore 0.11

## Description
This PR introduces official support for Megatron mcore 0.11 with the
following updates:
- Upgraded Megatron to version `core_r0.11.0`
- Applied compatibility patch `patches/mcore_r0.11.patch`
- Removed legacy version support for cleaner implementation

Special thanks to @chendong-1998 for:
- Original Megatron upgrade from 0.4 to 0.6 (#93f6a7e)

## Compatibility Notes
Current implementation requires careful handling due to dependency
conflicts:
- `megatron-core==0.11.0` requires torch>=2.6
- `vllm==0.6.3` requires torch==2.4

Installation constraints:
1. Must use vllm's torch dependency (2.4) as baseline
2. Do NOT run `pip install -e .` in mcore directory (will upgrade torch
to 2.6)
3. Apply compatibility patch manually after installation

## Testing
### test with `verl/examples/ppo_trainer/run_deepseek_megatron.sh`

![image](https://github.com/user-attachments/assets/e053c9b8-fdd7-47fc-aaeb-42cf85070056)

---------

Signed-off-by: chendong-1998 <chendong136@huawei.com>
Co-authored-by: chendong-1998 <chendong136@huawei.com>
Co-authored-by: gaoziyuan <gaoziyuan.955@bytedance.com>
Co-authored-by: Sion Gao <gaoziyuan19@mails.ucas.ac.cn>
parent 85768a5c
name: e2e_gsm8k_megatron name: e2e_gsm8k_megatron
# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0
on: on:
# Trigger the workflow on push or pull request, # Trigger the workflow on push or pull request,
...@@ -33,7 +34,7 @@ jobs: ...@@ -33,7 +34,7 @@ jobs:
NO_PROXY: "localhost,127.0.0.1" NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1 HF_HUB_ENABLE_HF_TRANSFER: 1
container: container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 image: whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-megatron0.11.0-v0.0.5
options: --gpus all --shm-size=10g options: --gpus all --shm-size=10g
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
...@@ -49,11 +50,10 @@ jobs: ...@@ -49,11 +50,10 @@ jobs:
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Deepseek) - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Deepseek)
run: | run: |
ray stop --force ray stop --force
[ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM export PYTHONPATH=$PYTHONPATH:/opt/nvidia/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_deepseek_megatron.sh bash tests/e2e/run_deepseek_megatron.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Qwen) - name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Qwen)
run: | run: |
ray stop --force ray stop --force
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM export PYTHONPATH=$PYTHONPATH:/opt/nvidia/Megatron-LM
bash tests/e2e/run_qwen_megatron.sh bash tests/e2e/run_qwen_megatron.sh
\ No newline at end of file
FROM verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
RUN cd /opt/nvidia && git clone --single-branch --branch core_r0.11.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
# only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed
# unset for now
RUN cd /opt/nvidia/Megatron-LM && pip3 install --no-deps -e .
\ No newline at end of file
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
from packaging.version import Version
import torch import torch
import time import time
from typing import Dict, Any, Callable, Optional from typing import Dict, Any, Callable, Optional
...@@ -53,7 +55,7 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params ...@@ -53,7 +55,7 @@ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params
""" """
import megatron import megatron
from megatron.core import mpu from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model from megatron.training.utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
......
...@@ -12,17 +12,21 @@ ...@@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import megatron import importlib
from megatron.core import mpu from packaging.version import Version
from megatron.utils import print_rank_0, unwrap_model
from megatron.model import Float16Module
from megatron.model import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch import torch
import time import time
from typing import Optional from typing import Optional
import torch.distributed as dist import torch.distributed as dist
import megatron
from megatron import get_args from megatron import get_args
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.training.utils import print_rank_0, unwrap_model
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
...@@ -77,7 +81,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtyp ...@@ -77,7 +81,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtyp
"""Merge sharded parameters of a Megatron module into a merged checkpoint. """Merge sharded parameters of a Megatron module into a merged checkpoint.
Args: Args:
wrapped_models (list of megatron.model.DistributedDataParallel): wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
The local DDP wrapped megatron modules. The local DDP wrapped megatron modules.
dtype (str or None): dtype (str or None):
The data type of state_dict. if None, the data type of the original parameters The data type of state_dict. if None, the data type of the original parameters
......
...@@ -53,7 +53,7 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params ...@@ -53,7 +53,7 @@ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params
""" """
import megatron import megatron
from megatron.core import mpu from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model from megatron.training.utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
import megatron import megatron
from megatron.core import mpu from megatron.core import mpu
from megatron.utils import print_rank_0, unwrap_model from megatron.training.utils import print_rank_0, unwrap_model
from megatron.model import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.model import DistributedDataParallel as LocalDDP from megatron.core.distributed import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch import torch
import time import time
...@@ -77,7 +77,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtyp ...@@ -77,7 +77,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, is_value_model=False, dtyp
"""Merge sharded parameters of a Megatron module into a merged checkpoint. """Merge sharded parameters of a Megatron module into a merged checkpoint.
Args: Args:
wrapped_modelss (list of megatron.model.DistributedDataParallel): wrapped_modelss (list of megatron.core.distributed.DistributedDataParallel):
The local DDP wrapped megatron modules. The local DDP wrapped megatron modules.
dtype (str or None): dtype (str or None):
The data type of state_dict. if None, the data type of the original parameters The data type of state_dict. if None, the data type of the original parameters
......
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
from packaging.version import Version
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
from megatron.optimizer.distrib_optimizer import DistributedOptimizer
from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler
from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
from megatron.optimizer import get_param_groups
from verl.utils.megatron.optimizer_config import OptimizerConfig from megatron.core.optimizer import OptimizerConfig
from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native
def get_megatron_optimizer( def get_megatron_optimizer(
...@@ -33,60 +34,8 @@ def get_megatron_optimizer( ...@@ -33,60 +34,8 @@ def get_megatron_optimizer(
overlap_param_gather=False # add for verl overlap_param_gather=False # add for verl
): ):
# Base optimizer. # Base optimizer.
param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult) return get_megatron_optimizer_native(config=config,
model_chunks=model,
if config.optimizer == 'adam': no_weight_decay_cond=no_weight_decay_cond,
optimizer = Adam(param_groups, scale_lr_cond=scale_lr_cond,
lr=config.lr, lr_mult=lr_mult)
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps)
elif config.optimizer == 'sgd':
optimizer = SGD(param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
# Determine whether the params have main-grad field.
params_have_main_grad = True
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
# Dynamic loss scale.
else:
if config.fp16:
grad_scaler = DynamicGradScaler(initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis)
# Megatron optimizer.
if config.use_distributed_optimizer:
return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16,
config.params_dtype, grad_scaler, model, overlap_param_gather)
else:
return Float16OptimizerWithFloat16Params(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16,
config.bf16, config.params_dtype, grad_scaler, model)
# FP32.
return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad,
params_have_main_grad, model)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class OptimizerConfig:
"""Configuration for optimizer."""
##############
# General
##############
optimizer: str = 'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr: Optional[float] = None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr: Optional[float] = None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr: Optional[float] = None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr: Optional[float] = None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay: float = 0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16: bool = False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16: bool = False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype: torch.dtype = torch.float32
"""dtype used when initializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale: Optional[float] = None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale: float = 2**32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale: float = 1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window: float = 1000
"""Window over which to raise/lower dynamic scale."""
hysteresis: int = 2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1: float = 0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2: float = 0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps: float = 1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum: float = 0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer: bool = False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce: bool = False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
################
# Miscellaneous
################
clip_grad: float = 1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad: bool = False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time: bool = False
"""If true, use barrier with level 1 time measurements."""
timers: Callable = None
"""Function to get timers."""
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain utilities.""" """Pretrain utilities."""
import importlib
from packaging.version import Version
from typing import Any, Dict from typing import Any, Dict
import time import time
from omegaconf import DictConfig from omegaconf import DictConfig
...@@ -23,12 +25,18 @@ import torch.nn as nn ...@@ -23,12 +25,18 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config from megatron.core.utils import get_attr_wrapped_model
from megatron.core.transformer import TransformerConfig from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
# from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core import ModelParallelConfig
from megatron.core.optimizer import OptimizerConfig
def get_model_config(model):
return get_attr_wrapped_model(model, 'megatron_config', allow_none=False)
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
...@@ -95,22 +103,28 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -95,22 +103,28 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
model_module.cuda(torch.cuda.current_device()) model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
config = get_model_config(model[0]) config: ModelParallelConfig = get_model_config(model[0])
config.fp8 = None
tfconfig: TransformerConfig = convert_config(model[0].config, config)
if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel
model = [Float16Module(config, model_module) for model_module in model] model = [Float16Module(config, model_module) for model_module in model]
if wrap_with_ddp: if wrap_with_ddp:
model = [ ddp_models = []
DDP(config=config, for model_chunk_idx, model_chunk in enumerate(model):
ddp_model = DDP(
config=tfconfig,
module=model_chunk, module=model_chunk,
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True), disable_bucketing=(model_chunk_idx > 0),
accumulate_allreduce_grads_in_fp32=True, ddp_config=DistributedDataParallelConfig(
overlap_grad_reduce=False, overlap_grad_reduce=False,
use_distributed_optimizer=True, use_distributed_optimizer=True,
disable_bucketing=(model_chunk_idx > 0)) for (model_chunk_idx, model_chunk) in enumerate(model) grad_reduce_in_fp32=True, # [old] accumulate_allreduce_grads_in_fp32=True,
] ))
ddp_models.append(ddp_model)
model = ddp_models
# # Broadcast params from data parallel src rank to other data parallel ranks. # # Broadcast params from data parallel src rank to other data parallel ranks.
# if args.data_parallel_random_init: # # if args.data_parallel_random_init:
for model_module in model: for model_module in model:
model_module.broadcast_params() model_module.broadcast_params()
return model return model
...@@ -139,7 +153,7 @@ from transformers import PretrainedConfig ...@@ -139,7 +153,7 @@ from transformers import PretrainedConfig
def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
print(f'megatron config {megatron_config}') print(f'megatron config {megatron_config}')
dt = PrecisionType.to_dtype(megatron_config['param_dtype']) dt = PrecisionType.to_dtype(megatron_config.params_dtype)
print(f'pipeline_dtype=megatron_config {dt}') print(f'pipeline_dtype=megatron_config {dt}')
transformer_config = TransformerConfig( transformer_config = TransformerConfig(
num_layers=hf_config.num_hidden_layers, num_layers=hf_config.num_hidden_layers,
...@@ -158,12 +172,13 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC ...@@ -158,12 +172,13 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
pipeline_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']), pipeline_dtype=dt,
params_dtype=PrecisionType.to_dtype(megatron_config['param_dtype']), params_dtype=dt,
sequence_parallel=megatron_config['sequence_parallel_enabled'], sequence_parallel=True,
variable_seq_lengths=True, variable_seq_lengths=True,
masked_softmax_fusion=True, masked_softmax_fusion=True,
bf16=PrecisionType.to_dtype(megatron_config['param_dtype']) is torch.bfloat16) moe_token_dispatcher_type="alltoall",
bf16=dt is torch.bfloat16)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(f'tensor_parallel_size={transformer_config.tensor_model_parallel_size} \n \ print(f'tensor_parallel_size={transformer_config.tensor_model_parallel_size} \n \
pipeline_model_parallel_size={transformer_config.pipeline_model_parallel_size} \n \ pipeline_model_parallel_size={transformer_config.pipeline_model_parallel_size} \n \
...@@ -177,11 +192,6 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC ...@@ -177,11 +192,6 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC
return transformer_config return transformer_config
# from megatron.core.optimizer import OptimizerConfig
from verl.utils.megatron.optimizer_config import OptimizerConfig
def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig: def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
config = OptimizerConfig( config = OptimizerConfig(
optimizer='adam', optimizer='adam',
...@@ -195,12 +205,9 @@ def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig: ...@@ -195,12 +205,9 @@ def init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:
return config return config
from megatron.core import ModelParallelConfig
def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig: def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
# TODO(sgm): check how to disable megatron timers # TODO(sgm): check how to disable megatron timers
timers = FakeTimers() timers = None
return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'), return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'),
pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'), pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'),
virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'), virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'),
...@@ -212,17 +219,6 @@ def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig: ...@@ -212,17 +219,6 @@ def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig:
timers=timers) timers=timers)
class FakeTimers:
"""Disable All Megatron Timing with FakeTimers"""
def __init__(self):
from megatron.timers import DummyTimer
self.dummy_timer = DummyTimer()
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.dummy_timer
def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None): def offload_megatron_param_and_grad(module_list: nn.ModuleList, offload_grad=False, hybrid_engine=None):
if hybrid_engine is not None: if hybrid_engine is not None:
pp_rank = mpu.get_pipeline_model_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
This file contains utilities to manipulate torch memory buffers This file contains utilities to manipulate torch memory buffers
""" """
from typing import Dict, List from typing import Dict, List, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -27,11 +27,14 @@ class MemoryBuffer: ...@@ -27,11 +27,14 @@ class MemoryBuffer:
memory. It must have a unique type to support this behavior. memory. It must have a unique type to support this behavior.
""" """
def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype): def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):
self.numel = numel self.numel = numel
self.numel_padded = numel_padded self.numel_padded = numel_padded
self.dtype = dtype self.dtype = dtype
self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) if source is not None:
self.data = source
else:
self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False)
def zero(self): def zero(self):
"""Reset the buffer to zero.""" """Reset the buffer to zero."""
......
...@@ -19,9 +19,9 @@ import torch ...@@ -19,9 +19,9 @@ import torch
from typing import Union from typing import Union
HALF_LIST = [16, "16", "fp16", "float16"] HALF_LIST = [16, "16", "fp16", "float16", torch.float16]
FLOAT_LIST = [32, "32", "fp32", "float32"] FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32]
BFLOAT_LIST = ["bf16", "bfloat16"] BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
class PrecisionType(object): class PrecisionType(object):
......
...@@ -19,22 +19,26 @@ In megatron actor, the differences are: ...@@ -19,22 +19,26 @@ In megatron actor, the differences are:
Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer
""" """
import importlib
from functools import partial from functools import partial
from packaging.version import Version
from typing import Iterable, Dict from typing import Iterable, Dict
import torch import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
# from megatron import get_args # from megatron import get_args
from megatron.optimizer import DistributedOptimizer from megatron.core.optimizer import OptimizerConfig
from verl.utils.megatron.optimizer_config import OptimizerConfig
from megatron.core import parallel_state as mpu from megatron.core import parallel_state as mpu
from megatron.core import ModelParallelConfig from megatron.core import ModelParallelConfig
from megatron.core.utils import get_model_config from verl.utils.megatron_utils import get_model_config
from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.distributed import finalize_model_grads from megatron.core.distributed import finalize_model_grads
# from megatron.core.optimizer import DistributedOptimizer # from megatron.core.optimizer import DistributedOptimizer
from megatron.core.optimizer import DistributedOptimizer
from omegaconf import OmegaConf from omegaconf import OmegaConf
from verl.utils.megatron.tensor_parallel import vocab_parallel_compute_entropy_loss, vocab_parallel_log_probs_from_logits from verl.utils.megatron.tensor_parallel import vocab_parallel_compute_entropy_loss, vocab_parallel_log_probs_from_logits
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
...@@ -160,7 +164,7 @@ class MegatronPPOActor(BasePPOActor): ...@@ -160,7 +164,7 @@ class MegatronPPOActor(BasePPOActor):
response = data['responses'] response = data['responses']
response_length = response.size(1) response_length = response.size(1)
logits = output['logits'] logits = output['logits']
logits = logits[:, -response_length - 1:-1] logits = logits[:, -response_length - 1:-1].contiguous()
log_probs = vocab_parallel_log_probs_from_logits(logits, response) log_probs = vocab_parallel_log_probs_from_logits(logits, response)
return {'log_probs': log_probs} return {'log_probs': log_probs}
...@@ -275,8 +279,10 @@ class MegatronPPOActor(BasePPOActor): ...@@ -275,8 +279,10 @@ class MegatronPPOActor(BasePPOActor):
# compute policy loss # compute policy loss
logits = output.logits logits = output.logits
logits = logits[:, -response_length - 1:-1] logits = logits[:, -response_length - 1:-1].contiguous()
logits_back = logits.clone()
log_prob = vocab_parallel_log_probs_from_logits(logits, responses) log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
logits = logits_back
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob, log_prob=log_prob,
advantages=advantages, advantages=advantages,
...@@ -316,9 +322,7 @@ class MegatronPPOActor(BasePPOActor): ...@@ -316,9 +322,7 @@ class MegatronPPOActor(BasePPOActor):
data_iterator=batch_generator, data_iterator=batch_generator,
model=self.actor_module, model=self.actor_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=batch_size * seq_len, # no use when input_shapes was set seq_length=batch_size * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only, forward_only=forward_only,
) )
...@@ -329,7 +333,6 @@ class MegatronPPOActor(BasePPOActor): ...@@ -329,7 +333,6 @@ class MegatronPPOActor(BasePPOActor):
model=self.actor_module, model=self.actor_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
seq_length=batch_size * seq_len, # in use for pp = 1 seq_length=batch_size * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1
forward_only=forward_only, forward_only=forward_only,
) )
...@@ -355,14 +358,14 @@ class MegatronPPOActor(BasePPOActor): ...@@ -355,14 +358,14 @@ class MegatronPPOActor(BasePPOActor):
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.actor_module: for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer # if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer(zero_buffer=(not self.actor_optimizer_config.use_distributed_optimizer)) chunk.zero_grad_buffer()
metric_micro_batch = self.forward_backward_batch(data) metric_micro_batch = self.forward_backward_batch(data)
for metric in metric_micro_batch: for metric in metric_micro_batch:
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step( update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()
self.megatron_config, self.megatron_config.timers)
if update_successful: if update_successful:
# allgather already execute in optimizer.step in new megatron # allgather already execute in optimizer.step in new megatron
pass pass
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
Implement a multiprocess PPOCritic Implement a multiprocess PPOCritic
""" """
import importlib
from functools import partial from functools import partial
from packaging.version import Version
from typing import Iterable from typing import Iterable
import torch import torch
...@@ -31,11 +33,11 @@ from verl.utils.py_functional import append_to_dict ...@@ -31,11 +33,11 @@ from verl.utils.py_functional import append_to_dict
from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches from verl.utils.torch_functional import masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron.optimizer_config import OptimizerConfig from megatron.core.optimizer import OptimizerConfig
from megatron.optimizer import DistributedOptimizer
from megatron.core import parallel_state as mpu from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.optimizer import DistributedOptimizer
class MegatronPPOCritic(BasePPOCritic): class MegatronPPOCritic(BasePPOCritic):
...@@ -185,9 +187,7 @@ class MegatronPPOCritic(BasePPOCritic): ...@@ -185,9 +187,7 @@ class MegatronPPOCritic(BasePPOCritic):
data_iterator=batch_generator, data_iterator=batch_generator,
model=self.critic_module, model=self.critic_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # no use when input_shapes was set seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only, forward_only=forward_only,
) )
...@@ -198,7 +198,6 @@ class MegatronPPOCritic(BasePPOCritic): ...@@ -198,7 +198,6 @@ class MegatronPPOCritic(BasePPOCritic):
model=self.critic_module, model=self.critic_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # in use for pp = 1 seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1
forward_only=forward_only, forward_only=forward_only,
) )
...@@ -213,12 +212,12 @@ class MegatronPPOCritic(BasePPOCritic): ...@@ -213,12 +212,12 @@ class MegatronPPOCritic(BasePPOCritic):
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.critic_module: for chunk in self.critic_module:
chunk.zero_grad_buffer(zero_buffer=(not self.critic_optimizer_config.use_distributed_optimizer)) chunk.zero_grad_buffer()
metric_micro_batch = self.forward_backward_batch(data) metric_micro_batch = self.forward_backward_batch(data)
update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step( update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step()
self.megatron_config, self.megatron_config.timers)
if update_successful: if update_successful:
# allgather already execute in optimizer.step in new megatron # allgather already execute in optimizer.step in new megatron
pass pass
......
...@@ -233,9 +233,7 @@ class MegatronRewardModel(BasePPORewardModel): ...@@ -233,9 +233,7 @@ class MegatronRewardModel(BasePPORewardModel):
data_iterator=batch_generator, data_iterator=batch_generator,
model=self.reward_model_module, model=self.reward_model_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
input_shapes=input_shapes, # must set for flash-attn sequence packing
seq_length=infer_batch_size * seq_len, # no use when input_shapes was set seq_length=infer_batch_size * seq_len, # no use when input_shapes was set
hidden_size=self.model_config.hidden_size, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set
forward_only=True, forward_only=True,
) )
...@@ -246,7 +244,6 @@ class MegatronRewardModel(BasePPORewardModel): ...@@ -246,7 +244,6 @@ class MegatronRewardModel(BasePPORewardModel):
model=self.reward_model_module, model=self.reward_model_module,
num_microbatches=n_micro_batch, num_microbatches=n_micro_batch,
seq_length=infer_batch_size * seq_len, # in use for pp = 1 seq_length=infer_batch_size * seq_len, # in use for pp = 1
hidden_size=self.model_config.hidden_size, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1
forward_only=True, forward_only=True,
) )
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.
""" """
import importlib
from packaging.version import Version
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -81,11 +83,24 @@ class AllGatherPPModel: ...@@ -81,11 +83,24 @@ class AllGatherPPModel:
def _build_param_buffer(self, pp_rank): def _build_param_buffer(self, pp_rank):
"""Build the parameter buffer in each pp rank""" """Build the parameter buffer in each pp rank"""
model = self.pp_models[pp_rank] if pp_rank == self._pp_rank:
weight_buffer_meta = get_weight_buffer_meta_from_module(model) from verl.utils.memory_buffer import MemoryBuffer
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta) # The code here is very hard-coded, based on the following assumptions:
# 1. `len(_this_rank_models) == 1`
# 2. `_this_rank_models[0]` is a instance of `DistributedDataParallel` and `use_distributed_optimizer=True`
# 3. Only bfloat16 data type is used in parameters
source = self._this_rank_models[0].buffers[0].param_data
self.memory_buffers[pp_rank] = {
torch.bfloat16: MemoryBuffer(source.numel(), source.numel(), torch.bfloat16, source)
}
else:
model = self.pp_models[pp_rank]
weight_buffer_meta = get_weight_buffer_meta_from_module(model)
self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
def _build_param_references(self, pp_rank, maintain_weight=False): def _build_param_references(self, pp_rank, maintain_weight=False):
if pp_rank == self._pp_rank:
return
model = self.pp_models[pp_rank] model = self.pp_models[pp_rank]
build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight) build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)
...@@ -121,8 +136,9 @@ class AllGatherPPModel: ...@@ -121,8 +136,9 @@ class AllGatherPPModel:
global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank) global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)
# NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models # NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models
for memory_buffer in self.memory_buffers[cur_pp_rank].values():
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False) for _, param in sorted(self.pp_models[cur_pp_rank].named_parameters()):
dist.broadcast(tensor=param.data, src=global_src, group=self.pp_group, async_op=False)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment