Unverified Commit 13a87c76 by Lumeng Wu Committed by GitHub

feat: support loading reward function from an external file (#452)

parent 90109ffd
...@@ -90,3 +90,7 @@ jobs: ...@@ -90,3 +90,7 @@ jobs:
run: | run: |
ray stop --force ray stop --force
bash tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh bash tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using customized reward function
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_custom_function_rm.sh
...@@ -319,6 +319,18 @@ Reward Model ...@@ -319,6 +319,18 @@ Reward Model
if ``naive``. If all verification functions are multiprocessing-safe, the reward if ``naive``. If all verification functions are multiprocessing-safe, the reward
manager can be set to ``prime`` for parallel verification. manager can be set to ``prime`` for parallel verification.
Customized Reward Function
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code:: yaml
custom_reward_function:
path: null
name: compute_score
- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used.
- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'.
Algorithm Algorithm
~~~~~~~~~ ~~~~~~~~~
......
...@@ -3,6 +3,7 @@ Implement Reward Function for Dataset ...@@ -3,6 +3,7 @@ Implement Reward Function for Dataset
For each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses. For each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses.
We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_. We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.
You can also use customized reward functions.
Currently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g., Currently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g.,
full_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model full_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model
...@@ -35,6 +36,10 @@ score for each response. ...@@ -35,6 +36,10 @@ score for each response.
Reward Functions Reward Functions
---------------- ----------------
Pre-implemented
~~~~~~~~~~~~~~~
We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_. We already pre-implemented some reward functions in `reward_score directory <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score>`_.
- In the `GSM8k example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_, we - In the `GSM8k example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/gsm8k.py>`_, we
...@@ -44,3 +49,21 @@ We already pre-implemented some reward functions in `reward_score directory <htt ...@@ -44,3 +49,21 @@ We already pre-implemented some reward functions in `reward_score directory <htt
the format is incorrect, score 0 points. the format is incorrect, score 0 points.
- In the `MATH example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_, we follow - In the `MATH example <https://github.com/volcengine/verl/blob/main/verl/utils/reward_score/math.py>`_, we follow
the implementation in `lm-evaluation-harness repository <https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py>`_. the implementation in `lm-evaluation-harness repository <https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py>`_.
Customized
~~~~~~~~~~
You can implement customized reward functions in a separate file and specify them using ``custom_reward_function.path`` and ``custom_reward_function.name``. For the set of them, please refer to :ref:`config-explain-page`.
The parameters of your reward function should be ``data_source``, ``solution_str``, ``ground_truth``, and ``extra_info``.
For example:
.. code:: python
def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None):
return len(solution_str)/100
If you are testing only a single customized reward function, you can simply name it 'compute_score' and leave ``custom_reward_function.name`` unset.
To run multiple tests with different customized reward functions, you can modify both ``custom_reward_function.path`` and ``custom_reward_function.name`` for each trial.
For instance, you might create a single `my_reward.py` file and implement multiple reward functions within it. This way, for different trials, you only need to adjust ``custom_reward_function.name``, making it more convenient to conduct multiple tests within scripts.
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def check_congratulations_in_file(output_file):
with open(output_file, 'r') as f:
output = f.read()
success_message = "Congratulations!!! You have called my_reward_function successfully!!!"
assert success_message in output, f'Success message of my_reward_function not found in {output_file}'
print("Check passes")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
args = parser.parse_args()
check_congratulations_in_file(args.output_file)
#!/bin/bash
set -e -x
FILE="$(pwd)/my_reward_function.py"
rm -rf $FILE
cat <<EOF > "$FILE"
def my_reward_function(data_source, solution_str, ground_truth, extra_info=None):
print(f"Congratulations!!! You have called my_reward_function successfully!!!")
return 0.1
EOF
OUTPUT_FILE="$(pwd)/output_custom_reward.txt"
FUNCTION_NAME="my_reward_function"
rm -rf $OUTPUT_FILE
export VLLM_ATTENTION_BACKEND=XFORMERS
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
custom_reward_function.path=$FILE\
custom_reward_function.name=$FUNCTION_NAME\
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_custom_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.default_local_dir=$HOME/ckpt/ \
trainer.total_training_steps=2 | tee $OUTPUT_FILE;
python3 tests/e2e/check_custom_rwd_fn.py --output_file=$OUTPUT_FILE
rm -rf $FILE
rm -rf $OUTPUT_FILE
\ No newline at end of file
...@@ -138,6 +138,10 @@ reward_model: ...@@ -138,6 +138,10 @@ reward_model:
use_dynamic_bsz: ${critic.use_dynamic_bsz} use_dynamic_bsz: ${critic.use_dynamic_bsz}
max_length: null max_length: null
custom_reward_function:
path: null
name: compute_score
algorithm: algorithm:
gamma: 1.0 gamma: 1.0
lam: 1.0 lam: 1.0
......
...@@ -148,6 +148,10 @@ reward_model: ...@@ -148,6 +148,10 @@ reward_model:
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive reward_manager: naive
custom_reward_function:
path: null
name: compute_score
algorithm: algorithm:
gamma: 1.0 gamma: 1.0
lam: 1.0 lam: 1.0
......
...@@ -20,21 +20,50 @@ import ray ...@@ -20,21 +20,50 @@ import ray
import hydra import hydra
def get_custom_reward_fn(config):
import importlib.util, os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
return getattr(module, function_name)
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None) @hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config): def main(config):
run_ppo(config) run_ppo(config)
def run_ppo(config, compute_score=None): def run_ppo(config) -> None:
if not ray.is_initialized(): if not ray.is_initialized():
# this is for local ray cluster # this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config, compute_score)) ray.get(main_task.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None): def main_task(config):
from verl.utils.fs import copy_to_local from verl.utils.fs import copy_to_local
# print initial config # print initial config
from pprint import pprint from pprint import pprint
...@@ -109,6 +138,8 @@ def main_task(config, compute_score=None): ...@@ -109,6 +138,8 @@ def main_task(config, compute_score=None):
reward_manager_cls = PrimeRewardManager reward_manager_cls = PrimeRewardManager
else: else:
raise NotImplementedError raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation # Note that we always use function-based RM for validation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment