Unverified Commit cfc976b7 by Guangming Sheng Committed by GitHub

[misc] fix issue in hf_weight_loader and fix typo in doc (#30)

* [fix] fix some bugs related to hf_weight_loader

* [doc] fix doc typo

* [ci] fix github action

* [ci] lint

* [doc] fix typo
parent c5a09641
...@@ -24,6 +24,8 @@ jobs: ...@@ -24,6 +24,8 @@ jobs:
python-version: ["3.12"] python-version: ["3.12"]
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.head_ref }} # Checkout the branch associated with the pull request
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
......
...@@ -30,7 +30,7 @@ into two parts: ...@@ -30,7 +30,7 @@ into two parts:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/gsm8k') parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')
parser.add_argument('--hdfs_dir', default='hdfs://haruna/home/byte_data_seed/lf_lq/user/zhangchi.usc1992/data/rlhf') parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -37,7 +37,8 @@ if __name__ == '__main__': ...@@ -37,7 +37,8 @@ if __name__ == '__main__':
# get the worker group using names # get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1'] worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer) cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)
batch_size = 16 batch_size = 16
sequence_length = 1024 sequence_length = 1024
......
...@@ -107,10 +107,7 @@ class Trainer(MegatronWorker): ...@@ -107,10 +107,7 @@ class Trainer(MegatronWorker):
wrap_with_ddp=True) wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module) actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({ optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = init_megatron_optim_config(optim_config) optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config self.optimizer_config = optim_config
...@@ -126,13 +123,15 @@ class Trainer(MegatronWorker): ...@@ -126,13 +123,15 @@ class Trainer(MegatronWorker):
position_ids = data.batch['position_ids'] position_ids = data.batch['position_ids']
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration # update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward() output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config, self.megatron_config.timers) self.megatron_config.timers)
return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0])) return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))
...@@ -142,7 +141,8 @@ if __name__ == '__main__': ...@@ -142,7 +141,8 @@ if __name__ == '__main__':
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer) cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool, worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args, ray_cls_with_init=cls_with_init_args,
name_prefix='trainer', name_prefix='trainer',
detached=True, detached=True,
......
...@@ -30,6 +30,8 @@ def update_hf_weight_loader(): ...@@ -30,6 +30,8 @@ def update_hf_weight_loader():
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict) assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items()) vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules(): for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import logging import logging
import torch import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig
from verl.third_party.vllm import LLM from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps from verl.third_party.vllm import parallel_state as vllm_ps
...@@ -42,7 +42,7 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -42,7 +42,7 @@ class FSDPVLLMShardingManager(BaseShardingManager):
if full_params: if full_params:
FSDP.set_state_dict_type(self.module, FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=ShardedStateDictConfig()) state_dict_config=FullStateDictConfig())
else: else:
FSDP.set_state_dict_type(self.module, FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_type=StateDictType.SHARDED_STATE_DICT,
......
...@@ -72,6 +72,8 @@ class vLLMRollout(BaseRollout): ...@@ -72,6 +72,8 @@ class vLLMRollout(BaseRollout):
"disable CUDA graph (enforce_eager = False) if free cache engine" "disable CUDA graph (enforce_eager = False) if free cache engine"
tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"
if kwargs.get('train_tp', None) is not None: if kwargs.get('train_tp', None) is not None:
# deployed with megatron # deployed with megatron
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment