Unverified Commit cd6cef60 by Guangming Sheng Committed by GitHub

[BREAKING][core] move single_controller into verl directory (#45)

* [BREAKING][core] move single_controller into verl directory

* fix blocking flag in fsdp workers
parent 1b24a3a8
......@@ -42,4 +42,4 @@ jobs:
pip install toml==0.10.2
- name: Running yapf
run: |
yapf -r -vv -d --style=./.style.yapf verl tests single_controller examples
yapf -r -vv -d --style=./.style.yapf verl tests examples
......@@ -180,7 +180,7 @@ pip3 install yapf
```
Then, make sure you are at top level of verl repo and run
```bash
yapf -ir -vv --style ./.style.yapf verl single_controller examples
yapf -ir -vv --style ./.style.yapf verl examples
```
......
......@@ -47,8 +47,8 @@ Implementation details:
.. code:: python
from single_controller.base import Worker
from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray
@ray.remote
......@@ -75,7 +75,7 @@ API: compute reference log probability
.. code:: python
from single_controller.base import Worker
from verl.single_controller.base import Worker
import ray
@ray.remote
......@@ -93,7 +93,7 @@ API: Update actor model parameters
.. code:: python
from single_controller.base import Worker
from verl.single_controller.base import Worker
import ray
@ray.remote
......@@ -184,7 +184,7 @@ registered into the worker_group**
.. code:: python
from single_controller.base.decorator import register
from verl.single_controller.base.decorator import register
def dispatch_data(worker_group, data):
return data.chunk(worker_group.world_size)
......@@ -214,11 +214,11 @@ computation, and data collection.
Furthermore, the model parallelism size of each model is usually fixed,
including dp, tp, pp. So for these common distributed scenarios, we have
pre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations.
pre-implemented specific dispatch and collect methods,in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_, which can be directly used to wrap the computations.
.. code:: python
from single_controller.base.decorator import register, Dispatch
from verl.single_controller.base.decorator import register, Dispatch
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto:
......
......@@ -49,13 +49,13 @@ Define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM
else:
......
......@@ -40,7 +40,7 @@ We implement various of APIs for each ``Worker`` class decorated by the
``@register(dispatch_mode=)`` . These APIs can be called by the ray
driver process. The data can be correctly collect and dispatch following
the ``dispatch_mode`` on each function. The supported dispatch_model
(i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/single_controller/base/decorator.py>`_.
(i.e., transfer protocols) can be found in `decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>`_.
ActorRolloutRefWorker
^^^^^^^^^^^^^^^^^^^^^
......
......@@ -232,8 +232,8 @@
},
"outputs": [],
"source": [
"from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n",
"from single_controller.base import Worker"
"from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n",
"from verl.single_controller.base import Worker"
]
},
{
......@@ -437,7 +437,7 @@
},
"outputs": [],
"source": [
"from single_controller.ray.decorator import register, Dispatch, Execute"
"from verl.single_controller.ray.decorator import register, Dispatch, Execute"
]
},
{
......@@ -518,7 +518,7 @@
},
"outputs": [],
"source": [
"from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute"
"from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute"
]
},
{
......@@ -723,10 +723,10 @@
},
"outputs": [],
"source": [
"from single_controller.ray.decorator import register, Dispatch, Execute\n",
"from single_controller.ray.megatron import NVMegatronRayWorkerGroup\n",
"from single_controller.base.megatron.worker import MegatronWorker\n",
"from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n",
"from verl.single_controller.ray.decorator import register, Dispatch, Execute\n",
"from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n",
"from verl.single_controller.base.megatron.worker import MegatronWorker\n",
"from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n",
"from omegaconf import OmegaConf\n",
"from megatron.core import parallel_state as mpu"
]
......
......@@ -121,13 +121,13 @@ def main_task(config):
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
......
......@@ -16,7 +16,7 @@ An naive implementation of split placment example
"""
import os
from pprint import pprint
from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls
from codetiming import Timer
......
......@@ -18,9 +18,9 @@ import os
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.base.decorator import register, Dispatch
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.base.decorator import register, Dispatch
@ray.remote
......
......@@ -19,8 +19,8 @@ import ray
import torch
from verl import DataProto
from single_controller.ray import RayClassWithInitArgs
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.ray import RayClassWithInitArgs
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from tensordict import TensorDict
......
......@@ -25,10 +25,10 @@ import torch
from torch import nn
import ray
from single_controller.ray import RayClassWithInitArgs, RayResourcePool
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from single_controller.base.megatron.worker import MegatronWorker
from single_controller.ray.decorator import register, Dispatch
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.single_controller.ray.decorator import register, Dispatch
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
......
......@@ -14,9 +14,9 @@
import ray
from single_controller.base import Worker
from single_controller.base.decorator import register, Dispatch
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from verl import DataProto
......
......@@ -15,10 +15,10 @@
In this test, we instantiate a data parallel worker with 8 GPUs
"""
from single_controller.base import Worker
from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
from single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.decorator import Dispatch, register
import ray
import torch
......
......@@ -18,9 +18,9 @@ import torch
from verl import DataProto
from tensordict import TensorDict
from single_controller.base.worker import Worker
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from single_controller.ray import RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs
from verl.single_controller.ray import RayWorkerGroup
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['NCCL_DEBUG'] = 'WARN'
......
......@@ -16,8 +16,8 @@ import time
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool
from single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool
from verl.single_controller.base.worker import Worker
@ray.remote
......
......@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
e2e test single_controller.ray
e2e test verl.single_controller.ray
"""
import os
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
@ray.remote
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from single_controller.remote import remote, RemoteBackend, SharedResourcePool
from single_controller.base.decorator import register, Dispatch
from single_controller.base.worker import Worker
from verl.single_controller.remote import remote, RemoteBackend, SharedResourcePool
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.base.worker import Worker
@remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool)
......
......@@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
e2e test single_controller.ray
e2e test verl.single_controller.ray
"""
import torch
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute
def two_to_all_dispatch_fn(worker_group, *args, **kwargs):
......
......@@ -21,8 +21,8 @@ import torch
import torch.distributed
import ray
from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.base.worker import Worker
@ray.remote
......
......@@ -75,7 +75,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
"""
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group,
MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}'
......@@ -104,7 +104,7 @@ def collect_megatron_compute(worker_group, output):
"""
Only collect the data from the tp=0 and pp=last and every dp ranks
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
pp_size = worker_group.get_megatron_global_info().pp_size
......@@ -119,7 +119,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
"""
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
......@@ -162,7 +162,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
"""
treat pp as dp.
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_size = worker_group.pp_size
......@@ -210,7 +210,7 @@ def collect_megatron_pp_as_dp(worker_group, output):
"""
treat pp as dp. Only collect data on tp=0
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_dp = []
for global_rank in range(worker_group.world_size):
......@@ -224,7 +224,7 @@ def collect_megatron_pp_only(worker_group, output):
"""
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
"""
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output_in_pp = []
for global_rank in range(worker_group.world_size):
......@@ -235,7 +235,7 @@ def collect_megatron_pp_only(worker_group, output):
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
pp_dp_size = worker_group.dp_size * worker_group.pp_size
......@@ -245,7 +245,7 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
from verl.protocol import DataProto
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
assert isinstance(worker_group, MegatronWorkerGroup)
output = collect_megatron_pp_as_dp(worker_group, output)
......@@ -253,7 +253,7 @@ def collect_megatron_pp_as_dp_data_proto(worker_group, output):
def dispatch_dp_compute(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
......@@ -263,21 +263,21 @@ def dispatch_dp_compute(worker_group, *args, **kwargs):
def collect_dp_compute(worker_group, output):
from single_controller.base.worker_group import WorkerGroup
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from single_controller.base.worker_group import WorkerGroup
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) == FunctionType # NOTE: The first one args is a function!
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from single_controller.base.worker import Worker
from verl.single_controller.base.worker import Worker
class DPEngineWorker(Worker):
......
......@@ -14,7 +14,7 @@
import os
from dataclasses import dataclass
from single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
class MegatronWorker(Worker):
......
......@@ -15,7 +15,7 @@
from typing import Dict
from .worker import DistRankInfo, DistGlobalInfo
from single_controller.base import ResourcePool, WorkerGroup
from verl.single_controller.base import ResourcePool, WorkerGroup
class MegatronWorkerGroup(WorkerGroup):
......
......@@ -17,7 +17,7 @@ the class for Worker
import os
import socket
from dataclasses import dataclass
from single_controller.base.decorator import register, Dispatch
from verl.single_controller.base.decorator import register, Dispatch
@dataclass
......@@ -43,7 +43,7 @@ class WorkerHelper:
import ray
return ray._private.services.get_node_ip_address()
elif os.getenv("WG_BACKEND", None) == "torch_rpc":
from single_controller.torchrpc.k8s_client import get_ip_addr
from verl.single_controller.torchrpc.k8s_client import get_ip_addr
return get_ip_addr()
return None
......@@ -110,7 +110,7 @@ class Worker(WorkerHelper):
}
if os.getenv("WG_BACKEND", None) == "ray":
from single_controller.base.register_center.ray import create_worker_group_register_center
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
self.register_center = create_worker_group_register_center(name=register_center_name,
info=rank_zero_info)
......
......@@ -20,7 +20,7 @@ import signal
import time
from typing import List, Any, Callable, Dict
from single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
......
......@@ -21,7 +21,7 @@ from ray.util.placement_group import placement_group, PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
from ray.experimental.state.api import get_actor
from single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
__all__ = ['Worker']
......@@ -373,7 +373,7 @@ with code written in separate ray.Actors.
"""
from unittest.mock import patch
from single_controller.base.decorator import MAGIC_ATTR
from verl.single_controller.base.decorator import MAGIC_ATTR
import os
......
......@@ -19,7 +19,7 @@ import os
import ray
# compatiblity cern
from single_controller.base.decorator import *
from verl.single_controller.base.decorator import *
def maybe_remote(main):
......
......@@ -14,7 +14,7 @@
import ray
from single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
from verl.single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
@ray.remote
......
......@@ -17,8 +17,8 @@ from typing import Dict, Optional
import ray
from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
from single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
from single_controller.base.megatron.worker_group import MegatronWorkerGroup
from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup
# NOTE(sgm): for opensource megatron-core
......
......@@ -33,7 +33,7 @@ from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@hydra.main(config_path='config', config_name='generation', version_base=None)
......
......@@ -120,13 +120,13 @@ def main_task(config):
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray import RayWorkerGroup
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
......
......@@ -26,9 +26,9 @@ from omegaconf import OmegaConf, open_dict
import numpy as np
from codetiming import Timer
from single_controller.base import Worker
from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from single_controller.ray.base import create_colocated_worker_cls
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl import DataProto
from verl.trainer.ppo import core_algos
......
......@@ -23,8 +23,8 @@ import torch
import torch.distributed
from omegaconf import DictConfig, open_dict
from single_controller.base import Worker
from single_controller.base.decorator import register, Dispatch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.actor import DataParallelPPOActor
......
......@@ -22,13 +22,13 @@ import torch
import torch.distributed
import torch.nn as nn
from omegaconf import DictConfig
from single_controller.base.megatron.worker import MegatronWorker
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor
from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic
from verl.trainer.ppo.hybrid_engine import AllGatherPPModel
from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel
from single_controller.base.decorator import register, Dispatch
from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.debug import log_gpu_memory_usage
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment