Unverified Commit 16b9e493 by Guangming Sheng Committed by GitHub

[ci] feat: add test files for ray hybrid programming model (#23)

* [ci] update some tests for hybrid programming model

* [ci] update detached worker tests
parent 9665a060
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import sys
import os
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.base.decorator import register, Dispatch
@ray.remote
class TestActor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def foo(self, wait_time):
time.sleep(wait_time)
sys.exit(1)
if __name__ == "__main__":
wait_time = int(os.getenv("WAIT_TIME", "10"))
ray.init()
# test single-node-no-partition
print(f"test single-node-no-partition")
resource_pool = RayResourcePool([8], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
print("create worker group")
wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test")
wg.start_worker_aliveness_check(1)
time.sleep(1)
print(time.time(), "start foo")
_ = wg.foo(wait_time)
print("foo started")
print(time.time(),
f"wait 6x wait time {wait_time*6} to let signal returned to process but still not exceed process wait time")
time.sleep(wait_time * 6)
ray.shutdown()
# Detached Worker
## How to run (Only on a single node)
- Start a local ray cluster:
```bash
ray start --head --port=6379
```
- Run the server
```bash
python3 server.py
```
- On another terminal, Run the client
```bash
python3 client.py
```
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In client, we can get the server handler and send RPC request
"""
import ray
import torch
from verl import DataProto
from single_controller.ray import RayClassWithInitArgs
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from tensordict import TensorDict
from server import Trainer
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
if __name__ == '__main__':
ray.init(address='auto', namespace='verl')
# get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
batch_size = 16
sequence_length = 1024
# give Trainer some data to train
input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device='cuda')
attention_mask = torch.ones_like(input_ids)
position_ids = compute_position_id_with_mask(attention_mask)
data = DataProto(batch=TensorDict(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}, batch_size=batch_size),
meta_info={})
output = worker_group.train_model(data)
print(output)
#!/bin/bash
ray start --head --port=6379
python3 server.py
python3 client.py
ray stop --force
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Server starts a Trainer. Client sends data to the server to train.
"""
import os
os.environ['MEGATRON_USE_CUDA_TIMER'] = '0'
os.environ['MEGATRON_START_PROCESS_TIMER'] = 'False'
os.environ['NCCL_DEBUG'] = 'WARN'
import torch
from torch import nn
import ray
from single_controller.ray import RayClassWithInitArgs, RayResourcePool
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from single_controller.base.megatron.worker import MegatronWorker
from single_controller.ray.decorator import register, Dispatch
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import tensor_parallel
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config
from verl.utils.megatron.optimizer import get_megatron_optimizer
from transformers import LlamaConfig
from omegaconf import OmegaConf
from tensordict import TensorDict
@ray.remote
class Trainer(MegatronWorker):
def __init__(self):
super().__init__()
if not torch.distributed.is_initialized():
rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
mpu.initialize_model_parallel(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
tensor_parallel.model_parallel_cuda_manual_seed(10)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True,
'param_dtype': 'bf16',
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
})
megatron_config = init_model_parallel_config(megatron_config)
self.megatron_config = megatron_config
def megatron_actor_model_provider(pre_process, post_process):
# vpp is not supported yet because it will hang for some reason. Need debugging
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() # this will be set inside get_model
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
parallel_model.cuda()
return parallel_model
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
self.model = actor_module[0]
self.optimizer = actor_optimizer
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def train_model(self, data: DataProto) -> DataProto:
input_ids = data.batch['input_ids']
attention_mask = data.batch['attention_mask']
position_ids = data.batch['position_ids']
self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers)
return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))
if __name__ == '__main__':
ray.init(address='auto', namespace='verl')
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)
worker_group.init_model()
worker_names = worker_group.worker_names
print(worker_names)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import subprocess
def test():
wait_time = 10
my_env = os.environ.copy()
my_env["WAIT_TIME"] = str(wait_time)
p = subprocess.Popen(["python3", "-u", "./check_worker_alive/main.py"], env=my_env, stdout=subprocess.PIPE)
count = 0
while b"foo started" not in p.stdout.read():
time.sleep(1)
count += 1
if count > 40:
raise RuntimeError("timeout for start foo in check_worker_alive/main.py")
print(
time.time(),
f"wait 1.5 wait time {wait_time*1.5} to let signal returned to process but still not exceed process wait time")
time.sleep(wait_time * 1.5)
print(time.time(), f"start checking")
assert p.poll() is not None, f"process {p} still alive, expecting signal raised abort"
assert p.returncode != 0, f"process {p} exit with code 0, expecting not-zero exit code"
print(f"test passed")
if __name__ == "__main__":
test()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from single_controller.base import Worker
from single_controller.base.decorator import register, Dispatch
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from verl import DataProto
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch['a'] += self.rank
return data
@ray.remote
class Critic(Worker):
def __init__(self, config) -> None:
super().__init__()
self.config = config
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def sub(self, data: DataProto):
data.batch['a'] -= self.config['b']
return data
def test_colocated_workers():
ray.init()
import torch
data = DataProto.from_dict({'a': torch.zeros(10)})
# create separate workers on the same resource pool
actor_cls = RayClassWithInitArgs(cls=Actor)
critic_cls = RayClassWithInitArgs(cls=Critic, config={'b': 10})
resource_pool = RayResourcePool(process_on_nodes=[2])
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)
expected_actor_output = actor_wg.add(data)
expected_critic_output = critic_wg.sub(data)
# create colocated workers
cls_dict = {'actor': actor_cls, 'critic': critic_cls}
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
colocated_actor_wg = spawn_wg['actor']
colocated_critic_wg = spawn_wg['critic']
actor_output = colocated_actor_wg.add(data)
critic_output = colocated_critic_wg.sub(data)
torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In this test, we instantiate a data parallel worker with 8 GPUs
"""
from single_controller.base import Worker
from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from single_controller.base.decorator import Dispatch, register
import ray
import torch
from torch import distributed as dist
from verl import DataProto
from verl.utils.ray_utils import parallel_put
from codetiming import Timer
import tensordict
@ray.remote
class DummyWorker(Worker):
def __init__(self):
super().__init__()
dist.init_process_group()
@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
def do_nothing(self, data):
for key in data.batch.keys():
data.batch[key] += 1
if tensordict.__version__ >= '0.5.0':
data.batch = data.batch.consolidate()
return data
def test_data_transfer():
ray.init()
# construct resource pool
resource_pool = RayResourcePool([8])
cls_with_init = RayClassWithInitArgs(cls=DummyWorker)
# construct worker group
wg = RayWorkerGroup(resource_pool, cls_with_init)
# this is real dataset size
batch_size = 4096
seqlen = 32768
data_dict = {}
for i in range(2):
data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))
data = DataProto.from_dict(tensors=data_dict)
print(data)
# we manually split data here and send to each worker
data_list = data.chunk(wg.world_size)
for i in range(wg.world_size):
# consolidate is necessary
if tensordict.__version__ >= '0.5.0':
data_list[i].batch = data_list[i].batch.consolidate()
with Timer(name='ray.pickle', initial_text=True):
for i in range(wg.world_size):
ray.cloudpickle.pickle.dumps(data_list[i])
with Timer(name='raw.pickle', initial_text=True):
import pickle
for i in range(wg.world_size):
pickle.dumps(data_list[i])
# we put in advance
with Timer(name='put', initial_text=True):
# takes around 40 seconds
data_list_ref = parallel_put(data_list)
# for i in range(wg.world_size):
# data_list[i] = ray.put(data_list[i])
with Timer(name='launch', initial_text=True):
output_ref = wg.do_nothing(data_list_ref)
with Timer(name='get', initial_text=True):
# takes around 40 seconds
output_lst = ray.get(output_ref)
for input_data, output_data in zip(data_list, output_lst):
for key in input_data.batch.keys():
assert torch.all(torch.eq(input_data.batch[key] + 1,
output_data.batch[key])), (input_data.batch[key], output_data.batch[key], key)
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ray
import torch
from verl import DataProto
from tensordict import TensorDict
from single_controller.base.worker import Worker
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from single_controller.ray import RayWorkerGroup
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'
@ray.remote
class ModelActor(Worker):
def __init__(self):
pass
class HackSelf():
def __init__(self):
pass
def get_aux_metrics(self, test_proto):
sequence_ids = test_proto.batch["sequence_ids"]
decode_count = []
for i in range(sequence_ids.size(0)):
decode_count.append(len(sequence_ids[i].tolist()))
ret_proto = DataProto(batch=TensorDict({
"sequence_ids": sequence_ids,
"decode_count": torch.tensor(decode_count)
},
batch_size=sequence_ids.size(0)))
return ret_proto
def test():
# construct model
ray.init()
# create 8 workers, each hold a GPU
resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a')
class_with_args = RayClassWithInitArgs(cls=ModelActor)
shard_wg = RayWorkerGroup(resource_pool, class_with_args)
test_bs = 8
test_proto = DataProto(TensorDict({
"sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
},
batch_size=test_bs),
meta_info={"query_length": 1536})
# Sharding among different ranks
ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
# compare execute on driver
hs = HackSelf()
ret_proto2 = get_aux_metrics(hs, test_proto)
torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool
from single_controller.base.worker import Worker
@ray.remote
class TestActor(Worker):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def __init__(self, cuda_visible_devices=None) -> None:
super().__init__(cuda_visible_devices)
def get_node_id(self):
return ray.get_runtime_context().get_node_id()
def test():
ray.init()
# test single-node-no-partition
print(f"test single-node-no-partition")
resource_pool = RayResourcePool([8], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
print("create actor worker group")
actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor")
print("create critic worker group")
critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic")
print("create rm worker group")
rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm")
print("create ref worker group")
ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref")
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
del actor_wg
del critic_wg
del rm_wg
del ref_wg
[ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
print("wait 5s to remove placemeng_group")
time.sleep(5)
# test single-node-multi-partition
print(f"test single-node-multi-partition")
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
assert rm_resource_pool.world_size == 4
assert ref_resource_pool.world_size == 4
assert total_resource_pool.world_size == 8
actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor")
critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic")
rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm")
ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref")
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
e2e test single_controller.ray
"""
import os
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
@ray.remote
class TestActor(Worker):
def __init__(self) -> None:
super().__init__()
def getenv(self, key):
val = os.getenv(key, f"{key} not set")
return val
def test_basics():
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
worker_group = RayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=class_with_args,
name_prefix="worker_group_basic")
output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE")
assert output == ["4", "4", "4", "4"]
output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_RANK")
assert set(output) == set(["0", "1", "2", "3"])
ray.shutdown()
if __name__ == '__main__':
test_basics()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from single_controller.remote import remote, RemoteBackend, SharedResourcePool
from single_controller.base.decorator import register, Dispatch
from single_controller.base.worker import Worker
@remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool)
class Actor(Worker):
...
@remote(process_on_nodes=[3], use_gpu=True, name_prefix="critic", sharing=SharedResourcePool)
class Critic(Worker):
...
@remote(process_on_nodes=[2], use_gpu=True, name_prefix="reward", sharing=SharedResourcePool.from_role("actor"))
class Reward(Worker):
...
@remote(process_on_nodes=[2], use_gpu=True, name_prefix="ref", sharing=SharedResourcePool.from_role("actor", "critic"))
class Ref(Worker):
...
@remote(process_on_nodes=[1], use_gpu=True, name_prefix="sec_rm", sharing=SharedResourcePool.from_role("any"))
class SecRM(Worker):
...
def test():
print("Remote.init_distributed")
remote.init_distributed(backend=RemoteBackend.RAY)
print("create actor worker group")
actor = Actor()
print("create critic worker group")
critic = Critic()
print("create rm worker group")
reward = Reward()
print("create ref worker group")
ref = Ref()
print("create sec_rm worker group")
sec_rm = SecRM()
actor_gpus = actor.execute_all_sync("get_cuda_visible_devices")
critic_gpus = critic.execute_all_sync("get_cuda_visible_devices")
reward_gpus = reward.execute_all_sync("get_cuda_visible_devices")
ref_gpus = ref.execute_all_sync("get_cuda_visible_devices")
sec_rm_gpus = sec_rm.execute_all_sync("get_cuda_visible_devices")
for gpu in actor_gpus:
assert gpu not in critic_gpus, f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}"
for gpu in critic_gpus:
assert gpu not in actor_gpus, f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}"
for gpu in reward_gpus:
assert gpu in actor_gpus, f"actor gpus = {actor_gpus}, reward gpus = {reward_gpus}"
for gpu in ref_gpus:
assert gpu in actor_gpus + critic_gpus, \
f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}, ref gpus = {ref_gpus}"
for gpu in sec_rm_gpus:
assert gpu in actor_gpus + critic_gpus, \
f"actor gpus = {actor_gpus}, critic gpus = {critic_gpus}, sec rm gpus = {sec_rm_gpus}"
# for ci only
import ray
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
@ray.remote
class TestWorker:
def __init__(self, rank, world_size, group_name):
self.rank = rank
self.world_size = world_size
self.group_name = group_name
self.communicator = None
def init(self):
from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray
self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)
def test(self):
if self.communicator is None:
return None
return self.communicator.rank_id()
def test_rvdz():
ray.init()
group_name = "test_group"
world_size = 4
workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]
[worker.init.remote() for worker in workers]
ranks = ray.get([worker.test.remote() for worker in workers])
assert ranks == [0, 1, 2, 3], f"expecting [0, 1, 2, 3], got {ranks}"
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
e2e test single_controller.ray
"""
import torch
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
"""
Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.
"""
for arg in args:
assert len(arg) == 2
for i in range(worker_group.world_size - 2):
arg.append(arg[i % 2])
for k, v in kwargs.items():
assert len(v) == 2
for i in range(worker_group.world_size - 2):
v.append(v[i % 2])
return args, kwargs
@ray.remote
class TestActor(Worker):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def __init__(self, x) -> None:
super().__init__()
self._x = x
def foo(self, y):
return self._x + y
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
def foo_rank_zero(self, x, y):
return self._x + y + x
@register(Dispatch.ONE_TO_ALL, blocking=False)
def foo_one_to_all(self, x, y):
return self._x + y + x
@register(Dispatch.ALL_TO_ALL, blocking=False)
def foo_all_to_all(self, x, y):
return self._x + y + x
@register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all})
def foo_custom(self, x, y):
return self._x + y + x
@ray.remote(num_gpus=0.1)
def remote_call_wg(worker_names):
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args)
print(worker_group.worker_names)
output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])
assert output_ref == [8, 10, 8, 10]
output_ref = worker_group.foo_rank_zero(x=1, y=2)
assert output_ref == 5
return worker_group.worker_names
def add_one(data):
data = data.to("cuda")
data += 1
data = data.to("cpu")
return data
def test_basics():
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
worker_group = RayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=class_with_args,
name_prefix="worker_group_basic")
print(worker_group.worker_names)
# this will wait for all the results
output = worker_group.execute_all_sync("foo", y=3)
assert output == [5, 5, 5, 5]
# this is a list of object reference. It won't block.
output_ref = worker_group.execute_all_async("foo", y=4)
print(output_ref)
assert ray.get(output_ref) == [6, 6, 6, 6]
output_ref = worker_group.foo_one_to_all(x=1, y=2)
assert ray.get(output_ref) == [5, 5, 5, 5]
output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8])
assert ray.get(output_ref) == [8, 10, 12, 14]
print(ray.get(remote_call_wg.remote(worker_group.worker_names)))
output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2))
torch.testing.assert_close(output, torch.ones(2, 2) + 1)
ray.shutdown()
if __name__ == '__main__':
test_basics()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'
import torch
import torch.distributed
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
@ray.remote
class TestAllGatherActor(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
def init(self):
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda')
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(size=(self.tensor.shape[0] * world_size,),
dtype=self.tensor.dtype,
device=self.tensor.device)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
@ray.remote
class TestAllGatherActorV2(Worker):
def __init__(self, size) -> None:
super().__init__()
self.size = size
torch.distributed.init_process_group()
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device='cuda')
self.tensor += self.rank
def all_gather(self):
world_size = self._world_size
output = torch.zeros(size=(self.tensor.shape[0] * world_size,),
dtype=self.tensor.dtype,
device=self.tensor.device)
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
return output
def test_all_gather_torch():
"""
In this test, we instantiate 4 GPUs in a group and test the all_gather
"""
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
worker_group.execute_all_sync('init')
output = worker_group.execute_all_sync('all_gather')
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])
output = output[0].cpu()
print(output)
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
ray.shutdown()
def test_all_gather_torch_v2():
"""
In this test, we instantiate 4 GPUs in a group and test the all_gather
"""
ray.init()
# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)
worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch")
output = worker_group.execute_all_sync('all_gather')
for i in range(1, len(output)):
assert torch.all(output[i] == output[0])
output = output[0].cpu()
print(output)
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
ray.shutdown()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment