Unverified Commit 884a7273 by Zhiqi Lin Committed by GitHub

[perf] feat: support meta device init and parallel load for fsdp (#123)

This PR supports:
- meta device init (which keeps the shared parameters)
- parallel pre-trained weight init for FSDP from huggingface checkpoint

---------

Co-authored-by: zhiqi.0 <zhiqi.0@bytedance.com>
parent cd52d8b3
......@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import functools
import json
import math
import itertools
import os
from contextlib import contextmanager
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from transformers.trainer_pt_utils import get_module_class_from_name
import torch
import torch.nn as nn
import torch.distributed as dist
def init_fn(x: torch.nn.Module):
......@@ -120,3 +128,173 @@ def load_fsdp_optimizer(optimizer, device_id):
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
torch.cuda.empty_cache()
@contextmanager
def meta_device_init():
"""
Create model parameters with meta device.
Note buffers in model will still be initialized in default device (e.g., CPU),
since the buffers can be non-persistent and filled with expected values that can
NOT be captured in meta device.
"""
device = torch.device("meta")
old_register_parameter = nn.Module.register_parameter
registered = set()
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
# we will skip register shared parameters as it
# is already registered previously
if param is not None and param not in registered:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
registered.add(module._parameters[name])
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
registered.clear()
nn.Module.register_parameter = old_register_parameter
def parallel_load_safetensors(filepath):
"""
Parallel load safetensors from huggingface checkpoint
Huggingface checkpoint contains:
- config.json: a json file for model configuration
- model.safetensor.index.json: a json file for safetensors (parameters & buffers) index
- model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks
Or (when model is small),
- model.safetensors: a binary file for all parameters and buffers
Each rank will own a part of model chunks and load them directly into GPU memory.
"""
from safetensors.torch import load_file
safetensors2param = {}
index_file = os.path.join(filepath, "model.safetensors.index.json")
if os.path.exists(index_file):
index = json.load(open(index_file, "rb"))
for param_name, filename in index["weight_map"].items():
safetensors2param.setdefault(filename, []).append(param_name)
else:
# in this case, the model is small and we can load it all at once
param_file = os.path.join(filepath, "model.safetensors")
assert os.path.exists(param_file), f"Cannot find {param_file}"
states = load_file(param_file)
for param_name in states:
safetensors2param.setdefault("model.safetensors", []).append(param_name)
del states
total_files = len(safetensors2param)
ckpt_chunks = sorted(safetensors2param.keys())
world_size = dist.get_world_size()
size = int(math.ceil(total_files / world_size))
ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)]
shard_states = {}
device = torch.cuda.current_device()
for rank, files in enumerate(ckpt_chunks):
if rank == dist.get_rank():
for file in files:
file = os.path.join(filepath, file)
states = load_file(file, device=device)
# print(f"rank {rank} loading {file}...")
shard_states.update(states)
else:
for file in files:
for param_name in safetensors2param[file]:
shard_states[param_name] = rank
return shard_states
def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):
"""
Generate a function to initialize sub-modules in the `module` with `shard_states`
from huggingface checkpoint.
Args:
module (torch.nn.Module): the global module to be initialized
shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint
Returns:
init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`
"""
state2fqn = {}
for name, state in itertools.chain(module.named_parameters(remove_duplicate=False),
module.named_buffers(remove_duplicate=False)):
state2fqn.setdefault(state, []).append(name)
# remove standalone parameters and buffers
shared = {s for s, names in state2fqn.items() if len(names) > 1}
materialized_states = {}
@torch.no_grad()
def create_and_sync_state(param_name, state, is_param):
assert param_name in shard_states, f"{param_name} not loaded"
device = torch.cuda.current_device()
if is_param:
param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)
else: # buffer
param = torch.empty_like(state.data, device=device)
loaded = shard_states[param_name]
if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
# NOTE: loaded.dtype can be different with param.dtype
param.data.copy_(loaded.data)
dist.broadcast(param.data, src=dist.get_rank())
else:
assert isinstance(loaded, int) # the rank that holds the state
dist.broadcast(param.data, src=loaded)
shard_states.pop(param_name)
del loaded
return param
def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):
param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))
# param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])
for name, state in param_and_buffers:
if not state.is_meta:
continue
is_param = name in sub_mod._parameters
fqn = state2fqn[state].pop(0)
# non-persistent buffers will not be saved in state dict, we can safely skip it
if (not is_param) and fqn not in shard_states:
if state.is_meta:
raise RuntimeError(
f"find a non-persistent buffer ({fqn}) initiated with device meta. "
"Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.")
continue
# for shared parameter, we get it from the first time it is created
if state in shared:
if state not in materialized_states:
materialized_states[state] = create_and_sync_state(fqn, state, is_param)
else:
if fqn in shard_states:
shard_states.pop(fqn)
materialize_state = materialized_states[state]
# for not shared parameter, we create it directly
else:
materialize_state = create_and_sync_state(fqn, state, is_param)
if is_param:
sub_mod._parameters[name] = materialize_state
else:
sub_mod._buffers[name] = materialize_state
if recurse:
for module in sub_mod.children():
init_fn(module, recurse=True)
# for debug
# if len(shard_states) == 0: print("clear")
return sub_mod
return init_fn
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment