Unverified Commit d4a00ef0 by Yuyang Ding Committed by GitHub

Add Math-Verify Support (#545)

# Description

https://github.com/volcengine/verl/issues/287,
https://github.com/volcengine/verl/issues/295.
This PR introduces support for
[Math-Verify](https://github.com/huggingface/Math-Verify) as a new
rule-based reward scorer, significantly improving evaluation accuracy.

# Key changes

- Added `math-verify` to the installation dependencies.
- Introduced `reward_score/math_verify.py` and updated
`reward_score/__init__.py`.

# Test

Comparison between the existing scorer in math.py and the newly added
`math_verify.py`, using Qwen2.5-Math-7B-Instruct:

```
# Use scorer in math.py (original)
{'val/test_score/DigitalLearningGmbH/MATH-lighteval': 0.803}

# Use scorer in math_verify.py (newly added)
{'val/test_score/DigitalLearningGmbH/MATH-lighteval': 0.8338}
```

Test scripts:

```bash
set -x

# Data Process
python examples/data_preprocess/math_dataset.py --local_dir /workspace/datasets/math

# Evaluation
export CUDA_VISIBLE_DEVICES=4,5,6,7
export VLLM_ATTENTION_BACKEND=XFORMERS

math_train_path=/workspace/datasets/math/train.parquet
math_test_path=/workspace/datasets/math/test.parquet

python3 -m verl.trainer.main_ppo \
    data.train_files="$math_train_path" \
    data.val_files="$math_test_path" \
    data.max_prompt_length=2048 \
    data.max_response_length=2048 \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-Math-7B-Instruct \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=1 \
    actor_rollout_ref.rollout.temperature=0 \
    trainer.logger=['console'] \
    trainer.project_name='test-math-verify' \
    trainer.experiment_name='test-math-verify' \
    +trainer.val_before_train=True \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.total_epochs=0 \
    data.train_batch_size=1024 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
    algorithm.adv_estimator=grpo $@
```
parent 1d12fe31
...@@ -35,6 +35,7 @@ dependencies = [ ...@@ -35,6 +35,7 @@ dependencies = [
"datasets", "datasets",
"dill", "dill",
"hydra-core", "hydra-core",
"math-verify",
"numpy", "numpy",
"pandas", "pandas",
"peft", "peft",
......
...@@ -27,6 +27,7 @@ install_requires = [ ...@@ -27,6 +27,7 @@ install_requires = [
'datasets', 'datasets',
'dill', 'dill',
'hydra-core', 'hydra-core',
'math-verify',
'numpy', 'numpy',
'pandas', 'pandas',
'peft', 'peft',
......
...@@ -19,8 +19,12 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N ...@@ -19,8 +19,12 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
from . import gsm8k from . import gsm8k
res = gsm8k.compute_score(solution_str, ground_truth) res = gsm8k.compute_score(solution_str, ground_truth)
elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
from . import math # from . import math
res = math.compute_score(solution_str, ground_truth) # res = math.compute_score(solution_str, ground_truth)
# Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy
from . import math_verify
res = math_verify.compute_score(solution_str, ground_truth)
elif data_source in [ elif data_source in [
'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
'numina_olympiads' 'numina_olympiads'
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
def compute_score(model_output: str, ground_truth: str) -> bool:
verify_func = math_metric(
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
)
ret_score = 0.
# Wrap the ground truth in \boxed{} format for verification
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
try:
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
except Exception as e:
print(e)
return ret_score
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment