Unverified Commit 99d2c19b by Guangming Sheng Committed by GitHub

[misc] feat: remove @ray.remote on workers to allow inheritance (#61)

Co-authored-by: Haibin Lin <haibin.lin@bytedance.com>
parent 09568e60
......@@ -135,9 +135,9 @@ def main_task(config):
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ActorRolloutRefWorker,
Role.Critic: CriticWorker,
Role.RefPolicy: ActorRolloutRefWorker
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
}
# NOTE: initialze two resource pool
......@@ -173,7 +173,7 @@ def main_task(config):
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = critic_pool_id
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
......
......@@ -14,7 +14,7 @@
"""
Generate responses given a dataset of prompts
"""
import ray
import numpy as np
import hydra
import os
......@@ -59,7 +59,7 @@ def main(config):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ray_cls_with_init = RayClassWithInitArgs(cls=ActorRolloutRefWorker, config=config, role='rollout')
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout')
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model()
......
......@@ -136,9 +136,9 @@ def main_task(config):
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ActorRolloutRefWorker,
Role.Critic: CriticWorker,
Role.RefPolicy: ActorRolloutRefWorker
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
}
global_pool_id = 'global_pool'
......@@ -164,7 +164,7 @@ def main_task(config):
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
......
......@@ -40,7 +40,6 @@ logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
@ray.remote
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
......@@ -434,7 +433,6 @@ class ActorRolloutRefWorker(Worker):
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
@ray.remote
class CriticWorker(Worker):
def __init__(self, config):
......@@ -642,7 +640,6 @@ class CriticWorker(Worker):
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
@ray.remote
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
......
......@@ -60,7 +60,6 @@ def set_random_seed(seed):
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
@ray.remote
class ActorRolloutRefWorker(MegatronWorker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
......@@ -406,7 +405,6 @@ class ActorRolloutRefWorker(MegatronWorker):
pass
@ray.remote
class CriticWorker(MegatronWorker):
def __init__(self, config):
......@@ -575,7 +573,6 @@ class CriticWorker(MegatronWorker):
pass
@ray.remote
class RewardModelWorker(MegatronWorker):
"""
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment