Commit e999fff3 by nzy

refactor: get_score_token

parent b5fbb4c5
......@@ -2,6 +2,8 @@ import argparse
from pathlib import Path
import pprint
from transformers import AutoTokenizer
from codecritic.data.cov import COV_PROMPT
from codecritic.utils.vllm import vllm_chatcomplete, vllm_score
from codecritic.utils.json import load_jsonl, save_jsonl
......@@ -38,7 +40,8 @@ def run_sft_model(model_path, test_path, apps_path, reason_prompt=None):
test_dataset = vllm_chatcomplete(model_path, test_dataset, sampling_params)
# score
score_token = get_score_token_id(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
score_token = get_score_token_id(tokenizer)
test_dataset = [append_prompt(x, JUDGE_PROMPT) for x in test_dataset]
results = vllm_score(model_path, test_dataset, score_token)
score_path = result_dir / "scores.jsonl"
......
import re
from codecritic.utils.json import load_json, save_json
from transformers import AutoTokenizer
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
......@@ -70,8 +69,7 @@ def save_dataset(llamafactory_path, dataset_info, dataset):
save_json(dataset, f"{llamafactory_path}/data/{dataset_relative_path}")
def get_score_token_id(model_path, token_str="Yes"):
tokenizer = AutoTokenizer.from_pretrained(model_path)
def get_score_token_id(tokenizer, token_str="Yes"):
score_tokens = tokenizer.encode(token_str, add_special_tokens=False)
assert len(score_tokens) == 1
return score_tokens[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment