Unverified Commit cfc976b7 by Guangming Sheng Committed by GitHub

[misc] fix issue in hf_weight_loader and fix typo in doc (#30)

* [fix] fix some bugs related to hf_weight_loader

* [doc] fix doc typo

* [ci] fix github action

* [ci] lint

* [doc] fix typo
parent c5a09641
......@@ -24,6 +24,8 @@ jobs:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.head_ref }} # Checkout the branch associated with the pull request
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
......
......@@ -30,7 +30,7 @@ into two parts:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')
parser.add_argument('--hdfs_dir', default='hdfs://haruna/home/byte_data_seed/lf_lq/user/zhangchi.usc1992/data/rlhf')
parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args()
......
......@@ -37,7 +37,8 @@ if __name__ == '__main__':
# get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)
batch_size = 16
sequence_length = 1024
......
......@@ -72,11 +72,11 @@ class Trainer(MegatronWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True,
......@@ -96,21 +96,18 @@ class Trainer(MegatronWorker):
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
parallel_model.cuda()
return parallel_model
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})
optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
......@@ -126,13 +123,15 @@ class Trainer(MegatronWorker):
position_ids = data.batch['position_ids']
self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers)
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config.timers)
return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))
......@@ -142,11 +141,12 @@ if __name__ == '__main__':
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)
worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)
worker_group.init_model()
......
......@@ -30,6 +30,8 @@ def update_hf_weight_loader():
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
......
......@@ -16,7 +16,7 @@ import os
import logging
import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps
......@@ -42,7 +42,7 @@ class FSDPVLLMShardingManager(BaseShardingManager):
if full_params:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
state_dict_config=FullStateDictConfig())
else:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
......
......@@ -72,6 +72,8 @@ class vLLMRollout(BaseRollout):
"disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
if kwargs.get('train_tp', None) is not None:
# deployed with megatron
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment