Commit 984c7815 by nanziyuan

*DONE* It is the commit to run the mk_cov_from_api

parent 3a9fce5b
import argparse import argparse
from itertools import chain from itertools import chain
import os import os
from tqdm import tqdm from tqdm.contrib.concurrent import thread_map
from threading import Lock
from openai import OpenAI from openai import OpenAI
import json
from codecritic.data.cov import ( from codecritic.data.cov import (
convert_preference_to_vot_prompt, convert_preference_to_vot_prompt,
...@@ -12,34 +14,46 @@ from codecritic.data.utils import save_jsonl_dataset ...@@ -12,34 +14,46 @@ from codecritic.data.utils import save_jsonl_dataset
from codecritic.utils.json import load_json from codecritic.utils.json import load_json
client = OpenAI(
base_url="https://api.deepseek.com/",
api_key=os.environ["DEEPSEEK_API_KEY"]
)
def worker(args):
index, cov_prompt, lock = args
completion = client.chat.completions.create(
model="deepseek-coder",
messages=cov_prompt["messages"],
temperature=0,
max_tokens=2048
)
content = completion.choices[0].message.content
cov_prompt["messages"].append({"role": "assistant", "content": content})
with lock:
with open("tmp.jsonl", "a", encoding="utf-8") as f:
cov_prompt["index"] = index
f.write(json.dumps(cov_prompt) + '\n')
return cov_prompt
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str) parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--output_dir", type=str) parser.add_argument("--output_dir", type=str)
args = parser.parse_args() args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset) preference_dataset = load_json(args.preference_dataset)
cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset] cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts)) cov_prompts = list(chain(*cov_prompts))
lock = Lock()
total_len = len(cov_prompts)
thread_args = list(zip(range(total_len),
cov_prompts,
[lock] * total_len))
client = OpenAI( covs = thread_map(worker, thread_args, max_workers=8)
base_url="https://api.deepseek.com/",
api_key=os.environ["DEEPSEEK_API_KEY"]
)
covs = []
for cov_prompt in tqdm(cov_prompts):
completion = client.chat.completions.create(
model="deepseek-chat",
messages=cov_prompt["messages"],
temperature=0,
max_tokens=2048
)
content = completion.choices[0].message.content
cov_prompt["messages"].append({"role": "assistant", "content": content})
covs.append(cov_prompt)
dataset = list(map(convert_cov_to_cov_dataset, covs)) dataset = list(map(convert_cov_to_cov_dataset, covs))
save_jsonl_dataset(dataset, args.output_dir) save_jsonl_dataset(dataset, args.output_dir)
\ No newline at end of file
...@@ -35,19 +35,21 @@ CORRECT_PROMPT = "Your code is correct." ...@@ -35,19 +35,21 @@ CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect." INCORRECT_PROMPT = "Your code is incorrect."
def mk_cov_prompt(is_correct): def mk_cov_prompt(is_correct, splitter):
prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
return [ turn1 = {"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE}
{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE}, if splitter:
{ turn2 = {
"role": "assistant", "role": "assistant",
"content": "Here's a step-by-step verification of the code." + SPLITTER, "content": "Here's a step-by-step verification of the code." + SPLITTER,
}, }
] return [turn1, turn2]
else:
return [turn1]
def convert_preference_to_vot_prompt(item): def convert_preference_to_vot_prompt(item, splitter=True):
message = item["messages"][0]["content"] message = item["messages"][0]["content"]
chosen = item["chosen"]["content"] chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"] rejected = item["rejected"]["content"]
...@@ -55,8 +57,8 @@ def convert_preference_to_vot_prompt(item): ...@@ -55,8 +57,8 @@ def convert_preference_to_vot_prompt(item):
chosen = code_template.format(extract_code(chosen)) chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected)) rejected = code_template.format(extract_code(rejected))
messages1 = mk_message(message, chosen) + mk_cov_prompt(True) messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False) messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter)
return mk_messages(messages1), mk_messages(messages2) return mk_messages(messages1), mk_messages(messages2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment