Commit 7e4c652e by nzy

remove llamafactory & fix remaining import statements after refactoring

parent efa812c2
......@@ -2,9 +2,17 @@
# Is reasoning really work? Let's verify step by step.
import argparse
from itertools import chain
from utils import load_json, extract_code, code_template
from codecritic.utils.data import mk_message, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER
from codecritic.utils.json import load_json
from codecritic.utils.data import (
extract_code,
code_template,
mk_message,
mk_messages,
mk_critic_verify,
save_jsonl_dataset,
SPLITTER,
)
from codecritic.utils.vllm import vllm_chatcomplete
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
......@@ -36,11 +44,17 @@ result = add_numbers(5, '10')
CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect."
def mk_cov_prompt(is_correct):
prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
return [{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + '\n' + COV_EXAMPLE},
{"role": "assistant", "content": "Here's a step-by-step verification of the code." + SPLITTER}]
return [
{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE},
{
"role": "assistant",
"content": "Here's a step-by-step verification of the code." + SPLITTER,
},
]
def convert_preference_to_vot_prompt(item):
......@@ -53,7 +67,7 @@ def convert_preference_to_vot_prompt(item):
messages1 = mk_message(message, chosen) + mk_cov_prompt(True)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False)
return mk_sft_item(messages1), mk_sft_item(messages2)
return mk_messages(messages1), mk_messages(messages2)
def convert_cov_to_cov_dataset(item):
......@@ -73,8 +87,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--llamafactory", type=str)
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
......@@ -85,5 +98,4 @@ if __name__ == "__main__":
covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs))
dataset_info = mk_sft_dataset_info(args.dataset_name)
save_dataset(args.llamafactory, dataset_info, dataset)
save_jsonl_dataset(dataset, args.output_dir)
import argparse
from pathlib import Path
from codecritic.utils.json import load_json, load_jsonl, save_json
from codecritic.utils.data import extract_code, mk_preference_dataset_info, mk_preference_pair, save_dataset
from codecritic.utils.json import load_json, load_jsonl
from codecritic.utils.data import extract_code, mk_preference_pair, save_jsonl_dataset
from nltk.metrics.distance import edit_distance
from collections import defaultdict
from itertools import product, chain
......@@ -92,7 +92,7 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--llamafactory", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--is_max", type=bool, required=True)
args = parser.parse_args()
......@@ -109,6 +109,4 @@ if __name__ == "__main__":
all_edit_distance_pairs, 10 * 1000, 5, is_max=args.is_max
)
dataset_info = mk_preference_dataset_info(dataset_name)
save_json(metadata, dataset_dir / f"{dataset_name}_metadata.json")
save_dataset(args.llamafactory, dataset_info, preference_pairs)
save_jsonl_dataset(preference_pairs, args.output_dir)
......@@ -5,7 +5,7 @@
# This experiment aims to fairly compare these two approaches.
import argparse
from codecritic.utils.json import load_json
from codecritic.utils.data import mk_message, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset
from codecritic.utils.data import mk_message, mk_critic_verify, mk_messages, save_jsonl_dataset
def convert_preference_to_sft(item):
......@@ -15,14 +15,13 @@ def convert_preference_to_sft(item):
messages1 = mk_message(message, chosen) + mk_critic_verify(True)
messages2 = mk_message(message, rejected) + mk_critic_verify(False)
return mk_sft_item(messages1), mk_sft_item(messages2)
return mk_messages(messages1), mk_messages(messages2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--llamafactory", type=str)
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
......@@ -31,5 +30,4 @@ if __name__ == "__main__":
for item in preference_dataset:
sft_dataset.extend(convert_preference_to_sft(item))
dataset_info = mk_sft_dataset_info(args.dataset_name)
save_dataset(args.llamafactory, dataset_info, sft_dataset)
save_jsonl_dataset(sft_dataset, args.output_dir)
......@@ -8,7 +8,8 @@ from datasets import load_dataset
from tqdm.contrib.concurrent import process_map
from codecritic.sampling.apps_test import run_test
from utils import extract_code, load_jsonl, save_jsonl
from codecritic.utils.json import save_jsonl
from codecritic.utils.data import extract_code
TIMEOUT = 10
......
import re
from codecritic.utils.json import load_json, save_json
from codecritic.utils.json import save_jsonl
from pathlib import Path
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
......@@ -32,7 +33,7 @@ def mk_preference_pair(instruction, chosen_code, rejected_code):
# Note that the human and observation should appear in odd positions
# while llm should appear in even positions.
def mk_sft_item(messages):
def mk_messages(messages):
return {"messages": messages}
......@@ -56,16 +57,10 @@ def mk_critic_verify(answer=None):
return message
def save_dataset(llamafactory_path, dataset_info, dataset):
all_dataset_info_path = f"{llamafactory_path}/data/dataset_info.json"
all_dataset_info = load_json(all_dataset_info_path)
all_dataset_info |= dataset_info
save_json(all_dataset_info, all_dataset_info_path, indent=4)
assert len(dataset_info.keys()) == 1
dataset_name = list(dataset_info.keys())[0]
dataset_relative_path = dataset_info[dataset_name]["file_name"]
save_json(dataset, f"{llamafactory_path}/data/{dataset_relative_path}")
def save_jsonl_dataset(dataset, output_dir, split="train"):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
save_jsonl(dataset, output_dir / f"{split}.jsonl")
def get_score_token_id(tokenizer, token_str="Yes"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment