Unverified Commit 0cc2bdad by Chi Zhang Committed by GitHub

[misc] feat: add allgather method to dataproto (#497)

- Add allgather method to dataproto
- Add tests
- Replace existing raw allgather with this function
parent 4a291fa7
...@@ -72,3 +72,6 @@ jobs: ...@@ -72,3 +72,6 @@ jobs:
run: | run: |
pip3 install transformers==4.45.0 pip3 install transformers==4.45.0
torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py torchrun --nproc_per_node=8 -m pytest tests/model/test_transformers_ulysses.py
- name: Run distributed test
run: |
bash tests/distributed/run_all.sh
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env bash
set -e -x
torchrun --nproc-per-node=4 --standalone tests/distributed/test_tensor_dict.py
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['NCCL_DEBUG'] = 'WARN'
from verl.protocol import all_gather_data_proto, DataProto
from verl.utils.distributed import initialize_global_process_group
import torch
import torch.distributed
import numpy as np
def test_all_gather_data_proto():
device_mesh = torch.distributed.device_mesh.init_device_mesh('cuda', mesh_shape=[2, 2], mesh_dim_names=['dp', 'tp'])
global_rank = torch.distributed.get_rank()
obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
labels = ['a', 'b'] if global_rank % 2 == 0 else ['b', 'a']
labels = np.array(labels, dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
all_gather_data_proto(data=data, process_group=device_mesh.get_group('dp'))
if global_rank == 0:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
elif global_rank == 1:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
elif global_rank == 2:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
elif global_rank == 3:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
torch.testing.assert_close(data.batch['obs'], expected_obs, atol=0, rtol=0)
assert (data.non_tensor_batch['labels'] == expected_labels).all()
assert data.meta_info == {'info': 'test_info'}
if __name__ == '__main__':
local_rank, rank, world_size = initialize_global_process_group()
test_all_gather_data_proto()
...@@ -644,3 +644,21 @@ class DataProtoFuture: ...@@ -644,3 +644,21 @@ class DataProtoFuture:
if self.dispatch_fn is not None: if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp output = self.dispatch_fn(output) # split in batch dim, select using dp
return output return output
from verl.utils.torch_functional import allgather_dict_tensors
import torch.distributed
def all_gather_data_proto(data: DataProto, process_group):
# Note that this is an inplace operator just like torch.distributed.all_gather
group_size = torch.distributed.get_world_size(group=process_group)
assert isinstance(data, DataProto)
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(group_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
...@@ -19,6 +19,7 @@ from .base import BaseShardingManager ...@@ -19,6 +19,7 @@ from .base import BaseShardingManager
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from verl.utils.torch_functional import allgather_dict_tensors from verl.utils.torch_functional import allgather_dict_tensors
from verl.protocol import all_gather_data_proto
from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group
import numpy as np import numpy as np
...@@ -63,16 +64,7 @@ class FSDPUlyssesShardingManager(BaseShardingManager): ...@@ -63,16 +64,7 @@ class FSDPUlyssesShardingManager(BaseShardingManager):
sp_size = self.device_mesh['sp'].size() sp_size = self.device_mesh['sp'].size()
group = self.device_mesh['sp'].get_group() group = self.device_mesh['sp'].get_group()
prev_device = data.batch.device all_gather_data_proto(data=data, process_group=group)
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(sp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {
k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch
}
return data return data
def postprocess_data(self, data: DataProto) -> DataProto: def postprocess_data(self, data: DataProto) -> DataProto:
......
...@@ -24,6 +24,7 @@ from verl.third_party.vllm import LLM ...@@ -24,6 +24,7 @@ from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps from verl.third_party.vllm import parallel_state as vllm_ps
from verl import DataProto from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
from verl.protocol import all_gather_data_proto
from verl.utils.debug import log_gpu_memory_usage from verl.utils.debug import log_gpu_memory_usage
from verl.third_party.vllm import vllm_version from verl.third_party.vllm import vllm_version
...@@ -134,14 +135,7 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -134,14 +135,7 @@ class FSDPVLLMShardingManager(BaseShardingManager):
else: else:
group = vllm_ps.get_tensor_model_parallel_group().device_group group = vllm_ps.get_tensor_model_parallel_group().device_group
prev_device = data.batch.device all_gather_data_proto(data=data, process_group=group)
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=tp_size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(tp_size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
return data return data
def postprocess_data(self, data: DataProto) -> DataProto: def postprocess_data(self, data: DataProto) -> DataProto:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment