Commit 9381291e by nzy

refactor: fix import & remove unused code

parent a89b519a
...@@ -7,8 +7,9 @@ from transformers import AutoTokenizer ...@@ -7,8 +7,9 @@ from transformers import AutoTokenizer
import pprint import pprint
from pathlib import Path from pathlib import Path
from utils import load_jsonl, save_jsonl, extract_code, code_template from codecritic.utils.json import load_jsonl, save_jsonl
from utils_metric import group_results, score_pass_at_k from codecritic.utils.data import extract_code, code_template
from codecritic.utils.metric import group_results, score_pass_at_k
def get_rewards_from_server(server_url: str, messages: list[str]): def get_rewards_from_server(server_url: str, messages: list[str]):
......
...@@ -2,11 +2,11 @@ import argparse ...@@ -2,11 +2,11 @@ import argparse
from pathlib import Path from pathlib import Path
import pprint import pprint
from step2_cov_dataset import COV_PROMPT from codecritic.data.cov import COV_PROMPT
from utils_vllm import vllm_chatcomplete, vllm_score from codecritic.utils.vllm import vllm_chatcomplete, vllm_score
from utils import load_jsonl, save_jsonl, extract_code, code_template from codecritic.utils.json import load_jsonl, save_jsonl
from utils_dataset import mk_critic_qa, JUDGE_PROMPT, get_score_token_id from codecritic.utils.data import extract_code, code_template, mk_critic_qa, JUDGE_PROMPT, get_score_token_id
from utils_metric import group_results, score_pass_at_k from codecritic.utils.metric import group_results, score_pass_at_k
def preprocess_test_item(item): def preprocess_test_item(item):
......
import argparse import argparse
from pathlib import Path from pathlib import Path
from utils import save_json, save_jsonl
from utils_vllm import vllm_chatcomplete from codecritic.utils.json import save_json, save_jsonl
from step1_sample_apps import mk_sample_prompt from codecritic.utils.vllm import vllm_chatcomplete
from step1_evaluate_code import evaluate from codecritic.sampling.sample_apps import mk_sample_prompt
from step1_sort_split_dataset import sort_and_split_dataset from codecritic.sampling.evaluate_code import evaluate
from codecritic.sampling.sort_split_dataset import sort_and_split_dataset
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import argparse import argparse
from itertools import chain from itertools import chain
from utils import load_json, extract_code, code_template from utils import load_json, extract_code, code_template
from utils_dataset import mk_critic_qa, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER
from utils_vllm import vllm_chatcomplete from codecritic.utils.data import mk_critic_qa, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER
from codecritic.utils.vllm import vllm_chatcomplete
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue." COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
......
import argparse import argparse
from pathlib import Path from pathlib import Path
from utils import load_json, load_jsonl, save_json, save_jsonl, extract_code from codecritic.utils.json import load_json, load_jsonl, save_json
from utils_dataset import mk_preference_dataset_info, mk_preference_pair, save_dataset from codecritic.utils.data import extract_code, mk_preference_dataset_info, mk_preference_pair, save_dataset
from nltk.metrics.distance import edit_distance from nltk.metrics.distance import edit_distance
from collections import defaultdict from collections import defaultdict
from itertools import product, chain from itertools import product, chain
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# 2. Using SFT (Supervised Fine-Tuning) directly # 2. Using SFT (Supervised Fine-Tuning) directly
# This experiment aims to fairly compare these two approaches. # This experiment aims to fairly compare these two approaches.
import argparse import argparse
from utils import load_json from codecritic.utils.json import load_json
from utils_dataset import mk_critic_qa, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset from codecritic.utils.data import mk_critic_qa, mk_critic_verify, mk_sft_item, mk_sft_dataset_info, save_dataset
def convert_preference_to_sft(item): def convert_preference_to_sft(item):
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from datasets import load_dataset from datasets import load_dataset
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
from step1_apps_test import run_test from codecritic.sampling.apps_test import run_test
from utils import extract_code, load_jsonl, save_jsonl from utils import extract_code, load_jsonl, save_jsonl
TIMEOUT = 10 TIMEOUT = 10
......
from datasets import load_dataset from datasets import load_dataset
import json import json
from utils import save_jsonl
from transformers import AutoTokenizer from transformers import AutoTokenizer
......
from utils import load_jsonl, save_jsonl
def mk_key_for_sort(item): def mk_key_for_sort(item):
problem_id = item['problem_id'] problem_id = item['problem_id']
prefix, idx = problem_id.split('_') prefix, idx = problem_id.split('_')
...@@ -53,5 +50,3 @@ def sort_and_split_dataset(dataset, n): ...@@ -53,5 +50,3 @@ def sort_and_split_dataset(dataset, n):
minimal_test.extend(problem) minimal_test.extend(problem)
return new_train, new_test, minimal_test return new_train, new_test, minimal_test
from utils import load_json, save_json, code_template import re
from codecritic.utils.json import load_json, save_json
from transformers import AutoTokenizer from transformers import AutoTokenizer
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
{}
```
"""
def mk_preference_dataset_info(dataset_name):
return { def extract_code(text: str):
dataset_name: { codes = [match.strip() for match in re.findall(codeblock_pattern, text)]
"file_name": f"{dataset_name}.json", if len(codes) > 0:
"formatting": "sharegpt", code = "\n".join(codes)
"ranking": True, return code
"columns": { else:
"messages": "messages", return ""
"chosen": "chosen",
"rejected": "rejected",
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system",
},
}
}
def mk_preference_pair(instruction, chosen_code, rejected_code): def mk_preference_pair(instruction, chosen_code, rejected_code):
...@@ -37,23 +31,6 @@ def mk_preference_pair(instruction, chosen_code, rejected_code): ...@@ -37,23 +31,6 @@ def mk_preference_pair(instruction, chosen_code, rejected_code):
} }
def mk_sft_dataset_info(dataset_name):
return {
dataset_name: {
"file_name": f"{dataset_name}.json",
"formatting": "sharegpt",
"columns": {"messages": "messages"},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system",
},
}
}
# Note that the human and observation should appear in odd positions # Note that the human and observation should appear in odd positions
# while llm should appear in even positions. # while llm should appear in even positions.
def mk_sft_item(messages): def mk_sft_item(messages):
...@@ -69,6 +46,8 @@ def mk_critic_qa(instruction, code): ...@@ -69,6 +46,8 @@ def mk_critic_qa(instruction, code):
JUDGE_PROMPT = "Is the code correct (Yes/No)?" JUDGE_PROMPT = "Is the code correct (Yes/No)?"
def mk_critic_verify(answer=None): def mk_critic_verify(answer=None):
# answer: bool or none # answer: bool or none
message = [{"role": "user", "content": JUDGE_PROMPT}] message = [{"role": "user", "content": JUDGE_PROMPT}]
...@@ -99,7 +78,10 @@ def get_score_token_id(model_path, token_str="Yes"): ...@@ -99,7 +78,10 @@ def get_score_token_id(model_path, token_str="Yes"):
def mk_critic_reason(codedit, explanation): def mk_critic_reason(codedit, explanation):
user_question = {"role": "user", "content": "Edit your code in diff format to fix any issues and explain the changes."} user_question = {
"role": "user",
"content": "Edit your code in diff format to fix any issues and explain the changes.",
}
llm_answer_content = f"""\ llm_answer_content = f"""\
**Edited Code (in diff format):** **Edited Code (in diff format):**
```diff ```diff
...@@ -112,4 +94,5 @@ def mk_critic_reason(codedit, explanation): ...@@ -112,4 +94,5 @@ def mk_critic_reason(codedit, explanation):
llm_answer = {"role": "assistant", "content": llm_answer_content} llm_answer = {"role": "assistant", "content": llm_answer_content}
return [user_question, llm_answer] return [user_question, llm_answer]
SPLITTER = "__I_wish_it_were_weekends_all_the_time.__" SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"
import json import json
import re
def load_jsonl(file_path): def load_jsonl(file_path):
...@@ -21,19 +20,3 @@ def save_jsonl(data, file_path): ...@@ -21,19 +20,3 @@ def save_jsonl(data, file_path):
def save_json(data, file_path, indent=None): def save_json(data, file_path, indent=None):
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent) json.dump(data, f, indent=indent)
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
{}
```
"""
def extract_code(text: str):
codes = [match.strip() for match in re.findall(codeblock_pattern, text)]
if len(codes) > 0:
code = "\n".join(codes)
return code
else:
return ""
...@@ -5,7 +5,7 @@ import multiprocessing ...@@ -5,7 +5,7 @@ import multiprocessing
from itertools import chain from itertools import chain
from functools import partial from functools import partial
from utils_dataset import SPLITTER from codecritic.utils.data import SPLITTER
import numpy as np import numpy as np
def generate_worker(cuda_device, prompts, model_path, sampling_params): def generate_worker(cuda_device, prompts, model_path, sampling_params):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment