Commit 2ab59c74 by nzy

step2: prepare_preference_dataset create preference dataset based on maximum edit distance.

parent 348877f0
...@@ -12,6 +12,13 @@ max_new_tokens = 512 ...@@ -12,6 +12,13 @@ max_new_tokens = 512
[evaluate] [evaluate]
evaluate_result_path = "" evaluate_result_path = ""
[dataset]
train_path = "" train_path = ""
test_path = "" test_path = ""
minimal_test_path = "" minimal_test_path = ""
\ No newline at end of file
[preference_dataset.max_edit_distance]
metadata_path = ""
preference_dataset_path = ""
dataset_info_path = ""
\ No newline at end of file
--- ---
title: "Code Critic" title: "Code Critic"
format: typst format: typst
bibliography: refs.bib
--- ---
## Abstract ## Abstract
...@@ -25,8 +26,14 @@ python step1_sort_split_dataset.py ...@@ -25,8 +26,14 @@ python step1_sort_split_dataset.py
### Step2 Prepare preference code pairs ### Step2 Prepare preference code pairs
#### Preference Dataset for Reward Model
Select pairs based on maximum edit distance, as described in [@shen2024policyfiltrationrlhffinetune; @pal2024smaugfixingfailuremodes].
### Step3 Train ORM & Critic Model ### Step3 Train ORM & Critic Model
ORM training follows [@Ouyang2022TrainingLM]
### Step4 Evaluate ORM & Critic Model ### Step4 Evaluate ORM & Critic Model
......
@article{Ouyang2022TrainingLM,
title={Training language models to follow instructions with human feedback},
author={Long Ouyang and Jeff Wu and Xu Jiang and Diogo Almeida and Carroll L. Wainwright and Pamela Mishkin and Chong Zhang and Sandhini Agarwal and Katarina Slama and Alex Ray and John Schulman and Jacob Hilton and Fraser Kelton and Luke E. Miller and Maddie Simens and Amanda Askell and Peter Welinder and Paul Francis Christiano and Jan Leike and Ryan J. Lowe},
journal={ArXiv},
year={2022},
volume={abs/2203.02155},
url={https://arxiv.org/abs/2203.02155}
}
@misc{pal2024smaugfixingfailuremodes,
title={Smaug: Fixing Failure Modes of Preference Optimisation with DPO-Positive},
author={Arka Pal and Deep Karkhanis and Samuel Dooley and Manley Roberts and Siddartha Naidu and Colin White},
year={2024},
eprint={2402.13228},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2402.13228},
}
@misc{shen2024policyfiltrationrlhffinetune,
title={Policy Filtration in RLHF to Fine-Tune LLM for Code Generation},
author={Wei Shen and Chuheng Zhang},
year={2024},
eprint={2409.06957},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2409.06957},
}
\ No newline at end of file
...@@ -62,8 +62,8 @@ if __name__ == "__main__": ...@@ -62,8 +62,8 @@ if __name__ == "__main__":
cfg = read_config cfg = read_config
sort_and_split_dataset( sort_and_split_dataset(
cfg["evaluate"]["evaluate_result_path"], cfg["evaluate"]["evaluate_result_path"],
cfg["evaluate"]["train_path"], cfg["dataset"]["train_path"],
cfg["evaluate"]["test_path"], cfg["dataset"]["test_path"],
cfg["evaluate"]["minimal_test_path"], cfg["dataset"]["minimal_test_path"],
cfg["sample"]["sampling_params"]["n"] cfg["sample"]["sampling_params"]["n"]
) )
\ No newline at end of file
from utils import load_jsonl, save_jsonl, save_json, extract_code, read_config
from utils_preference_dataset import mk_dataset_info, mk_preference_pair
from nltk.metrics.distance import edit_distance
from collections import defaultdict
from itertools import product
def mk_problem_groups(train_dataset_path, n):
train_dataset = load_jsonl(train_dataset_path)
assert len(train_dataset) % n == 0
problems = []
for i in range(len(train_dataset) // n):
problem = train_dataset[i * n : (i + 1) * n]
problem_id = problem[0]["problem_id"]
eval_results = [d["eval_result"] for d in problem]
# filter all passed/failed problems
if True in eval_results and False in eval_results:
instruction = problem[0]["messages"][0]["content"]
correct_codes, incorrect_codes = [], []
for d in problem:
assert d["problem_id"] == problem_id, "dataset is not sorted"
code = extract_code(d["messages"][1]["content"])
if d["eval_result"]:
correct_codes.append(code)
else:
incorrect_codes.append(code)
problems.append(
dict(
problem_id=problem_id,
instruction=instruction,
correct_codes=correct_codes,
incorrect_codes=incorrect_codes,
)
)
return problems
def max_edit_distance(problems, k, n):
"""
Top-k pairs with the maximum edit distance.
Each problem can contribute no more than n pairs.
Each code snippet can be used only once.
"""
all_pairs = []
code_usages = defaultdict(set)
problem_contributions = defaultdict(int)
for problem in problems:
for pair in product(problem["correct_codes"], problem["incorrect_codes"]):
# transform incorrect code to correct code
distance = edit_distance(pair[1], pair[0])
all_pairs.append((distance, problem["problem_id"], problem["instruction"], pair))
all_pairs.sort(reverse=True, key=lambda x: x[0])
preference_pairs, pairs_metadata = [], []
for distance, problem_id, instr, pair in all_pairs:
if len(preference_pairs) > k:
break
is_code_used = (pair[0] in code_usages[problem_id]) or (pair[1] in code_usages[problem_id])
if not is_code_used and problem_contributions[problem_id] < n:
code_usages[problem_id].update(pair)
problem_contributions[problem_id] += 1
preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1]))
pairs_metadata.append(problem_id)
return preference_pairs, pairs_metadata
if __name__ == "__main__":
cfg = read_config()
problems = mk_problem_groups(
cfg["dataset"]["train_path"], cfg["sample"]["sampling_params"]["n"]
)
preference_pairs, metadata = max_edit_distance(problems, 10 * 1000, 5)
dataset_info = mk_dataset_info("apps_max_edit_distance_prefrence")
save_jsonl(metadata, cfg["preference_dataset"]["max_edit_distance"]["metadata_path"])
save_jsonl(preference_pairs, cfg["preference_dataset"]["max_edit_distance"]["preference_dataset_path"])
save_json(dataset_info, cfg["preference_dataset"]["max_edit_distance"]["dataset_info_path"])
...@@ -27,6 +27,10 @@ def save_jsonl(data, file_path): ...@@ -27,6 +27,10 @@ def save_jsonl(data, file_path):
for item in data: for item in data:
f.write(json.dumps(item) + "\n") f.write(json.dumps(item) + "\n")
def save_json(data, file_path):
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f)
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL) codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
......
def mk_dataset_info(dataset_name):
return {
dataset_name: {
"file_name": f"{dataset_name}.json",
"formatting": "sharegpt",
"ranking": True,
"columns": {
"messages": "messages",
"chosen": "chosen",
"rejected": "rejected",
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system",
},
}
}
def mk_preference_pair(instruction, chosen_code, rejected_code):
return {
"messages": [
{"role": "user", "content": instruction},
],
"chosen": {"role": "assistant", "content": chosen_code},
"rejected": {"role": "assistant", "content": rejected_code},
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment