Commit 7e4c652e by nzy

remove llamafactory & fix remaining import statements after refactoring

parent efa812c2
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
# Is reasoning really work? Let's verify step by step. # Is reasoning really work? Let's verify step by step.
import argparse import argparse
from itertools import chain from itertools import chain
from utils import load_json, extract_code, code_template
from codecritic.utils.data import mk_message, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER from codecritic.utils.json import load_json
from codecritic.utils.data import (
extract_code,
code_template,
mk_message,
mk_messages,
mk_critic_verify,
save_jsonl_dataset,
SPLITTER,
)
from codecritic.utils.vllm import vllm_chatcomplete from codecritic.utils.vllm import vllm_chatcomplete
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue." COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
...@@ -36,11 +44,17 @@ result = add_numbers(5, '10') ...@@ -36,11 +44,17 @@ result = add_numbers(5, '10')
CORRECT_PROMPT = "Your code is correct." CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect." INCORRECT_PROMPT = "Your code is incorrect."
def mk_cov_prompt(is_correct): def mk_cov_prompt(is_correct):
prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT
return [{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + '\n' + COV_EXAMPLE}, return [
{"role": "assistant", "content": "Here's a step-by-step verification of the code." + SPLITTER}] {"role": "user", "content": prompt1 + "\n" + COV_PROMPT + "\n" + COV_EXAMPLE},
{
"role": "assistant",
"content": "Here's a step-by-step verification of the code." + SPLITTER,
},
]
def convert_preference_to_vot_prompt(item): def convert_preference_to_vot_prompt(item):
...@@ -53,7 +67,7 @@ def convert_preference_to_vot_prompt(item): ...@@ -53,7 +67,7 @@ def convert_preference_to_vot_prompt(item):
messages1 = mk_message(message, chosen) + mk_cov_prompt(True) messages1 = mk_message(message, chosen) + mk_cov_prompt(True)
messages2 = mk_message(message, rejected) + mk_cov_prompt(False) messages2 = mk_message(message, rejected) + mk_cov_prompt(False)
return mk_sft_item(messages1), mk_sft_item(messages2) return mk_messages(messages1), mk_messages(messages2)
def convert_cov_to_cov_dataset(item): def convert_cov_to_cov_dataset(item):
...@@ -73,8 +87,7 @@ if __name__ == "__main__": ...@@ -73,8 +87,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str) parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--llamafactory", type=str) parser.add_argument("--output_dir", type=str)
parser.add_argument("--dataset_name", type=str)
args = parser.parse_args() args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset) preference_dataset = load_json(args.preference_dataset)
...@@ -85,5 +98,4 @@ if __name__ == "__main__": ...@@ -85,5 +98,4 @@ if __name__ == "__main__":
covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params) covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs)) dataset = list(map(convert_cov_to_cov_dataset, covs))
dataset_info = mk_sft_dataset_info(args.dataset_name) save_jsonl_dataset(dataset, args.output_dir)
save_dataset(args.llamafactory, dataset_info, dataset)
import argparse import argparse
from pathlib import Path from pathlib import Path
from codecritic.utils.json import load_json, load_jsonl, save_json from codecritic.utils.json import load_json, load_jsonl
from codecritic.utils.data import extract_code, mk_preference_dataset_info, mk_preference_pair, save_dataset from codecritic.utils.data import extract_code, mk_preference_pair, save_jsonl_dataset
from nltk.metrics.distance import edit_distance from nltk.metrics.distance import edit_distance
from collections import defaultdict from collections import defaultdict
from itertools import product, chain from itertools import product, chain
...@@ -92,7 +92,7 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True): ...@@ -92,7 +92,7 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str) parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--llamafactory", type=str) parser.add_argument("--output_dir", type=str)
parser.add_argument("--is_max", type=bool, required=True) parser.add_argument("--is_max", type=bool, required=True)
args = parser.parse_args() args = parser.parse_args()
...@@ -109,6 +109,4 @@ if __name__ == "__main__": ...@@ -109,6 +109,4 @@ if __name__ == "__main__":
all_edit_distance_pairs, 10 * 1000, 5, is_max=args.is_max all_edit_distance_pairs, 10 * 1000, 5, is_max=args.is_max
) )
dataset_info = mk_preference_dataset_info(dataset_name) save_jsonl_dataset(preference_pairs, args.output_dir)
save_json(metadata, dataset_dir / f"{dataset_name}_metadata.json")
save_dataset(args.llamafactory, dataset_info, preference_pairs)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# This experiment aims to fairly compare these two approaches. # This experiment aims to fairly compare these two approaches.
import argparse import argparse
from codecritic.utils.json import load_json from codecritic.utils.json import load_json
from codecritic.utils.data import mk_message, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset from codecritic.utils.data import mk_message, mk_critic_verify, mk_messages, save_jsonl_dataset
def convert_preference_to_sft(item): def convert_preference_to_sft(item):
...@@ -15,14 +15,13 @@ def convert_preference_to_sft(item): ...@@ -15,14 +15,13 @@ def convert_preference_to_sft(item):
messages1 = mk_message(message, chosen) + mk_critic_verify(True) messages1 = mk_message(message, chosen) + mk_critic_verify(True)
messages2 = mk_message(message, rejected) + mk_critic_verify(False) messages2 = mk_message(message, rejected) + mk_critic_verify(False)
return mk_sft_item(messages1), mk_sft_item(messages2) return mk_messages(messages1), mk_messages(messages2)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--preference_dataset", type=str) parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--llamafactory", type=str) parser.add_argument("--output_dir", type=str)
parser.add_argument("--dataset_name", type=str)
args = parser.parse_args() args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset) preference_dataset = load_json(args.preference_dataset)
...@@ -31,5 +30,4 @@ if __name__ == "__main__": ...@@ -31,5 +30,4 @@ if __name__ == "__main__":
for item in preference_dataset: for item in preference_dataset:
sft_dataset.extend(convert_preference_to_sft(item)) sft_dataset.extend(convert_preference_to_sft(item))
dataset_info = mk_sft_dataset_info(args.dataset_name) save_jsonl_dataset(sft_dataset, args.output_dir)
save_dataset(args.llamafactory, dataset_info, sft_dataset)
...@@ -8,7 +8,8 @@ from datasets import load_dataset ...@@ -8,7 +8,8 @@ from datasets import load_dataset
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
from codecritic.sampling.apps_test import run_test from codecritic.sampling.apps_test import run_test
from utils import extract_code, load_jsonl, save_jsonl from codecritic.utils.json import save_jsonl
from codecritic.utils.data import extract_code
TIMEOUT = 10 TIMEOUT = 10
......
import re import re
from codecritic.utils.json import load_json, save_json from codecritic.utils.json import save_jsonl
from pathlib import Path
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL) codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python code_template = """```python
...@@ -32,7 +33,7 @@ def mk_preference_pair(instruction, chosen_code, rejected_code): ...@@ -32,7 +33,7 @@ def mk_preference_pair(instruction, chosen_code, rejected_code):
# Note that the human and observation should appear in odd positions # Note that the human and observation should appear in odd positions
# while llm should appear in even positions. # while llm should appear in even positions.
def mk_sft_item(messages): def mk_messages(messages):
return {"messages": messages} return {"messages": messages}
...@@ -56,16 +57,10 @@ def mk_critic_verify(answer=None): ...@@ -56,16 +57,10 @@ def mk_critic_verify(answer=None):
return message return message
def save_dataset(llamafactory_path, dataset_info, dataset): def save_jsonl_dataset(dataset, output_dir, split="train"):
all_dataset_info_path = f"{llamafactory_path}/data/dataset_info.json" output_dir = Path(output_dir)
all_dataset_info = load_json(all_dataset_info_path) output_dir.mkdir(parents=True, exist_ok=True)
all_dataset_info |= dataset_info save_jsonl(dataset, output_dir / f"{split}.jsonl")
save_json(all_dataset_info, all_dataset_info_path, indent=4)
assert len(dataset_info.keys()) == 1
dataset_name = list(dataset_info.keys())[0]
dataset_relative_path = dataset_info[dataset_name]["file_name"]
save_json(dataset, f"{llamafactory_path}/data/{dataset_relative_path}")
def get_score_token_id(tokenizer, token_str="Yes"): def get_score_token_id(tokenizer, token_str="Yes"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment