Unverified Commit cd6cef60 by Guangming Sheng Committed by GitHub

[BREAKING][core] move single_controller into verl directory (#45)

* [BREAKING][core] move single_controller into verl directory

* fix blocking flag in fsdp workers
parent 1b24a3a8
...@@ -42,4 +42,4 @@ jobs: ...@@ -42,4 +42,4 @@ jobs:
pip install toml==0.10.2 pip install toml==0.10.2
- name: Running yapf - name: Running yapf
run: | run: |
yapf -r -vv -d --style=./.style.yapf verl tests single_controller examples yapf -r -vv -d --style=./.style.yapf verl tests examples
...@@ -180,7 +180,7 @@ pip3 install yapf ...@@ -180,7 +180,7 @@ pip3 install yapf
``` ```
Then, make sure you are at top level of verl repo and run Then, make sure you are at top level of verl repo and run
```bash ```bash
yapf -ir -vv --style ./.style.yapf verl single_controller examples yapf -ir -vv --style ./.style.yapf verl examples
``` ```
......
...@@ -47,8 +47,8 @@ Implementation details: ...@@ -47,8 +47,8 @@ Implementation details:
.. code:: python .. code:: python
from single_controller.base import Worker from verl.single_controller.base import Worker
from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray import ray
@ray.remote @ray.remote
...@@ -75,7 +75,7 @@ API: compute reference log probability ...@@ -75,7 +75,7 @@ API: compute reference log probability
.. code:: python .. code:: python
from single_controller.base import Worker from verl.single_controller.base import Worker
import ray import ray
@ray.remote @ray.remote
...@@ -93,7 +93,7 @@ API: Update actor model parameters ...@@ -93,7 +93,7 @@ API: Update actor model parameters
.. code:: python .. code:: python
from single_controller.base import Worker from verl.single_controller.base import Worker
import ray import ray
@ray.remote @ray.remote
...@@ -184,7 +184,7 @@ registered into the worker_group** ...@@ -184,7 +184,7 @@ registered into the worker_group**
.. code:: python .. code:: python
from single_controller.base.decorator import register from verl.single_controller.base.decorator import register
def dispatch_data(worker_group, data): def dispatch_data(worker_group, data):
return data.chunk(worker_group.world_size) return data.chunk(worker_group.world_size)
...@@ -214,11 +214,11 @@ computation, and data collection. ...@@ -214,11 +214,11 @@ computation, and data collection.
Furthermore, the model parallelism size of each model is usually fixed, Furthermore, the model parallelism size of each model is usually fixed,
including dp, tp, pp. So for these common distributed scenarios, we have including dp, tp, pp. So for these common distributed scenarios, we have
pre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations. pre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations.
.. code:: python .. code:: python
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto: def generate_sequences(self, data: DataProto) -> DataProto:
......
...@@ -49,13 +49,13 @@ Define worker classes ...@@ -49,13 +49,13 @@ Define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM
else: else:
......
...@@ -40,7 +40,7 @@ We implement various of APIs for each ``Worker`` class decorated by the ...@@ -40,7 +40,7 @@ We implement various of APIs for each ``Worker`` class decorated by the
``@register(dispatch_mode=)`` . These APIs can be called by the ray ``@register(dispatch_mode=)`` . These APIs can be called by the ray
driver process. The data can be correctly collect and dispatch following driver process. The data can be correctly collect and dispatch following
the ``dispatch_mode`` on each function. The supported dispatch_model the ``dispatch_mode`` on each function. The supported dispatch_model
(i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/single_controller/base/decorator.py>`_. (i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_.
ActorRolloutRefWorker ActorRolloutRefWorker
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
......
...@@ -232,8 +232,8 @@ ...@@ -232,8 +232,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n", "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n",
"from single_controller.base import Worker" "from verl.single_controller.base import Worker"
] ]
}, },
{ {
...@@ -437,7 +437,7 @@ ...@@ -437,7 +437,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from single_controller.ray.decorator import register, Dispatch, Execute" "from verl.single_controller.ray.decorator import register, Dispatch, Execute"
] ]
}, },
{ {
...@@ -518,7 +518,7 @@ ...@@ -518,7 +518,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute" "from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute"
] ]
}, },
{ {
...@@ -723,10 +723,10 @@ ...@@ -723,10 +723,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from single_controller.ray.decorator import register, Dispatch, Execute\n", "from verl.single_controller.ray.decorator import register, Dispatch, Execute\n",
"from single_controller.ray.megatron import NVMegatronRayWorkerGroup\n", "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n",
"from single_controller.base.megatron.worker import MegatronWorker\n", "from verl.single_controller.base.megatron.worker import MegatronWorker\n",
"from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n", "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n",
"from omegaconf import OmegaConf\n", "from omegaconf import OmegaConf\n",
"from megatron.core import parallel_state as mpu" "from megatron.core import parallel_state as mpu"
] ]
......
...@@ -121,13 +121,13 @@ def main_task(config): ...@@ -121,13 +121,13 @@ def main_task(config):
if config.actor_rollout_ref.actor.strategy == 'fsdp': if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup
else: else:
......
...@@ -16,7 +16,7 @@ An naive implementation of split placment example ...@@ -16,7 +16,7 @@ An naive implementation of split placment example
""" """
import os import os
from pprint import pprint from pprint import pprint
from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl import DataProto from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls
from codetiming import Timer from codetiming import Timer
......
...@@ -18,9 +18,9 @@ import os ...@@ -18,9 +18,9 @@ import os
import ray import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
@ray.remote @ray.remote
......
...@@ -19,8 +19,8 @@ import ray ...@@ -19,8 +19,8 @@ import ray
import torch import torch
from verl import DataProto from verl import DataProto
from single_controller.ray import RayClassWithInitArgs from verl.single_controller.ray import RayClassWithInitArgs
from single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from tensordict import TensorDict from tensordict import TensorDict
......
...@@ -25,10 +25,10 @@ import torch ...@@ -25,10 +25,10 @@ import torch
from torch import nn from torch import nn
import ray import ray
from single_controller.ray import RayClassWithInitArgs, RayResourcePool from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
from single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from single_controller.base.megatron.worker import MegatronWorker from verl.single_controller.base.megatron.worker import MegatronWorker
from single_controller.ray.decorator import register, Dispatch from verl.single_controller.ray.decorator import register, Dispatch
from verl import DataProto from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
import ray import ray
from single_controller.base import Worker from verl.single_controller.base import Worker
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from verl import DataProto from verl import DataProto
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
In this test, we instantiate a data parallel worker with 8 GPUs In this test, we instantiate a data parallel worker with 8 GPUs
""" """
from single_controller.base import Worker from verl.single_controller.base import Worker
from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from single_controller.base.decorator import Dispatch, register from verl.single_controller.base.decorator import Dispatch, register
import ray import ray
import torch import torch
......
...@@ -18,9 +18,9 @@ import torch ...@@ -18,9 +18,9 @@ import torch
from verl import DataProto from verl import DataProto
from tensordict import TensorDict from tensordict import TensorDict
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from single_controller.ray import RayWorkerGroup from verl.single_controller.ray import RayWorkerGroup
os.environ['RAY_DEDUP_LOGS'] = '0' os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN' os.environ['NCCL_DEBUG'] = 'WARN'
......
...@@ -16,8 +16,8 @@ import time ...@@ -16,8 +16,8 @@ import time
import ray import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
@ray.remote @ray.remote
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
e2e test single_controller.ray e2e test verl.single_controller.ray
""" """
import os import os
import ray import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
@ray.remote @ray.remote
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from single_controller.remote import remote, RemoteBackend, SharedResourcePool from verl.single_controller.remote import remote, RemoteBackend, SharedResourcePool
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
@remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool) @remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool)
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
e2e test single_controller.ray e2e test verl.single_controller.ray
""" """
import torch import torch
import ray import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
def two_to_all_dispatch_fn(worker_group, *args, **kwargs): def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
......
...@@ -21,8 +21,8 @@ import torch ...@@ -21,8 +21,8 @@ import torch
import torch.distributed import torch.distributed
import ray import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
@ray.remote @ray.remote
......
...@@ -75,7 +75,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): ...@@ -75,7 +75,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
""" """
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, assert isinstance(worker_group,
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}' MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
...@@ -104,7 +104,7 @@ def collect_megatron_compute(worker_group, output): ...@@ -104,7 +104,7 @@ def collect_megatron_compute(worker_group, output):
""" """
Only collect the data from the tp=0 and pp=last and every dp ranks Only collect the data from the tp=0 and pp=last and every dp ranks
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = [] output_in_dp = []
pp_size = worker_group.get_megatron_global_info().pp_size pp_size = worker_group.get_megatron_global_info().pp_size
...@@ -119,7 +119,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): ...@@ -119,7 +119,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
""" """
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
...@@ -162,7 +162,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): ...@@ -162,7 +162,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
""" """
treat pp as dp. treat pp as dp.
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
pp_size = worker_group.pp_size pp_size = worker_group.pp_size
...@@ -210,7 +210,7 @@ def collect_megatron_pp_as_dp(worker_group, output): ...@@ -210,7 +210,7 @@ def collect_megatron_pp_as_dp(worker_group, output):
""" """
treat pp as dp. Only collect data on tp=0 treat pp as dp. Only collect data on tp=0
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = [] output_in_dp = []
for global_rank in range(worker_group.world_size): for global_rank in range(worker_group.world_size):
...@@ -224,7 +224,7 @@ def collect_megatron_pp_only(worker_group, output): ...@@ -224,7 +224,7 @@ def collect_megatron_pp_only(worker_group, output):
""" """
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
""" """
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
output_in_pp = [] output_in_pp = []
for global_rank in range(worker_group.world_size): for global_rank in range(worker_group.world_size):
...@@ -235,7 +235,7 @@ def collect_megatron_pp_only(worker_group, output): ...@@ -235,7 +235,7 @@ def collect_megatron_pp_only(worker_group, output):
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
pp_dp_size = worker_group.dp_size * worker_group.pp_size pp_dp_size = worker_group.dp_size * worker_group.pp_size
...@@ -245,7 +245,7 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): ...@@ -245,7 +245,7 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
def collect_megatron_pp_as_dp_data_proto(worker_group, output): def collect_megatron_pp_as_dp_data_proto(worker_group, output):
from verl.protocol import DataProto from verl.protocol import DataProto
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup) assert isinstance(worker_group, MegatronWorkerGroup)
output = collect_megatron_pp_as_dp(worker_group, output) output = collect_megatron_pp_as_dp(worker_group, output)
...@@ -253,7 +253,7 @@ def collect_megatron_pp_as_dp_data_proto(worker_group, output): ...@@ -253,7 +253,7 @@ def collect_megatron_pp_as_dp_data_proto(worker_group, output):
def dispatch_dp_compute(worker_group, *args, **kwargs): def dispatch_dp_compute(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) assert isinstance(worker_group, WorkerGroup)
for arg in args: for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
...@@ -263,21 +263,21 @@ def dispatch_dp_compute(worker_group, *args, **kwargs): ...@@ -263,21 +263,21 @@ def dispatch_dp_compute(worker_group, *args, **kwargs):
def collect_dp_compute(worker_group, output): def collect_dp_compute(worker_group, output):
from single_controller.base.worker_group import WorkerGroup from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size assert len(output) == worker_group.world_size
return output return output
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) assert isinstance(worker_group, WorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) == FunctionType # NOTE: The first one args is a function! assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from single_controller.base.worker import Worker from verl.single_controller.base.worker import Worker
class DPEngineWorker(Worker): class DPEngineWorker(Worker):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
class MegatronWorker(Worker): class MegatronWorker(Worker):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from typing import Dict from typing import Dict
from .worker import DistRankInfo, DistGlobalInfo from .worker import DistRankInfo, DistGlobalInfo
from single_controller.base import ResourcePool, WorkerGroup from verl.single_controller.base import ResourcePool, WorkerGroup
class MegatronWorkerGroup(WorkerGroup): class MegatronWorkerGroup(WorkerGroup):
......
...@@ -17,7 +17,7 @@ the class for Worker ...@@ -17,7 +17,7 @@ the class for Worker
import os import os
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
@dataclass @dataclass
...@@ -43,7 +43,7 @@ class WorkerHelper: ...@@ -43,7 +43,7 @@ class WorkerHelper:
import ray import ray
return ray._private.services.get_node_ip_address() return ray._private.services.get_node_ip_address()
elif os.getenv("WG_BACKEND", None) == "torch_rpc": elif os.getenv("WG_BACKEND", None) == "torch_rpc":
from single_controller.torchrpc.k8s_client import get_ip_addr from verl.single_controller.torchrpc.k8s_client import get_ip_addr
return get_ip_addr() return get_ip_addr()
return None return None
...@@ -110,7 +110,7 @@ class Worker(WorkerHelper): ...@@ -110,7 +110,7 @@ class Worker(WorkerHelper):
} }
if os.getenv("WG_BACKEND", None) == "ray": if os.getenv("WG_BACKEND", None) == "ray":
from single_controller.base.register_center.ray import create_worker_group_register_center from verl.single_controller.base.register_center.ray import create_worker_group_register_center
self.register_center = create_worker_group_register_center(name=register_center_name, self.register_center = create_worker_group_register_center(name=register_center_name,
info=rank_zero_info) info=rank_zero_info)
......
...@@ -20,7 +20,7 @@ import signal ...@@ -20,7 +20,7 @@ import signal
import time import time
from typing import List, Any, Callable, Dict from typing import List, Any, Callable, Dict
from single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool: class ResourcePool:
......
...@@ -21,7 +21,7 @@ from ray.util.placement_group import placement_group, PlacementGroup ...@@ -21,7 +21,7 @@ from ray.util.placement_group import placement_group, PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
from ray.experimental.state.api import get_actor from ray.experimental.state.api import get_actor
from single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
__all__ = ['Worker'] __all__ = ['Worker']
...@@ -373,7 +373,7 @@ with code written in separate ray.Actors. ...@@ -373,7 +373,7 @@ with code written in separate ray.Actors.
""" """
from unittest.mock import patch from unittest.mock import patch
from single_controller.base.decorator import MAGIC_ATTR from verl.single_controller.base.decorator import MAGIC_ATTR
import os import os
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import ray import ray
# compatiblity cern # compatiblity cern
from single_controller.base.decorator import * from verl.single_controller.base.decorator import *
def maybe_remote(main): def maybe_remote(main):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import ray import ray
from single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs from verl.single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
@ray.remote @ray.remote
......
...@@ -17,8 +17,8 @@ from typing import Dict, Optional ...@@ -17,8 +17,8 @@ from typing import Dict, Optional
import ray import ray
from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
from single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
from single_controller.base.megatron.worker_group import MegatronWorkerGroup from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
# NOTE(sgm): for opensource megatron-core # NOTE(sgm): for opensource megatron-core
......
...@@ -33,7 +33,7 @@ from verl import DataProto ...@@ -33,7 +33,7 @@ from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fs import copy_local_path_from_hdfs
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs from verl.utils.hdfs_io import makedirs
from single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@hydra.main(config_path='config', config_name='generation', version_base=None) @hydra.main(config_path='config', config_name='generation', version_base=None)
......
...@@ -120,13 +120,13 @@ def main_task(config): ...@@ -120,13 +120,13 @@ def main_task(config):
if config.actor_rollout_ref.actor.strategy == 'fsdp': if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup
else: else:
......
...@@ -26,9 +26,9 @@ from omegaconf import OmegaConf, open_dict ...@@ -26,9 +26,9 @@ from omegaconf import OmegaConf, open_dict
import numpy as np import numpy as np
from codetiming import Timer from codetiming import Timer
from single_controller.base import Worker from verl.single_controller.base import Worker
from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from single_controller.ray.base import create_colocated_worker_cls from verl.single_controller.ray.base import create_colocated_worker_cls
from verl import DataProto from verl import DataProto
from verl.trainer.ppo import core_algos from verl.trainer.ppo import core_algos
......
...@@ -23,8 +23,8 @@ import torch ...@@ -23,8 +23,8 @@ import torch
import torch.distributed import torch.distributed
from omegaconf import DictConfig, open_dict from omegaconf import DictConfig, open_dict
from single_controller.base import Worker from verl.single_controller.base import Worker
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
import verl.utils.torch_functional as verl_F import verl.utils.torch_functional as verl_F
from verl import DataProto from verl import DataProto
from verl.trainer.ppo.actor import DataParallelPPOActor from verl.trainer.ppo.actor import DataParallelPPOActor
......
...@@ -22,13 +22,13 @@ import torch ...@@ -22,13 +22,13 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from omegaconf import DictConfig from omegaconf import DictConfig
from single_controller.base.megatron.worker import MegatronWorker from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor
from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic
from verl.trainer.ppo.hybrid_engine import AllGatherPPModel from verl.trainer.ppo.hybrid_engine import AllGatherPPModel
from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel
from single_controller.base.decorator import register, Dispatch from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.debug import log_gpu_memory_usage from verl.utils.debug import log_gpu_memory_usage
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment