Commit ca962f5c by nanziyuan

fix bugs.

parent 40af338d
import argparse import argparse
import datetime
from itertools import chain from itertools import chain
from functools import partial
import os import os
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from threading import Lock from threading import Lock
...@@ -11,8 +13,7 @@ from codecritic.data.cov import ( ...@@ -11,8 +13,7 @@ from codecritic.data.cov import (
convert_sft_to_vot_prompt, convert_sft_to_vot_prompt,
convert_cov_to_cov_dataset, convert_cov_to_cov_dataset,
) )
from codecritic.data.utils import save_jsonl_dataset from codecritic.utils.json import load_json, load_jsonl, save_jsonl
from codecritic.utils.json import load_json, load_jsonl
client = OpenAI( client = OpenAI(
...@@ -21,8 +22,7 @@ client = OpenAI( ...@@ -21,8 +22,7 @@ client = OpenAI(
) )
def worker(args): def worker(index, cov_prompt, lock, tmp_path):
index, cov_prompt, lock = args
completion = client.chat.completions.create( completion = client.chat.completions.create(
model="deepseek-coder", model="deepseek-coder",
messages=cov_prompt["messages"], messages=cov_prompt["messages"],
...@@ -32,7 +32,7 @@ def worker(args): ...@@ -32,7 +32,7 @@ def worker(args):
content = completion.choices[0].message.content content = completion.choices[0].message.content
cov_prompt["messages"].append({"role": "assistant", "content": content}) cov_prompt["messages"].append({"role": "assistant", "content": content})
with lock: with lock:
with open("tmp.jsonl", "a", encoding="utf-8") as f: with open(tmp_path, "a", encoding="utf-8") as f:
cov_prompt["index"] = index cov_prompt["index"] = index
f.write(json.dumps(cov_prompt) + '\n') f.write(json.dumps(cov_prompt) + '\n')
return cov_prompt return cov_prompt
...@@ -40,29 +40,30 @@ def worker(args): ...@@ -40,29 +40,30 @@ def worker(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--preference_dataset", type=str) parser.add_argument("--dataset", type=str)
parser.add_argument("--sft_dataset", type=str) parser.add_argument("--dataset_type", type=str, choices=["reward", "sft"])
parser.add_argument("--output_dir", type=str) parser.add_argument("--out", type=str)
parser.add_argument("--mode", type=str, choices=["train", "test"])
args = parser.parse_args() args = parser.parse_args()
if args.preference_dataset: if args.dataset_type == "reward":
preference_dataset = load_json(args.preference_dataset) preference_dataset = load_json(args.dataset) # TODO change to jsonl
cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False) for x in preference_dataset] cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False, mode=args.mode) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts)) cov_prompts = list(chain(*cov_prompts))
elif args.sft_dataset: elif args.dataset_type == "sft":
sft_dataset = load_jsonl(args.sft_dataset) sft_dataset = load_jsonl(args.dataset)
cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False) for x in sft_dataset] cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False, mode=args.mode) for x in sft_dataset]
else: else:
parser.error("preference_dataset or sft_dataset") parser.error("preference_dataset or sft_dataset")
lock = Lock() lock = Lock()
tmp_path = f"log_{datetime.datetime.now().strftime('%y%m%d_%H%M')}.jsonl"
total_len = len(cov_prompts) total_len = len(cov_prompts)
thread_args = list(zip(range(total_len), w = partial(worker, lock=lock, tmp_path=tmp_path)
cov_prompts,
[lock] * total_len))
covs = thread_map(worker, thread_args, max_workers=8) covs = thread_map(w, range(total_len), cov_prompts, max_workers=8)
dataset = list(map(convert_cov_to_cov_dataset, covs)) save_jsonl(covs, args.out + ".raw")
save_jsonl_dataset(dataset, args.output_dir) dataset = [convert_cov_to_cov_dataset(x, args.mode) for x in covs]
save_jsonl(dataset, args.out)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Is reasoning really work? Let's verify step by step. # Is reasoning really work? Let's verify step by step.
from codecritic.data.code import extract_code, code_template from codecritic.data.code import extract_code, code_template
from codecritic.data.utils import SPLITTER, mk_message, mk_messages from codecritic.data.utils import SPLITTER, mk_message
from codecritic.data.verify import mk_critic_verify from codecritic.data.verify import mk_critic_verify
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue." COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
...@@ -35,10 +35,15 @@ CORRECT_PROMPT = "Your code is correct." ...@@ -35,10 +35,15 @@ CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect." INCORRECT_PROMPT = "Your code is incorrect."
def mk_cov_prompt(is_correct, splitter): def mk_cov_prompt(is_correct, splitter, mode):
prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT if mode == "train":
anchor = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
elif mode == "test":
anchor = ""
else:
raise ValueError(f"Invalid mode: {mode}. Expected 'train' or 'test'.")
turn1 = {"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE} turn1 = {"role": "user", "content": '\n'.join([anchor, COV_PROMPT, COV_EXAMPLE])}
if splitter: if splitter:
turn2 = { turn2 = {
"role": "assistant", "role": "assistant",
...@@ -49,7 +54,7 @@ def mk_cov_prompt(is_correct, splitter): ...@@ -49,7 +54,7 @@ def mk_cov_prompt(is_correct, splitter):
return [turn1] return [turn1]
def convert_preference_to_vot_prompt(item, splitter=True): def convert_preference_to_vot_prompt(item, splitter, mode):
message = item["messages"][0]["content"] message = item["messages"][0]["content"]
chosen = item["chosen"]["content"] chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"] rejected = item["rejected"]["content"]
...@@ -57,29 +62,26 @@ def convert_preference_to_vot_prompt(item, splitter=True): ...@@ -57,29 +62,26 @@ def convert_preference_to_vot_prompt(item, splitter=True):
chosen = code_template.format(extract_code(chosen)) chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected)) rejected = code_template.format(extract_code(rejected))
messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter) messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter, mode)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter) messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter, mode)
return mk_messages(messages1), mk_messages(messages2) return (
{"messages": messages1, "eval_result": True, "problem_id": item["problem_id"]},
{"messages": messages2, "eval_result": False, "problem_id": item["problem_id"]}
)
def convert_sft_to_vot_prompt(item, splitter=True): def convert_sft_to_vot_prompt(item, splitter, mode):
question = item["messages"][0]["content"] question = item["messages"][0]["content"]
response = item["messages"][1]["content"] response = item["messages"][1]["content"]
code = code_template.format(extract_code(response)) code = code_template.format(extract_code(response))
messages = mk_message(question, code) + mk_cov_prompt(item["eval_result"], splitter) messages = mk_message(question, code) + mk_cov_prompt(item["eval_result"], splitter, mode)
return mk_messages(messages) return {"messages": messages, "eval_result": item["eval_result"], "problem_id": item["problem_id"]}
def convert_cov_to_cov_dataset(item): def convert_cov_to_cov_dataset(item, mode):
user_content = item["messages"][2]["content"]
item["messages"][2]["content"] = COV_PROMPT item["messages"][2]["content"] = COV_PROMPT
if CORRECT_PROMPT in user_content:
is_correct = True if mode == "train":
elif INCORRECT_PROMPT in user_content: item["messages"] += mk_critic_verify(item["eval_result"])
is_correct = False
else:
raise ValueError("Invalid prompt")
item["messages"] += mk_critic_verify(is_correct)
item["eval_result"] = is_correct
return item return item
...@@ -2,21 +2,16 @@ ...@@ -2,21 +2,16 @@
# while llm should appear in even positions. # while llm should appear in even positions.
from codecritic.utils.json import save_jsonl from codecritic.utils.json import save_jsonl
from pathlib import Path from pathlib import Path
def mk_messages(messages):
return {"messages": messages}
def mk_message(user, assistant): def mk_message(user, assistant):
return [ return [
{"role": "user", "content": user}, {"role": "user", "content": user},
{"role": "assistant", "content": assistant}, {"role": "assistant", "content": assistant},
] ]
# TODO This function can be removed
def save_jsonl_dataset(dataset, output_dir, split="train"): def save_jsonl_dataset(dataset, output_dir, split="train"):
output_dir = Path(output_dir) output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
......
import json import json
from pathlib import Path
def ensure_parent(path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
def load_jsonl(file_path): def load_jsonl(file_path):
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
...@@ -12,11 +16,13 @@ def load_json(file_path): ...@@ -12,11 +16,13 @@ def load_json(file_path):
def save_jsonl(data, file_path): def save_jsonl(data, file_path):
ensure_parent(file_path)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
for item in data: for item in data:
f.write(json.dumps(item) + "\n") f.write(json.dumps(item) + "\n")
def save_json(data, file_path, indent=None): def save_json(data, file_path, indent=None):
ensure_parent(file_path)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent) json.dump(data, f, indent=indent)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment