Unverified Commit ce862ce8 by Kyle Corbitt Committed by GitHub

Allow users to pass in custom compute_score function (#162)

This is a follow-up to https://github.com/volcengine/verl/issues/151

## Motivation

Currently, in order to add a custom score function you need to fork verl
and update the `_select_rm_score_fn` to define your logic. This makes it
harder to use verl as part of a larger application while staying up to
date with upstream improvements in verl.

It would be convenient to allow end users to directly pass in a reward
function they wish to use, without requiring them to clone/fork verl to
do so.

## Design

In this PR I slightly modify `main_ppo.py` to allow users to import a
new function `run_ppo`. `run_ppo` behaves very similarly to the existing
`main`, with the important addition of a new `compute_score` argument.
This argument, if passed in, is used to compute the score of every
generation. This is the change that allows

The `compute_score` function is similar in shape to the existing
`compute_score` on gsm8k and math. However, I have added a new
`data_source` parameter so that the user can compute the score
differently if desired depending on the task shape.

## Example Usage

This is a sample script showing how you can use the new functionality. I
have tested that this works.

```python
from verl.trainer.main_ppo import run_ppo
from omegaconf import OmegaConf


def custom_compute_score(data_source, solution_str, ground_truth):
    """Dummy compute_score function that reward the model for generations of exactly 20 characters :)
    """
    return abs(len(solution_str) - 20)


config = OmegaConf.load("vendor/verl/verl/trainer/config/ppo_trainer.yaml")

# Update config as needed
config.data.train_files = "path/to/train.parquet"
config.data.val_files = "path/to/test.parquet"
# ...

run_ppo(config, custom_compute_score)
```

## Breaking changes

There are no breaking changes in this PR. It is still possible to call
`python -m verl.trainer.main_ppo ...` as before (although if you want to
pass in a custom compute_score you will need to use the new method
described above).

## Possible future work

It would be great to move to [structured
configs](https://omegaconf.readthedocs.io/en/2.1_branch/structured_config.html)
as well since they'd allow us to have typesafe, autocompletable
configurations from Python. I thought about adding those changes here as
well but they would be much more extensive and I'm not sure whether
there's interest from the project.
parent 41b7c583
......@@ -21,11 +21,11 @@ from verl.utils.reward_score import gsm8k, math
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
def _select_rm_score_fn(data_source):
def _default_compute_score(data_source, solution_str, ground_truth):
if data_source == 'openai/gsm8k':
return gsm8k.compute_score
return gsm8k.compute_score(solution_str, ground_truth)
elif data_source == 'lighteval/MATH':
return math.compute_score
return math.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError
......@@ -34,9 +34,10 @@ class RewardManager():
"""The reward manager.
"""
def __init__(self, tokenizer, num_examine) -> None:
def __init__(self, tokenizer, num_examine, compute_score=None) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or _default_compute_score
def __call__(self, data: DataProto):
"""We will expand this function gradually based on the available datasets"""
......@@ -69,11 +70,13 @@ class RewardManager():
ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
# select rm_score
data_source = data_item.non_tensor_batch['data_source']
compute_score_fn = _select_rm_score_fn(data_source)
score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth)
score = self.compute_score(
data_source=data_source,
solution_str=sequences_str,
ground_truth=ground_truth,
)
reward_tensor[i, valid_response_length - 1] = score
if data_source not in already_print_data_sources:
......@@ -92,15 +95,19 @@ import hydra
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config, compute_score=None):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config))
ray.get(main_task.remote(config, compute_score))
@ray.remote
def main_task(config):
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
from transformers import AutoTokenizer
......@@ -167,10 +174,10 @@ def main_task(config):
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0)
reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1)
val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment