Commit 79b13ce2 by nzy

select preference pairs

parent 61491762
import argparse
from itertools import product, chain
from codecritic.utils.json import load_jsonl, save_jsonl
def mk_preference_pair(ds, pair):
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
return {
"messages": chosen["messages"][:1],
"chosen": chosen["messages"][1:],
"rejected": rejected["messages"][1:],
"meta_pairinfo": pair
}
def mk_sft(ds, pair):
dataset_name = pair["dataset"]
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
# TODO add judgement response
return [
{
"question": chosen["messages"][:1],
"response": chosen["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": chosen
},
{
"question": rejected["messages"][:1],
"response": rejected["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": rejected
}
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="path/to/dataset")
parser.add_argument("--pairs", type=str, help="path/to/selected_pairs")
parser.add_argument("--output", type=str, help="path/to/output")
args = parser.parse_args()
dataset = load_jsonl(args.dataset)
selected_pairs = load_jsonl(args.pairs)
if args.format == "sft":
sft_ds = list(chain.from_iterable([mk_sft(dataset, pair) for pair in selected_pairs]))
save_jsonl(sft_ds, args.output)
elif args.format == "reward":
reward_ds = [mk_preference_pair(dataset, pair) for pair in selected_pairs]
save_jsonl(reward_ds, args.output)
else:
raise NotImplementedError(f"Unknown format: {args.format}")
......@@ -7,7 +7,6 @@ import re
from tqdm.contrib.concurrent import process_map
from rapidfuzz import fuzz
from codecritic.utils.json import load_jsonl, save_jsonl
......@@ -75,48 +74,10 @@ def compute_pair_similarity(group):
return results
def mk_sft(ds, pair):
dataset_name = pair["dataset"]
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
# TODO add judgement response
return [
{
"question": chosen["messages"][:1],
"response": chosen["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": chosen
},
{
"question": rejected["messages"][:1],
"response": rejected["messages"][1:],
"dataset": dataset_name,
"task_id": task_id,
"solution_id": rejected
}
]
def mk_preference_pair(ds, pair):
task_id = pair["task_id"]
chosen = ds[task_id][pair["chosen"]]
rejected = ds[task_id][pair["rejected"]]
return {
"messages": chosen["messages"][:1],
"chosen": chosen["messages"][1:],
"rejected": rejected["messages"][1:],
"meta_pairinfo": pair
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="path/to/dataset")
parser.add_argument("--output", type=str, help="path/to/output")
parser.add_argument("--format", type=str, choices=["sft", "reward"], help="output format")
args = parser.parse_args()
dataset = load_jsonl(args.dataset)
......@@ -141,15 +102,16 @@ if __name__ == "__main__":
for item in dataset:
ds[item["task_id"]][item["solution_id"]] = item
pairinfo = []
sorted_pairinfo = sorted(pairinfo, key=lambda x: x["similarity"])
task_groups = defaultdict(list)
for item in pairinfo:
task_groups[item["task_id"]].append(item)
# Step 2: Select the 4 pairs with the smallest score for each task
selected_pairs = []
for task, items in task_groups.items():
# Sort items for this task by score and select the top 4
sorted_items = sorted(items, key=lambda x: x["similarity"])[:4]
selected_pairs.extend(sorted_items)
if args.format == "sft":
sft_ds = list(chain.from_iterable([mk_sft(ds, pair) for pair in selected_pairs]))
save_jsonl(sft_ds, args.output)
elif args.format == "reward":
reward_ds = [mk_preference_pair(ds, pair) for pair in selected_pairs]
save_jsonl(reward_ds, args.output)
else:
raise NotImplementedError(f"Unknown format: {args.format}")
\ No newline at end of file
save_jsonl(selected_pairs, args.output)
......@@ -6,7 +6,6 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from codecritic.dataset.genrm_prompt import THINK_MESSAGE, JUDGE_MESSAGE, JUDGE_TOEKNS
from codecritic.dataset.legacy_genrm_prompt import COV_MESSAGE
from codecritic.utils.inference import generate_worker, score_worker
from codecritic.utils.parallel import model_map
from codecritic.utils.json import load_jsonl, save_jsonl
......@@ -36,7 +35,7 @@ if __name__ == "__main__":
if args.reasoning:
for item in dataset:
item["messages"].append(COV_MESSAGE)
item["messages"].append(THINK_MESSAGE)
sampling_params = SamplingParams(
n=1,
......
from codecritic.dataset.code import extract_code, code_template
from codecritic.data.utils import SPLITTER, mk_message
from codecritic.dataset.genrm_prompt import mk_judge_response
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
COV_EXAMPLE = """\
** Example RETURN FORMAT **
```python
def add_numbers(a, b):
return a + b
result = add_numbers(5, '10')
```
1. **Code:**
```python
def add_numbers(a, b):
return a + b
```
This defines a function `add_numbers` that takes two arguments and returns their sum. Correct.
2. **Code:**
```python
result = add_numbers(5, '10')
```
The second argument is a string (`'10'`), which will cause a TypeError when trying to add it to an integer. Incorrect.
"""
CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect."
COV_MESSAGE = {"role": "user", "content": COV_PROMPT}
def mk_cov_prompt(is_correct, splitter, mode):
if mode == "train":
anchor = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
elif mode == "test":
anchor = ""
else:
raise ValueError(f"Invalid mode: {mode}. Expected 'train' or 'test'.")
turn1 = {"role": "user", "content": '\n'.join([anchor, COV_PROMPT, COV_EXAMPLE])}
if splitter:
turn2 = {
"role": "assistant",
"content": "Here's a step-by-step verification of the code." + SPLITTER,
}
return [turn1, turn2]
else:
return [turn1]
def convert_preference_to_vot_prompt(item, splitter, mode):
message = item["messages"][0]["content"]
chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"]
chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected))
messages1 = mk_message(message, chosen) + mk_cov_prompt(True, splitter, mode)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False, splitter, mode)
return (
{"messages": messages1, "eval_result": True, "problem_id": item["problem_id"]},
{"messages": messages2, "eval_result": False, "problem_id": item["problem_id"]}
)
def convert_sft_to_vot_prompt(item, splitter, mode):
question = item["messages"][0]["content"]
response = item["messages"][1]["content"]
code = code_template.format(extract_code(response))
messages = mk_message(question, code) + mk_cov_prompt(item["eval_result"], splitter, mode)
return {"messages": messages, "eval_result": item["eval_result"], "problem_id": item["problem_id"]}
def convert_cov_to_cov_dataset(item, mode):
item["messages"][2]["content"] = COV_PROMPT
if mode == "train":
item["messages"] += mk_critic_verify(item["eval_result"])
return item
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment