Commit efa812c2 by nzy

refactor: rename mk_critic_qa to mk_message

parent e999fff3
...@@ -7,7 +7,7 @@ from transformers import AutoTokenizer ...@@ -7,7 +7,7 @@ from transformers import AutoTokenizer
from codecritic.data.cov import COV_PROMPT from codecritic.data.cov import COV_PROMPT
from codecritic.utils.vllm import vllm_chatcomplete, vllm_score from codecritic.utils.vllm import vllm_chatcomplete, vllm_score
from codecritic.utils.json import load_jsonl, save_jsonl from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.utils.data import extract_code, code_template, mk_critic_qa, JUDGE_PROMPT, get_score_token_id from codecritic.utils.data import extract_code, code_template, mk_message, JUDGE_PROMPT, get_score_token_id
from codecritic.utils.metric import group_results, score_pass_at_k from codecritic.utils.metric import group_results, score_pass_at_k
...@@ -15,7 +15,7 @@ def preprocess_test_item(item): ...@@ -15,7 +15,7 @@ def preprocess_test_item(item):
question = item["messages"][0]["content"] question = item["messages"][0]["content"]
answer = item["messages"][1]["content"] answer = item["messages"][1]["content"]
code = code_template.format(extract_code(answer)) code = code_template.format(extract_code(answer))
item["messages"] = mk_critic_qa(question, code) item["messages"] = mk_message(question, code)
return item return item
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
from itertools import chain from itertools import chain
from utils import load_json, extract_code, code_template from utils import load_json, extract_code, code_template
from codecritic.utils.data import mk_critic_qa, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER from codecritic.utils.data import mk_message, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER
from codecritic.utils.vllm import vllm_chatcomplete from codecritic.utils.vllm import vllm_chatcomplete
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue." COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
...@@ -51,8 +51,8 @@ def convert_preference_to_vot_prompt(item): ...@@ -51,8 +51,8 @@ def convert_preference_to_vot_prompt(item):
chosen = code_template.format(extract_code(chosen)) chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected)) rejected = code_template.format(extract_code(rejected))
messages1 = mk_critic_qa(message, chosen) + mk_cov_prompt(True) messages1 = mk_message(message, chosen) + mk_cov_prompt(True)
messages2 = mk_critic_qa(message, rejected) + mk_cov_prompt(False) messages2 = mk_message(message, rejected) + mk_cov_prompt(False)
return mk_sft_item(messages1), mk_sft_item(messages2) return mk_sft_item(messages1), mk_sft_item(messages2)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# This experiment aims to fairly compare these two approaches. # This experiment aims to fairly compare these two approaches.
import argparse import argparse
from codecritic.utils.json import load_json from codecritic.utils.json import load_json
from codecritic.utils.data import mk_critic_qa, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset from codecritic.utils.data import mk_message, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset
def convert_preference_to_sft(item): def convert_preference_to_sft(item):
...@@ -13,8 +13,8 @@ def convert_preference_to_sft(item): ...@@ -13,8 +13,8 @@ def convert_preference_to_sft(item):
chosen = item["chosen"]["content"] chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"] rejected = item["rejected"]["content"]
messages1 = mk_critic_qa(message, chosen) + mk_critic_verify(True) messages1 = mk_message(message, chosen) + mk_critic_verify(True)
messages2 = mk_critic_qa(message, rejected) + mk_critic_verify(False) messages2 = mk_message(message, rejected) + mk_critic_verify(False)
return mk_sft_item(messages1), mk_sft_item(messages2) return mk_sft_item(messages1), mk_sft_item(messages2)
......
...@@ -36,11 +36,10 @@ def mk_sft_item(messages): ...@@ -36,11 +36,10 @@ def mk_sft_item(messages):
return {"messages": messages} return {"messages": messages}
def mk_critic_qa(instruction, code): def mk_message(user, assistant):
# Code should be enclosed in a markdown code block
return [ return [
{"role": "user", "content": instruction}, {"role": "user", "content": user},
{"role": "assistant", "content": code}, {"role": "assistant", "content": assistant},
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment