Unverified Commit 13a87c76 by Lumeng Wu Committed by GitHub

feat: support loading reward function from an external file (#452)

parent 90109ffd
......@@ -90,3 +90,7 @@ jobs:
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using customized reward function
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_custom_function_rm.sh
......@@ -319,6 +319,18 @@ Reward Model
if ``naive``. If all verification functions are multiprocessing-safe, the reward
manager can be set to ``prime`` for parallel verification.
Customized Reward Function
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code:: yaml
custom_reward_function:
path: null
name: compute_score
- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.
- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.
Algorithm
~~~~~~~~~
......
......@@ -3,6 +3,7 @@ Implement Reward Function for Dataset
For each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses.
We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.
You can also use customized reward functions.
Currently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g.,
full_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model
......@@ -35,6 +36,10 @@ score for each response.
Reward Functions
----------------
Pre-implemented
~~~~~~~~~~~~~~~
We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.
- In the `GSM8k example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_, we
......@@ -44,3 +49,21 @@ We already pre-implemented some reward functions in `reward_score directory <htt
the format is incorrect, score 0 points.
- In the `MATH example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_, we follow
the implementation in `lm-evaluation-harness repository <https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py>`_.
Customized
~~~~~~~~~~
You can implement customized reward functions in a separate file and specify them using ``custom_reward_function.path`` and ``custom_reward_function.name``. For the set of them, please refer to :ref:`config-explain-page`.
The parameters of your reward function should be ``data_source``, ``solution_str``, ``ground_truth``, and ``extra_info``.
For example:
.. code:: python
def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):
return len(solution_str)/100
If you are testing only a single customized reward function, you can simply name it 'compute_score' and leave ``custom_reward_function.name`` unset.
To run multiple tests with different customized reward functions, you can modify both ``custom_reward_function.path`` and ``custom_reward_function.name`` for each trial.
For instance, you might create a single `my_reward.py` file and implement multiple reward functions within it. This way, for different trials, you only need to adjust ``custom_reward_function.name``, making it more convenient to conduct multiple tests within scripts.
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def check_congratulations_in_file(output_file):
with open(output_file, 'r') as f:
output = f.read()
success_message = "Congratulations!!! You have called my_reward_function successfully!!!"
assert success_message in output, f'Success message of my_reward_function not found in {output_file}'
print("Check passes")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
args = parser.parse_args()
check_congratulations_in_file(args.output_file)
#!/bin/bash
set -e -x
FILE="$(pwd)/my_reward_function.py"
rm -rf $FILE
cat <<EOF > "$FILE"
def my_reward_function(data_source, solution_str, ground_truth, extra_info=None):
print(f"Congratulations!!! You have called my_reward_function successfully!!!")
return 0.1
EOF
OUTPUT_FILE="$(pwd)/output_custom_reward.txt"
FUNCTION_NAME="my_reward_function"
rm -rf $OUTPUT_FILE
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
custom_reward_function.path=$FILE\
custom_reward_function.name=$FUNCTION_NAME\
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_custom_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.default_local_dir=$HOME/ckpt/ \
trainer.total_training_steps=2 | tee $OUTPUT_FILE;
python3 tests/e2e/check_custom_rwd_fn.py --output_file=$OUTPUT_FILE
rm -rf $FILE
rm -rf $OUTPUT_FILE
\ No newline at end of file
......@@ -138,6 +138,10 @@ reward_model:
use_dynamic_bsz: ${critic.use_dynamic_bsz}
max_length: null
custom_reward_function:
path: null
name: compute_score
algorithm:
gamma: 1.0
lam: 1.0
......
......@@ -148,6 +148,10 @@ reward_model:
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
custom_reward_function:
path: null
name: compute_score
algorithm:
gamma: 1.0
lam: 1.0
......
......@@ -20,21 +20,50 @@ import ray
import hydra
def get_custom_reward_fn(config):
import importlib.util, os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
return getattr(module, function_name)
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config, compute_score=None):
def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config, compute_score))
ray.get(main_task.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
def main_task(config):
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
......@@ -109,6 +138,8 @@ def main_task(config, compute_score=None):
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment