Unverified Commit 656accb0 by 湛露先生 Committed by GitHub

rollout: Fix navive_rollout class names. (#361)

Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
parent 8eb22b50
......@@ -16,20 +16,15 @@ Megatron Reward Model.
"""
from tensordict import TensorDict
from functools import partial
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits
import torch
import torch
import torch.distributed
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.utils.torch_functional import pad_sequence_to_length
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.workers.reward_model.base import BasePPORewardModel
from verl.utils.megatron import sequence_parallel as sp_utils
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
......
......@@ -30,7 +30,7 @@ from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits
from ..base import BaseRollout
__all__ = ['NativeRollout']
__all__ = ['NaiveRollout']
class NaiveRollout(BaseRollout):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment