Commit 0c079f4d by nzy

step4: refactor. merge two to one

parent 1a8e110a
...@@ -7,7 +7,7 @@ from utils import load_json ...@@ -7,7 +7,7 @@ from utils import load_json
from utils_dataset import mk_critic_qa, mk_sft_item, mk_sft_dataset_info, save_dataset from utils_dataset import mk_critic_qa, mk_sft_item, mk_sft_dataset_info, save_dataset
from utils_vllm import vllm_chatcomplete from utils_vllm import vllm_chatcomplete
verify_prompt = "Let's verify step by step." COV_PROMPT = "Let's verify step by step."
def mk_cov_prompt(is_correct): def mk_cov_prompt(is_correct):
...@@ -16,7 +16,7 @@ def mk_cov_prompt(is_correct): ...@@ -16,7 +16,7 @@ def mk_cov_prompt(is_correct):
else: else:
prompt1 = "This code is incorrect." prompt1 = "This code is incorrect."
return [{"role": "user", "content": prompt1 + " " + verify_prompt}] return [{"role": "user", "content": prompt1 + " " + COV_PROMPT}]
def convert_preference_to_vot_prompt(item): def convert_preference_to_vot_prompt(item):
...@@ -30,7 +30,7 @@ def convert_preference_to_vot_prompt(item): ...@@ -30,7 +30,7 @@ def convert_preference_to_vot_prompt(item):
def convert_cov_to_cov_dataset(item): def convert_cov_to_cov_dataset(item):
item["messages"][2]["content"] = verify_prompt item["messages"][2]["content"] = COV_PROMPT
return item return item
......
import argparse
from pathlib import Path
import pprint
from utils_vllm import vllm_chatcomplete, vllm_score
from utils import load_jsonl, save_jsonl, save_json, extract_code
from utils_metric import group_results, score_pass_at_k
from utils_dataset import (
mk_critic_verify,
get_score_token_id,
code_template,
mk_critic_reason,
mk_critic_qa,
)
def preprocess_test_item(item):
question = item["messages"][0]["content"]
answer = item["messages"][1]["content"]
code = code_template.format(extract_code(answer))
critic_reason_prompt = mk_critic_reason("", "")[:1]
item["messages"] = mk_critic_qa(question, code) + critic_reason_prompt
return item
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--test", type=str)
parser.add_argument("--apps", type=str)
args = parser.parse_args()
home_path = Path(args.model).parent
result_dir = home_path / "eval"
result_dir.mkdir(exist_ok=True)
prompt_path = result_dir / "prompt.jsonl"
raw_test_dataset = load_jsonl(args.test)
test_dataset = [preprocess_test_item(item) for item in raw_test_dataset]
save_jsonl(test_dataset, prompt_path)
reason_path = result_dir / "reason.jsonl"
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
save_json(sampling_params, result_dir / "sampling_params.json")
vllm_chatcomplete(args.model, prompt_path, reason_path, sampling_params)
score_token = get_score_token_id(args.model)
reason_results = load_jsonl(reason_path)
score_prompts = []
for item in reason_results:
item["messages"] += mk_critic_verify()
score_prompts.append(item)
score_prompt_path = result_dir / "score_prompt.jsonl"
save_jsonl(score_prompts, score_prompt_path)
score_path = result_dir / "score.jsonl"
vllm_score(args.model, score_prompt_path, score_path, score_token)
results = load_jsonl(score_path)
groups = group_results(results, args.apps)
eval_results = [score_pass_at_k(groups, k, "critic") for k in range(1, 16)]
result_path = result_dir / "result.jsonl"
save_jsonl(eval_results, result_path)
pprint.pp(eval_results)
import argparse import argparse
from pathlib import Path from pathlib import Path
import pprint import pprint
from utils_vllm import vllm_score
from step2_cov_dataset import COV_PROMPT
from utils_vllm import vllm_chatcomplete, vllm_score
from utils import load_jsonl, save_jsonl, extract_code, code_template from utils import load_jsonl, save_jsonl, extract_code, code_template
from utils_dataset import mk_critic_qa, mk_critic_verify, get_score_token_id from utils_dataset import mk_critic_qa, JUDGE_PROMPT, get_score_token_id
from utils_metric import group_results, score_pass_at_k from utils_metric import group_results, score_pass_at_k
...@@ -11,36 +13,56 @@ def preprocess_test_item(item): ...@@ -11,36 +13,56 @@ def preprocess_test_item(item):
question = item["messages"][0]["content"] question = item["messages"][0]["content"]
answer = item["messages"][1]["content"] answer = item["messages"][1]["content"]
code = code_template.format(extract_code(answer)) code = code_template.format(extract_code(answer))
item["messages"] = mk_critic_qa(question, code) + mk_critic_verify() item["messages"] = mk_critic_qa(question, code)
return item return item
if __name__ == "__main__": def append_prompt(item, content):
parser = argparse.ArgumentParser() item["messages"].append({"role": "user", "content": content})
parser.add_argument("--model", type=str) return item
parser.add_argument("--test", type=str)
parser.add_argument("--apps", type=str)
args = parser.parse_args()
home_path = Path(args.model).parent
def run_sft_model(model_path, test_path, apps_path, reason_prompt=None):
home_path = Path(model_path).parent
result_dir = home_path / "eval" result_dir = home_path / "eval"
result_dir.mkdir(exist_ok=True) result_dir.mkdir(exist_ok=True)
# preprocess prompt # preprocess prompt
prompt_path = result_dir / "prompt.jsonl" raw_test_dataset = load_jsonl(test_path)
raw_test_dataset = load_jsonl(args.test)
test_dataset = [preprocess_test_item(item) for item in raw_test_dataset] test_dataset = [preprocess_test_item(item) for item in raw_test_dataset]
save_jsonl(test_dataset, prompt_path)
# reason
if reason_prompt:
test_dataset = [append_prompt(x, COV_PROMPT) for x in test_dataset]
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
reason_path = result_dir / "reason.jsonl"
test_dataset = vllm_chatcomplete(
model_path, test_dataset, reason_path, sampling_params
)
# score # score
score_path = result_dir / "scores.jsonl" score_path = result_dir / "scores.jsonl"
score_token = get_score_token_id(args.model) score_token = get_score_token_id(model_path)
vllm_score(args.model, prompt_path, score_path, score_token) test_dataset = [append_prompt(x, JUDGE_PROMPT) for x in test_dataset]
results = vllm_score(model_path, test_dataset, score_path, score_token)
# compute pass@k # compute pass@k
eval_result_path = result_dir / "passk.jsonl" eval_result_path = result_dir / "passk.jsonl"
results = load_jsonl(score_path) # results = load_jsonl(score_path)
groups = group_results(results, args.apps) groups = group_results(results, apps_path)
eval_results = [score_pass_at_k(groups, k, "sft-orm") for k in range(1, 16)] eval_results = [score_pass_at_k(groups, k, home_path.stem) for k in range(1, 16)]
save_jsonl(eval_results, eval_result_path) save_jsonl(eval_results, eval_result_path)
pprint.pp(eval_results) pprint.pp(eval_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--test", type=str)
parser.add_argument("--apps", type=str)
parser.add_argument("--reason", choices=["cov"])
args = parser.parse_args()
reason_prompts = {"cov": COV_PROMPT}
reason_prompt = reason_prompts.get(args.reason, None)
run_sft_model(args.model, args.test, args.apps, reason_prompt)
...@@ -68,9 +68,10 @@ def mk_critic_qa(instruction, code): ...@@ -68,9 +68,10 @@ def mk_critic_qa(instruction, code):
] ]
JUDGE_PROMPT = "Is the code correct (Yes/No)?"
def mk_critic_verify(answer=None): def mk_critic_verify(answer=None):
# answer: bool or none # answer: bool or none
message = [{"role": "user", "content": "Is the code correct (Yes/No)?"}] message = [{"role": "user", "content": JUDGE_PROMPT}]
if answer is not None: if answer is not None:
response = "Yes" if answer else "No" response = "Yes" if answer else "No"
message.append({"role": "assistant", "content": response}) message.append({"role": "assistant", "content": response})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment