Commit 3a9fce5b by nanziyuan

Merge branch 'main' of http://62.234.201.16/nzy/codecritic

parents d447a68a 5c43bbaa
import argparse
from itertools import chain
from codecritic.data.cov import (
convert_preference_to_vot_prompt,
convert_cov_to_cov_dataset,
)
from codecritic.utils.json import load_json
from codecritic.data.utils import save_jsonl_dataset
from codecritic.utils.vllm import vllm_chatcomplete
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts))
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs))
save_jsonl_dataset(dataset, args.output_dir)
import argparse
from itertools import chain
import os
from tqdm import tqdm
from openai import OpenAI
from codecritic.data.cov import (
convert_preference_to_vot_prompt,
convert_cov_to_cov_dataset,
)
from codecritic.data.utils import save_jsonl_dataset
from codecritic.utils.json import load_json
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts))
client = OpenAI(
base_url="https://api.deepseek.com/",
api_key=os.environ["DEEPSEEK_API_KEY"]
)
covs = []
for cov_prompt in tqdm(cov_prompts):
completion = client.chat.completions.create(
model="deepseek-chat",
messages=cov_prompt["messages"],
temperature=0,
max_tokens=2048
)
content = completion.choices[0].message.content
cov_prompt["messages"].append({"role": "assistant", "content": content})
covs.append(cov_prompt)
dataset = list(map(convert_cov_to_cov_dataset, covs))
save_jsonl_dataset(dataset, args.output_dir)
\ No newline at end of file
import argparse
from pathlib import Path
from codecritic.utils.json import load_json
from codecritic.data.utils import save_jsonl_dataset
from codecritic.data.edit_distance import (
mk_problem_groups,
calculate_edit_distances,
mk_edit_distance_dataset,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--is_max", type=bool, required=True)
args = parser.parse_args()
dataset_dir = Path(args.dataset_dir)
train_path = dataset_dir / "train.jsonl"
sampling_params = load_json(dataset_dir / "sampling_params.json")
problems = mk_problem_groups(train_path, sampling_params["n"])
all_edit_distance_pairs = calculate_edit_distances(problems)
postfix = "max" if args.is_max else "min"
dataset_name = f"apps_edit_distance_{postfix}"
preference_pairs, metadata = mk_edit_distance_dataset(
all_edit_distance_pairs, 10 * 1000, 5, is_max=args.is_max
)
save_jsonl_dataset(preference_pairs, args.output_dir)
...@@ -7,8 +7,8 @@ from transformers import AutoTokenizer ...@@ -7,8 +7,8 @@ from transformers import AutoTokenizer
import pprint import pprint
from pathlib import Path from pathlib import Path
from codecritic.data.code import code_template, extract_code
from codecritic.utils.json import load_jsonl, save_jsonl from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.utils.data import extract_code, code_template
from codecritic.utils.metric import group_results, score_pass_at_k from codecritic.utils.metric import group_results, score_pass_at_k
......
...@@ -2,12 +2,15 @@ import argparse ...@@ -2,12 +2,15 @@ import argparse
from pathlib import Path from pathlib import Path
import pprint import pprint
from codecritic.data.utils import mk_message
from codecritic.data.verify import JUDGE_PROMPT
from transformers import AutoTokenizer from transformers import AutoTokenizer
from codecritic.data.code import extract_code, code_template
from codecritic.data.cov import COV_PROMPT from codecritic.data.cov import COV_PROMPT
from codecritic.data.verify import get_score_token_id
from codecritic.utils.vllm import vllm_chatcomplete, vllm_score from codecritic.utils.vllm import vllm_chatcomplete, vllm_score
from codecritic.utils.json import load_jsonl, save_jsonl from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.utils.data import extract_code, code_template, mk_message, JUDGE_PROMPT, get_score_token_id
from codecritic.utils.metric import group_results, score_pass_at_k from codecritic.utils.metric import group_results, score_pass_at_k
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
# 2. Using SFT (Supervised Fine-Tuning) directly # 2. Using SFT (Supervised Fine-Tuning) directly
# This experiment aims to fairly compare these two approaches. # This experiment aims to fairly compare these two approaches.
import argparse import argparse
from codecritic.data.utils import mk_message, mk_messages, save_jsonl_dataset
from codecritic.utils.json import load_json from codecritic.utils.json import load_json
from codecritic.utils.data import mk_message, mk_critic_verify, mk_messages, save_jsonl_dataset from codecritic.data.verify import mk_critic_verify
def convert_preference_to_sft(item): def convert_preference_to_sft(item):
......
import re
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
{}
```
"""
def extract_code(text: str):
codes = [match.strip() for match in re.findall(codeblock_pattern, text)]
if len(codes) > 0:
code = "\n".join(codes)
return code
else:
return ""
\ No newline at end of file
# Additional Experiment: # Additional Experiment:
# Is reasoning really work? Let's verify step by step. # Is reasoning really work? Let's verify step by step.
import argparse from codecritic.data.code import extract_code, code_template
from itertools import chain
from codecritic.data.utils import SPLITTER, mk_message, mk_messages
from codecritic.utils.json import load_json from codecritic.data.verify import mk_critic_verify
from codecritic.utils.data import (
extract_code,
code_template,
mk_message,
mk_messages,
mk_critic_verify,
save_jsonl_dataset,
SPLITTER,
)
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue." COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
...@@ -80,23 +71,3 @@ def convert_cov_to_cov_dataset(item): ...@@ -80,23 +71,3 @@ def convert_cov_to_cov_dataset(item):
raise ValueError("Invalid prompt") raise ValueError("Invalid prompt")
item["messages"] += mk_critic_verify(is_correct) item["messages"] += mk_critic_verify(is_correct)
return item return item
if __name__ == "__main__":
from codecritic.utils.vllm import vllm_chatcomplete
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts))
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs))
save_jsonl_dataset(dataset, args.output_dir)
import argparse from codecritic.utils.json import load_jsonl
from pathlib import Path from codecritic.data.code import extract_code, code_template
from codecritic.utils.json import load_json, load_jsonl
from codecritic.utils.data import extract_code, mk_preference_pair, save_jsonl_dataset
from nltk.metrics.distance import edit_distance from nltk.metrics.distance import edit_distance
from collections import defaultdict from collections import defaultdict
from itertools import product, chain from itertools import product, chain
import multiprocessing import multiprocessing
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
def mk_preference_pair(instruction, chosen_code, rejected_code):
return {
"messages": [
{"role": "user", "content": instruction},
],
"chosen": {"role": "assistant", "content": code_template.format(chosen_code)},
"rejected": {
"role": "assistant",
"content": code_template.format(rejected_code),
},
}
def mk_problem_groups(train_dataset_path, n): def mk_problem_groups(train_dataset_path, n):
train_dataset = load_jsonl(train_dataset_path) train_dataset = load_jsonl(train_dataset_path)
...@@ -86,27 +96,4 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True): ...@@ -86,27 +96,4 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1])) preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1]))
pairs_metadata.append(dict(problem_id=problem_id, edit_distance=distance)) pairs_metadata.append(dict(problem_id=problem_id, edit_distance=distance))
return preference_pairs, pairs_metadata return preference_pairs, pairs_metadata
\ No newline at end of file
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--is_max", type=bool, required=True)
args = parser.parse_args()
dataset_dir = Path(args.dataset_dir)
train_path = dataset_dir / "train.jsonl"
sampling_params = load_json(dataset_dir / "sampling_params.json")
problems = mk_problem_groups(train_path, sampling_params["n"])
all_edit_distance_pairs = calculate_edit_distances(problems)
postfix = "max" if args.is_max else "min"
dataset_name = f"apps_edit_distance_{postfix}"
preference_pairs, metadata = mk_edit_distance_dataset(
all_edit_distance_pairs, 10 * 1000, 5, is_max=args.is_max
)
save_jsonl_dataset(preference_pairs, args.output_dir)
# Note that the human and observation should appear in odd positions
# while llm should appear in even positions.
from codecritic.utils.json import save_jsonl
from pathlib import Path
def mk_messages(messages):
return {"messages": messages}
def mk_message(user, assistant):
return [
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
def save_jsonl_dataset(dataset, output_dir, split="train"):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
save_jsonl(dataset, output_dir / f"{split}.jsonl")
SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"
JUDGE_PROMPT = "Is the code correct (Yes/No)?"
def mk_critic_verify(answer=None):
# answer: bool or none
message = [{"role": "user", "content": JUDGE_PROMPT}]
if answer is not None:
response = "Yes" if answer else "No"
message.append({"role": "assistant", "content": response})
return message
def get_score_token_id(tokenizer, token_str="Yes"):
score_tokens = tokenizer.encode(token_str, add_special_tokens=False)
assert len(score_tokens) == 1
return score_tokens[0]
...@@ -9,7 +9,7 @@ from tqdm.contrib.concurrent import process_map ...@@ -9,7 +9,7 @@ from tqdm.contrib.concurrent import process_map
from codecritic.sampling.apps_test import run_test from codecritic.sampling.apps_test import run_test
from codecritic.utils.json import save_jsonl from codecritic.utils.json import save_jsonl
from codecritic.utils.data import extract_code from codecritic.data.code import extract_code
TIMEOUT = 10 TIMEOUT = 10
......
import re
from codecritic.utils.json import save_jsonl
from pathlib import Path
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
{}
```
"""
def extract_code(text: str):
codes = [match.strip() for match in re.findall(codeblock_pattern, text)]
if len(codes) > 0:
code = "\n".join(codes)
return code
else:
return ""
def mk_preference_pair(instruction, chosen_code, rejected_code):
return {
"messages": [
{"role": "user", "content": instruction},
],
"chosen": {"role": "assistant", "content": code_template.format(chosen_code)},
"rejected": {
"role": "assistant",
"content": code_template.format(rejected_code),
},
}
# Note that the human and observation should appear in odd positions
# while llm should appear in even positions.
def mk_messages(messages):
return {"messages": messages}
def mk_message(user, assistant):
return [
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
JUDGE_PROMPT = "Is the code correct (Yes/No)?"
def mk_critic_verify(answer=None):
# answer: bool or none
message = [{"role": "user", "content": JUDGE_PROMPT}]
if answer is not None:
response = "Yes" if answer else "No"
message.append({"role": "assistant", "content": response})
return message
def save_jsonl_dataset(dataset, output_dir, split="train"):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
save_jsonl(dataset, output_dir / f"{split}.jsonl")
def get_score_token_id(tokenizer, token_str="Yes"):
score_tokens = tokenizer.encode(token_str, add_special_tokens=False)
assert len(score_tokens) == 1
return score_tokens[0]
def mk_critic_reason(codedit, explanation):
user_question = {
"role": "user",
"content": "Edit your code in diff format to fix any issues and explain the changes.",
}
llm_answer_content = f"""\
**Edited Code (in diff format):**
```diff
{codedit}
```
**Explanation:**
{explanation}
"""
llm_answer = {"role": "assistant", "content": llm_answer_content}
return [user_question, llm_answer]
SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"
...@@ -5,7 +5,7 @@ import multiprocessing ...@@ -5,7 +5,7 @@ import multiprocessing
from itertools import chain from itertools import chain
from functools import partial from functools import partial
from codecritic.utils.data import SPLITTER from codecritic.data.utils import SPLITTER
import numpy as np import numpy as np
def generate_worker(cuda_device, prompts, model_path, sampling_params): def generate_worker(cuda_device, prompts, model_path, sampling_params):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment