Unverified Commit f7e183e4 by Joel Committed by GitHub

fix: remove redundant broadcast in fsdp vllm postprocess (#577)

Remove redundant broadcast in fsdp vllm postprocess since vllm output in
each tp rank should be identical.
parent 3fc3e2b7
...@@ -58,6 +58,9 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -58,6 +58,9 @@ class FSDPVLLMShardingManager(BaseShardingManager):
state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig()) state_dict_config=ShardedStateDictConfig())
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
# Note that torch_random_states may be different on each dp rank # Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state() self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states # get a random rng states
...@@ -135,8 +138,11 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -135,8 +138,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
torch.cuda.set_rng_state(self.torch_random_states) torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto: def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp # TODO: Current impl doesn't consider FSDP with torch micro-dp
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
group = vllm_ps.get_tensor_model_parallel_group() group = vllm_ps.get_tensor_model_parallel_group()
else: else:
...@@ -146,20 +152,8 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -146,20 +152,8 @@ class FSDPVLLMShardingManager(BaseShardingManager):
return data return data
def postprocess_data(self, data: DataProto) -> DataProto: def postprocess_data(self, data: DataProto) -> DataProto:
# TODO: Current impl doesn't consider FSDP with torch micro-dp """Get chunk data of this tp rank since we do all gather in preprocess."""
local_world_size = vllm_ps.get_tensor_model_parallel_world_size() if self.tp_size == 1:
src_rank = (torch.distributed.get_rank() // local_world_size) * local_world_size return data
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
broadcast_dict_tensor(data.batch, src=src_rank, group=vllm_ps.get_tensor_model_parallel_group()) return data.chunk(chunks=self.tp_size)[self.tp_rank]
else:
broadcast_dict_tensor(data.batch,
src=src_rank,
group=vllm_ps.get_tensor_model_parallel_group().device_group)
dp_rank = torch.distributed.get_rank()
dp_size = torch.distributed.get_world_size() # not consider torch micro-dp
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if tp_size > 1:
# TODO: shall we build a micro_dp group for vllm when integrating with vLLM?
local_prompts = data.chunk(chunks=tp_size)
data = local_prompts[dp_rank % tp_size]
return data
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment