Commit f3dd6691 by nanziyuan

merge conflicts

parents 9af763c6 79b13ce2
......@@ -4,7 +4,7 @@
```
pip install scikit-learn
pip install
pip install nltk
```
## Evaluation
......
import argparse
from itertools import product, chain
from codecritic.utils.json import load_jsonl, save_jsonl
def mk_preference_pair(ds, pair):
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
return {
"messages": chosen["messages"][:1],
"chosen": chosen["messages"][1:],
"rejected": rejected["messages"][1:],
"meta_pairinfo": pair
}
def mk_sft(ds, pair):
dataset_name = pair["dataset"]
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
# TODO add judgement response
return [
{
"question": chosen["messages"][:1],
"response": chosen["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": chosen
},
{
"question": rejected["messages"][:1],
"response": rejected["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": rejected
}
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="path/to/dataset")
parser.add_argument("--pairs", type=str, help="path/to/selected_pairs")
parser.add_argument("--output", type=str, help="path/to/output")
args = parser.parse_args()
dataset = load_jsonl(args.dataset)
selected_pairs = load_jsonl(args.pairs)
if args.format == "sft":
sft_ds = list(chain.from_iterable([mk_sft(dataset, pair) for pair in selected_pairs]))
save_jsonl(sft_ds, args.output)
elif args.format == "reward":
reward_ds = [mk_preference_pair(dataset, pair) for pair in selected_pairs]
save_jsonl(reward_ds, args.output)
else:
raise NotImplementedError(f"Unknown format: {args.format}")
import argparse
from collections import defaultdict
from itertools import product, chain
from pathlib import Path
import os
import re
from tqdm.contrib.concurrent import process_map
from rapidfuzz import fuzz
from codecritic.utils.json import load_jsonl, save_jsonl
def group_and_filter(dataset):
grouped = defaultdict(list)
for sample in dataset:
grouped[sample["task_id"]].append(sample)
# filter groups passed/failed all testcase
for task_id, group in grouped.items():
passes = {x["pass"] for x in group}
if len(passes) == 2:
yield group
# Precompile regular expressions
SINGLE_LINE_COMMENT_REGEX = re.compile(r'#.*')
MULTILINE_DOUBLE_QUOTE_REGEX = re.compile(r'^\s*""".*?"""\s*$', flags=re.DOTALL | re.MULTILINE)
MULTILINE_SINGLE_QUOTE_REGEX = re.compile(r"^\s*'''.*?'''\s*$", flags=re.DOTALL | re.MULTILINE)
def preprocess_code(code):
# Remove single-line comments
code = SINGLE_LINE_COMMENT_REGEX.sub('', code)
# Remove standalone docstrings (triple-quoted strings that are not part of an expression)
code = MULTILINE_DOUBLE_QUOTE_REGEX.sub('', code)
code = MULTILINE_SINGLE_QUOTE_REGEX.sub('', code)
# Remove blank lines
code = "\n".join([line for line in code.splitlines() if line.strip()])
return code
def compute_pair_similarity(group):
correct_code_set, incorrect_code_set = set(), {''}
correct_samples, incorrect_samples = [], []
assert len(group) > 0
dataset_name = group[0]["dataset"]
task_id = group[0]["task_id"]
for sample in group:
code = preprocess_code(sample["code"])
item = {
"solution_id": sample["solution_id"],
"code": code
}
if sample["pass"] and (code not in correct_code_set):
correct_samples.append(item)
elif (not sample["pass"]) and (code not in incorrect_code_set):
incorrect_samples.append(item)
results = []
for correct, incorrect in product(correct_samples, incorrect_samples):
score = fuzz.ratio(correct["code"], incorrect["code"])
results.append({
"dataset": dataset_name,
"task_id": task_id,
"chosen": correct["solution_id"],
"rejected": incorrect["solution_id"],
"similarity": score
})
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="path/to/dataset")
parser.add_argument("--output", type=str, help="path/to/output")
args = parser.parse_args()
dataset = load_jsonl(args.dataset)
cache_path = Path(args.dataset + ".pairinfo")
if not cache_path.exists():
results = process_map(
compute_pair_similarity,
group_and_filter(dataset),
max_workers=os.cpu_count(),
chunksize=1,
)
pairinfo = list(chain.from_iterable(results))
save_jsonl(pairinfo, cache_path)
else:
pairinfo = load_jsonl(cache_path)
print(f"load cached similarity information from {cache_path}")
# select pairs
ds = defaultdict(dict)
for item in dataset:
ds[item["task_id"]][item["solution_id"]] = item
sorted_pairinfo = sorted(pairinfo, key=lambda x: x["similarity"])
task_groups = defaultdict(list)
for item in pairinfo:
task_groups[item["task_id"]].append(item)
# Step 2: Select the 4 pairs with the smallest score for each task
selected_pairs = []
for task, items in task_groups.items():
# Sort items for this task by score and select the top 4
sorted_items = sorted(items, key=lambda x: x["similarity"])[:4]
selected_pairs.extend(sorted_items)
save_jsonl(selected_pairs, args.output)
......@@ -5,8 +5,7 @@ import os
from transformers import AutoTokenizer
from vllm import SamplingParams
from codecritic.dataset.genrm_prompt import JUDGE_MESSAGE, JUDGE_TOEKNS
from codecritic.dataset.legacy_genrm_prompt import COV_MESSAGE
from codecritic.dataset.genrm_prompt import THINK_MESSAGE, JUDGE_MESSAGE, JUDGE_TOEKNS
from codecritic.utils.inference import generate_worker, score_worker
from codecritic.utils.parallel import model_map
from codecritic.utils.json import load_jsonl, save_jsonl
......@@ -36,7 +35,7 @@ if __name__ == "__main__":
if args.reasoning:
for item in dataset:
item["messages"].append(COV_MESSAGE)
item["messages"].append(THINK_MESSAGE)
sampling_params = SamplingParams(
n=1,
......
from codecritic.utils.json import load_jsonl
from codecritic.dataset.code import extract_code, code_template
from nltk.metrics.distance import edit_distance
from collections import defaultdict
from itertools import product, chain
import multiprocessing
from tqdm.contrib.concurrent import process_map
def mk_preference_pair(instruction, chosen_code, rejected_code):
return {
"messages": [
{"role": "user", "content": instruction},
],
"chosen": {"role": "assistant", "content": code_template.format(chosen_code)},
"rejected": {
"role": "assistant",
"content": code_template.format(rejected_code),
},
}
def mk_problem_groups(train_dataset_path, n):
train_dataset = load_jsonl(train_dataset_path)
assert len(train_dataset) % n == 0
problems = []
for i in range(len(train_dataset) // n):
problem = train_dataset[i * n : (i + 1) * n]
problem_id = problem[0]["problem_id"]
eval_results = [d["eval_result"] for d in problem]
# filter all passed/failed problems
if True in eval_results and False in eval_results:
instruction = problem[0]["messages"][0]["content"]
correct_codes, incorrect_codes = [], []
for d in problem:
assert d["problem_id"] == problem_id, "dataset is not sorted"
code = extract_code(d["messages"][1]["content"])
if d["eval_result"]:
correct_codes.append(code)
else:
incorrect_codes.append(code)
problems.append(
dict(
problem_id=problem_id,
instruction=instruction,
correct_codes=correct_codes,
incorrect_codes=incorrect_codes,
)
)
return problems
def calculate_edit_distances_for_problem(problem):
local_pairs = []
for pair in product(problem["correct_codes"], problem["incorrect_codes"]):
# transform incorrect code to correct code
distance = edit_distance(pair[1], pair[0])
local_pairs.append(
(distance, problem["problem_id"], problem["instruction"], pair)
)
return local_pairs
def calculate_edit_distances(problems):
cpu_num = multiprocessing.cpu_count()
results = process_map(
calculate_edit_distances_for_problem,
problems,
max_workers=cpu_num,
chunksize=32,
)
return list(chain.from_iterable(results))
def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
"""
Top-k pairs with the maximum/minimum edit distance.
Each problem can contribute no more than n pairs.
Each code snippet can be used only once.
"""
all_pairs.sort(reverse=is_max, key=lambda x: x[0])
code_usages = defaultdict(set)
problem_contributions = defaultdict(int)
preference_pairs, pairs_metadata = [], []
for distance, problem_id, instr, pair in all_pairs:
if len(preference_pairs) >= k:
break
is_code_used = (pair[0] in code_usages[problem_id]) or (
pair[1] in code_usages[problem_id]
)
if not is_code_used and problem_contributions[problem_id] < n:
code_usages[problem_id].update(pair)
problem_contributions[problem_id] += 1
preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1]))
pairs_metadata.append(dict(problem_id=problem_id, edit_distance=distance))
return preference_pairs, pairs_metadata
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment