Unverified Commit cfc976b7 by Guangming Sheng Committed by GitHub

[misc] fix issue in hf_weight_loader and fix typo in doc (#30)

* [fix] fix some bugs related to hf_weight_loader

* [doc] fix doc typo

* [ci] fix github action

* [ci] lint

* [doc] fix typo
parent c5a09641
...@@ -24,6 +24,8 @@ jobs: ...@@ -24,6 +24,8 @@ jobs:
python-version: ["3.12"] python-version: ["3.12"]
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.head_ref }} # Checkout the branch associated with the pull request
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
......
...@@ -30,7 +30,7 @@ into two parts: ...@@ -30,7 +30,7 @@ into two parts:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/gsm8k') parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')
parser.add_argument('--hdfs_dir', default='hdfs://haruna/home/byte_data_seed/lf_lq/user/zhangchi.usc1992/data/rlhf') parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -37,7 +37,8 @@ if __name__ == '__main__': ...@@ -37,7 +37,8 @@ if __name__ == '__main__':
# get the worker group using names # get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1'] worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer) cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)
batch_size = 16 batch_size = 16
sequence_length = 1024 sequence_length = 1024
......
...@@ -72,11 +72,11 @@ class Trainer(MegatronWorker): ...@@ -72,11 +72,11 @@ class Trainer(MegatronWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL) @register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self): def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256, actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048, hidden_size=2048,
intermediate_size=5504, intermediate_size=5504,
num_hidden_layers=24, num_hidden_layers=24,
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=16) num_key_value_heads=16)
megatron_config = OmegaConf.create({ megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True, 'sequence_parallel_enabled': True,
...@@ -96,21 +96,18 @@ class Trainer(MegatronWorker): ...@@ -96,21 +96,18 @@ class Trainer(MegatronWorker):
# this_megatron_config = copy.deepcopy(megatron_config) # this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config, parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config, megatron_config=megatron_config,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
parallel_model.cuda() parallel_model.cuda()
return parallel_model return parallel_model
actor_module = get_model(model_provider_func=megatron_actor_model_provider, actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder, model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True) wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module) actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({ optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = init_megatron_optim_config(optim_config) optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config self.optimizer_config = optim_config
...@@ -126,13 +123,15 @@ class Trainer(MegatronWorker): ...@@ -126,13 +123,15 @@ class Trainer(MegatronWorker):
position_ids = data.batch['position_ids'] position_ids = data.batch['position_ids']
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration # update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward() output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config, self.megatron_config.timers) self.megatron_config.timers)
return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0])) return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))
...@@ -142,11 +141,12 @@ if __name__ == '__main__': ...@@ -142,11 +141,12 @@ if __name__ == '__main__':
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer) cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool, worker_group = NVMegatronRayWorkerGroup(
ray_cls_with_init=cls_with_init_args, resource_pool=resource_pool,
name_prefix='trainer', ray_cls_with_init=cls_with_init_args,
detached=True, name_prefix='trainer',
) detached=True,
)
worker_group.init_model() worker_group.init_model()
......
...@@ -30,6 +30,8 @@ def update_hf_weight_loader(): ...@@ -30,6 +30,8 @@ def update_hf_weight_loader():
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict) assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items()) vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules(): for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import logging import logging
import torch import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from verl.third_party.vllm import LLM from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps from verl.third_party.vllm import parallel_state as vllm_ps
...@@ -42,7 +42,7 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -42,7 +42,7 @@ class FSDPVLLMShardingManager(BaseShardingManager):
if full_params: if full_params:
FSDP.set_state_dict_type(self.module, FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=ShardedStateDictConfig()) state_dict_config=FullStateDictConfig())
else: else:
FSDP.set_state_dict_type(self.module, FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_type=StateDictType.SHARDED_STATE_DICT,
......
...@@ -72,6 +72,8 @@ class vLLMRollout(BaseRollout): ...@@ -72,6 +72,8 @@ class vLLMRollout(BaseRollout):
"disable CUDA graph (enforce_eager = False) if free cache engine" "disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
if kwargs.get('train_tp', None) is not None: if kwargs.get('train_tp', None) is not None:
# deployed with megatron # deployed with megatron
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment