Commit 7c231b37 by nzy

step2 prepare_freference_dataset parallelize edit distance calculation

parent 99bd77f6
......@@ -21,4 +21,10 @@ minimal_test_path = ""
[preference_dataset.max_edit_distance]
metadata_path = ""
preference_dataset_path = ""
dataset_info_path = ""
[preference_dataset.min_edit_distance]
metadata_path = ""
preference_dataset_path = ""
dataset_info_path = ""
\ No newline at end of file
......@@ -2,7 +2,10 @@ from utils import load_jsonl, save_jsonl, save_json, extract_code, read_config
from utils_preference_dataset import mk_dataset_info, mk_preference_pair
from nltk.metrics.distance import edit_distance
from collections import defaultdict
from itertools import product
from itertools import product, chain
import multiprocessing
from tqdm.contrib.concurrent import process_map
def mk_problem_groups(train_dataset_path, n):
train_dataset = load_jsonl(train_dataset_path)
......@@ -35,36 +38,51 @@ def mk_problem_groups(train_dataset_path, n):
return problems
def max_edit_distance(problems, k, n):
def calculate_edit_distances_for_problem(problem):
local_pairs = []
for pair in product(problem["correct_codes"], problem["incorrect_codes"]):
# transform incorrect code to correct code
distance = edit_distance(pair[1], pair[0])
local_pairs.append(
(distance, problem["problem_id"], problem["instruction"], pair)
)
return local_pairs
def calculate_edit_distances(problems):
cpu_num = multiprocessing.cpu_count()
results = process_map(
calculate_edit_distances_for_problem,
problems,
max_workers=cpu_num,
chunksize=32,
)
return list(chain.from_iterable(results))
def edit_distance(all_pairs, k, n, is_max=True):
"""
Top-k pairs with the maximum edit distance.
Top-k pairs with the maximum/minimum edit distance.
Each problem can contribute no more than n pairs.
Each code snippet can be used only once.
"""
all_pairs = []
all_pairs.sort(reverse=is_max, key=lambda x: x[0])
code_usages = defaultdict(set)
problem_contributions = defaultdict(int)
for problem in problems:
for pair in product(problem["correct_codes"], problem["incorrect_codes"]):
# transform incorrect code to correct code
distance = edit_distance(pair[1], pair[0])
all_pairs.append((distance, problem["problem_id"], problem["instruction"], pair))
all_pairs.sort(reverse=True, key=lambda x: x[0])
preference_pairs, pairs_metadata = [], []
for distance, problem_id, instr, pair in all_pairs:
if len(preference_pairs) > k:
break
is_code_used = (pair[0] in code_usages[problem_id]) or (pair[1] in code_usages[problem_id])
is_code_used = (pair[0] in code_usages[problem_id]) or (
pair[1] in code_usages[problem_id]
)
if not is_code_used and problem_contributions[problem_id] < n:
code_usages[problem_id].update(pair)
problem_contributions[problem_id] += 1
preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1]))
pairs_metadata.append(problem_id)
pairs_metadata.append(dict(problem_id=problem_id, edit_distance=distance))
return preference_pairs, pairs_metadata
......@@ -75,8 +93,38 @@ if __name__ == "__main__":
cfg["dataset"]["train_path"], cfg["sample"]["sampling_params"]["n"]
)
preference_pairs, metadata = max_edit_distance(problems, 10 * 1000, 5)
all_edit_distance_pairs = calculate_edit_distances(problems)
# Maximum distance
preference_pairs, metadata = edit_distance(
all_edit_distance_pairs, 10 * 1000, 5, is_max=True
)
dataset_info = mk_dataset_info("apps_max_edit_distance_prefrence")
save_jsonl(metadata, cfg["preference_dataset"]["max_edit_distance"]["metadata_path"])
save_jsonl(preference_pairs, cfg["preference_dataset"]["max_edit_distance"]["preference_dataset_path"])
save_json(dataset_info, cfg["preference_dataset"]["max_edit_distance"]["dataset_info_path"])
save_jsonl(
metadata, cfg["preference_dataset"]["max_edit_distance"]["metadata_path"]
)
save_jsonl(
preference_pairs,
cfg["preference_dataset"]["max_edit_distance"]["preference_dataset_path"],
)
save_json(
dataset_info,
cfg["preference_dataset"]["max_edit_distance"]["dataset_info_path"],
)
# Minimum distance
preference_pairs, metadata = edit_distance(
all_edit_distance_pairs, 10 * 1000, 5, is_max=False
)
dataset_info = mk_dataset_info("apps_min_edit_distance_prefrence")
save_jsonl(
metadata, cfg["preference_dataset"]["min_edit_distance"]["metadata_path"]
)
save_jsonl(
preference_pairs,
cfg["preference_dataset"]["min_edit_distance"]["preference_dataset_path"],
)
save_json(
dataset_info,
cfg["preference_dataset"]["min_edit_distance"]["dataset_info_path"],
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment