Commit ca962f5c by nanziyuan

fix bugs.

parent 40af338d
import argparse
import datetime
from itertools import chain
from functools import partial
import os
from tqdm.contrib.concurrent import thread_map
from threading import Lock
......@@ -11,8 +13,7 @@ from codecritic.data.cov import (
convert_sft_to_vot_prompt,
convert_cov_to_cov_dataset,
)
from codecritic.data.utils import save_jsonl_dataset
from codecritic.utils.json import load_json, load_jsonl
from codecritic.utils.json import load_json, load_jsonl, save_jsonl
client = OpenAI(
......@@ -21,8 +22,7 @@ client = OpenAI(
)
def worker(args):
index, cov_prompt, lock = args
def worker(index, cov_prompt, lock, tmp_path):
completion = client.chat.completions.create(
model="deepseek-coder",
messages=cov_prompt["messages"],
......@@ -32,7 +32,7 @@ def worker(args):
content = completion.choices[0].message.content
cov_prompt["messages"].append({"role": "assistant", "content": content})
with lock:
with open("tmp.jsonl", "a", encoding="utf-8") as f:
with open(tmp_path, "a", encoding="utf-8") as f:
cov_prompt["index"] = index
f.write(json.dumps(cov_prompt) + '\n')
return cov_prompt
......@@ -40,29 +40,30 @@ def worker(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--sft_dataset", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--dataset", type=str)
parser.add_argument("--dataset_type", type=str, choices=["reward", "sft"])
parser.add_argument("--out", type=str)
parser.add_argument("--mode", type=str, choices=["train", "test"])
args = parser.parse_args()
if args.preference_dataset:
preference_dataset = load_json(args.preference_dataset)
cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False) for x in preference_dataset]
if args.dataset_type == "reward":
preference_dataset = load_json(args.dataset) # TODO change to jsonl
cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False, mode=args.mode) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts))
elif args.sft_dataset:
sft_dataset = load_jsonl(args.sft_dataset)
cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False) for x in sft_dataset]
elif args.dataset_type == "sft":
sft_dataset = load_jsonl(args.dataset)
cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False, mode=args.mode) for x in sft_dataset]
else:
parser.error("preference_dataset or sft_dataset")
lock = Lock()
tmp_path = f"log_{datetime.datetime.now().strftime('%y%m%d_%H%M')}.jsonl"
total_len = len(cov_prompts)
thread_args = list(zip(range(total_len),
cov_prompts,
[lock] * total_len))
w = partial(worker, lock=lock, tmp_path=tmp_path)
covs = thread_map(worker, thread_args, max_workers=8)
covs = thread_map(w, range(total_len), cov_prompts, max_workers=8)
dataset = list(map(convert_cov_to_cov_dataset, covs))
save_jsonl_dataset(dataset, args.output_dir)
save_jsonl(covs, args.out + ".raw")
dataset = [convert_cov_to_cov_dataset(x, args.mode) for x in covs]
save_jsonl(dataset, args.out)
......@@ -2,7 +2,7 @@
# Is reasoning really work? Let's verify step by step.
from codecritic.data.code import extract_code, code_template
from codecritic.data.utils import SPLITTER, mk_message, mk_messages
from codecritic.data.utils import SPLITTER, mk_message
from codecritic.data.verify import mk_critic_verify
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
......@@ -35,10 +35,15 @@ CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect."
def mk_cov_prompt(is_correct, splitter):
prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
def mk_cov_prompt(is_correct, splitter, mode):
if mode == "train":
anchor = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
elif mode == "test":
anchor = ""
else:
raise ValueError(f"Invalid mode: {mode}. Expected 'train' or 'test'.")
turn1 = {"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE}
turn1 = {"role": "user", "content": '\n'.join([anchor, COV_PROMPT, COV_EXAMPLE])}
if splitter:
turn2 = {
"role": "assistant",
......@@ -49,7 +54,7 @@ def mk_cov_prompt(is_correct, splitter):
return [turn1]
def convert_preference_to_vot_prompt(item, splitter=True):
def convert_preference_to_vot_prompt(item, splitter, mode):
message = item["messages"][0]["content"]
chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"]
......@@ -57,29 +62,26 @@ def convert_preference_to_vot_prompt(item, splitter=True):
chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected))
messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter)
return mk_messages(messages1), mk_messages(messages2)
messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter, mode)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter, mode)
return (
{"messages": messages1, "eval_result": True, "problem_id": item["problem_id"]},
{"messages": messages2, "eval_result": False, "problem_id": item["problem_id"]}
)
def convert_sft_to_vot_prompt(item, splitter=True):
def convert_sft_to_vot_prompt(item, splitter, mode):
question = item["messages"][0]["content"]
response = item["messages"][1]["content"]
code = code_template.format(extract_code(response))
messages = mk_message(question, code) + mk_cov_prompt(item["eval_result"], splitter)
return mk_messages(messages)
messages = mk_message(question, code) + mk_cov_prompt(item["eval_result"], splitter, mode)
return {"messages": messages, "eval_result": item["eval_result"], "problem_id": item["problem_id"]}
def convert_cov_to_cov_dataset(item):
user_content = item["messages"][2]["content"]
def convert_cov_to_cov_dataset(item, mode):
item["messages"][2]["content"] = COV_PROMPT
if CORRECT_PROMPT in user_content:
is_correct = True
elif INCORRECT_PROMPT in user_content:
is_correct = False
else:
raise ValueError("Invalid prompt")
item["messages"] += mk_critic_verify(is_correct)
item["eval_result"] = is_correct
if mode == "train":
item["messages"] += mk_critic_verify(item["eval_result"])
return item
......@@ -2,21 +2,16 @@
# while llm should appear in even positions.
from codecritic.utils.json import save_jsonl
from pathlib import Path
def mk_messages(messages):
return {"messages": messages}
def mk_message(user, assistant):
return [
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
# TODO This function can be removed
def save_jsonl_dataset(dataset, output_dir, split="train"):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......
import json
from pathlib import Path
def ensure_parent(path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
def load_jsonl(file_path):
with open(file_path, "r", encoding="utf-8") as f:
......@@ -12,11 +16,13 @@ def load_json(file_path):
def save_jsonl(data, file_path):
ensure_parent(file_path)
with open(file_path, "w", encoding="utf-8") as f:
for item in data:
f.write(json.dumps(item) + "\n")
def save_json(data, file_path, indent=None):
ensure_parent(file_path)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment