Commit 20c48579 by nzy

add algolr

parent f3dd6691
import argparse
from collections import defaultdict
from functools import partial
from vllm import SamplingParams
from datasets import load_dataset
from codecritic.utils.parallel import model_map
from codecritic.utils.inference import generate_worker
from codecritic.utils.json import load_jsonl, save_jsonl
import codecritic.evaluation.apps_eval as evaluation
import codecritic.dataset.algolr_prompt as promptlib
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="path/to/model")
parser.add_argument("--dataset", type=str, help="path/to/sample")
parser.add_argument("--pairinfo", type=str, help="path/to/pairinfo")
parser.add_argument("--apps", type=str, help="path/to/apps")
parser.add_argument("--output", type=str, help="path/to/score")
parser.add_argument("--hint_level", type=str, choices=["beginner"])
parser.add_argument(
"--tp", type=int, default=1, help="tensor parallel"
)
args = parser.parse_args()
# Step0 load dataset
dataset = load_jsonl(args.dataset)
pairinfo = load_jsonl(args.pairinfo)
ds = defaultdict(dict)
for item in dataset:
ds[item["task_id"]][item["solution_id"]] = item
# Step1 Generate hints
hint_prompts = []
for pair in pairinfo:
task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
chosen, rejected = ds[task_id][chosen_id], ds[task_id][chosen_id]
prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
hint_prompts.append(prompt)
sampling_params = SamplingParams(
n=1,
temperature=0,
top_p=0.95,
max_tokens=2048,
)
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
hint_responses = model_map(worker, hint_prompts, args.tp)
hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
# hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}
save_jsonl(hints, args.output + ".hints")
hints_dict = defaultdict(dict)
for item in hints:
hints_dict[item["task_id"]][item["solution_id"]] = item
# Step2 Generate reasoning
reason_prompts = []
for pair in pairinfo:
task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
chosen, rejected = ds[task_id][chosen_id], ds[task_id][chosen_id]
CORRECT_HINT = "The code is correct."
# chosen
chosen_prompt = promptlib.process_to_reason_prompt(chosen, CORRECT_HINT)
reason_prompts.append(chosen_prompt)
# rejected
rejected_hints = hints_dict[task_id][rejected_id]
rejected_prompt = promptlib.process_to_reason_prompt(rejected, rejected_hints)
reason_prompts.append(rejected_prompt)
sampling_params = SamplingParams(
n=4,
temperature=0.8,
top_p=0.95,
max_tokens=4096,
)
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
reason_responses = model_map(worker, reason_prompts, args.tp)
# Step3 Verify reasoning results
# add prompt "correct the code based the reasoning"
# original solution_id is only an int, now change it to x_y(str) x is solution_id, y is the reason id.
reason_id_counter = defaultdict(lambda: defaultdict(int))
# Iterate through the list and update solution_id
for item in reason_responses:
task_id = item["task_id"]
solution_id = item["solution_id"]
reason_id = reason_id_counter[task_id][solution_id]
item["solution_id"] = f"{solution_id}_{reason_id}"
reason_id_counter[task_id][solution_id] += 1
promptlib.remove_hint(item)
item["messages"].append({"role": "user", "content": promptlib.get_debug_prompt()})
sampling_params = SamplingParams(
n=1,
temperature=0,
top_p=0.95,
max_tokens=2048,
)
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
verify_responses = model_map(worker, reason_responses, args.tp)
print("verify response size: {}".format(len(verify_responses)))
# postprocess verify_response.
# filter the judgement that are not consistent with ground truth.
verify_passed = []
for item in verify_responses:
conclusion, code = promptlib.extract_conclusion_and_code(item["messages"][-1]["content"])
if conclusion == item["pass"]:
item["code"] = code
verify_passed.append(item)
print("verify passed (judgement consistent) size: {}".format(len(verify_passed)))
incorrects, corrects = [], []
for item in verify_passed:
if not item["pass"]:
incorrects.append(item)
else:
corrects.append(item)
# need a list of dict {"task_id": str, "solution_id": str(unique index), "code": ...}
apps = load_dataset(args.apps)
fixed_incorrects = evaluation.evaluate(incorrects, apps)
# filter that code is not correct.
verify_passed = [x for x in fixed_incorrects if x["pass"]] + corrects
print("verify passed (judgement consistent) size: {}".format(len(verify_passed)))
# Step4 Remove hints and Reformat to a SFT dataset
# extract reasoning sets
sft = []
for item in verify_passed:
line = {
"dataset": item["dataset"],
"task_id": item["task_id"],
"solution_id": item["solution_id"],
"question": item["messages"][:1],
"response": item["messages"][1:2],
}
sft.append(line)
print("Size of sft dataset: {}".format(len(sft)))
save_jsonl(sft, args.output)
import openai # Assuming you're using OpenAI's API
import json
# Set your OpenAI API key
openai.api_key = "your_openai_api_key"
# Example input data
input_data = {
"incorrect": """
def sum_even_numbers(numbers):
total = 0
for num in numbers:
if num % 2 == 0:
total += num
return total # Incorrect indentation: return is inside the loop
""",
"correct": """
def sum_even_numbers(numbers):
total = 0
for num in numbers:
if num % 2 == 0:
total += num
return total # Correct indentation: return is outside the loop
""",
"diff": """
- return total # Incorrect indentation: return is inside the loop
+ return total # Correct indentation: return is outside the loop
""",
"problem": "Sum all even numbers in a given list."
}
# Prompt template for generating all hint levels
HINT_GENERATION_PROMPT = """
You are an expert code reviewer. Your task is to generate hints for the following code snippet at three levels of detail: high, middle, and low.
**Problem**: {problem}
**Code**:
{code}
**Diff (if applicable)**:
{diff}
Provide your response in the following JSON format:
{
"high": "High-level hint (e.g., 'The code is incorrect.' or 'The code is correct.')",
"middle": "Middle-level hint (e.g., 'The code contains a logic error related to the placement of the return statement.' or 'The code correctly implements the intended functionality.')",
"low": "Low-level hint (e.g., 'The return statement is incorrectly placed inside the loop, causing the function to exit prematurely.' or 'The return statement is placed outside the loop, ensuring all numbers are processed.')"
}
"""
# Unified prompt template
UNIFIED_PROMPT = """
You are an expert code reviewer. Your task is to analyze the following code snippet step-by-step. Follow this structure:
1. **Intended Functionality**: What is the code trying to do?
2. **Code Analysis**:
- If the code is **incorrect**:
- What is the error in the code?
- Why does this error occur?
- If the code is **correct**:
- Why is this code correct?
- Are there any potential improvements or edge cases to consider?
**Problem**: {problem}
**Hint (if applicable)**: {hint}
Here is the code:
{code}
Provide your response in a clear and structured format.
"""
# Prompt template for verifying consistency
VERIFICATION_PROMPT = """
You are an expert code reviewer. Your task is to verify whether the following analysis is consistent with the hint and the problem description.
**Problem**: {problem}
**Hint**: {hint}
**Analysis**:
{analysis}
Provide your response as either "Consistent" or "Inconsistent" and briefly explain why.
"""
# Function to call the LLM
def call_llm(prompt):
response = openai.ChatCompletion.create(
model="gpt-4", # Use the appropriate model
messages=[{"role": "user", "content": prompt}],
max_tokens=500,
temperature=0.7,
)
return response.choices[0].message.content.strip()
# Function to generate all hint levels using the LLM
def generate_all_hints(code, problem, diff, is_correct=False):
prompt = HINT_GENERATION_PROMPT.format(
problem=problem,
code=code,
diff=diff if not is_correct else "No diff provided."
)
response = call_llm(prompt)
try:
hints = json.loads(response) # Parse the JSON response
return hints
except json.JSONDecodeError:
# Fallback if the response is not valid JSON
return {
"high": "The code is incorrect." if not is_correct else "The code is correct.",
"middle": "The code contains a logic error." if not is_correct else "The code correctly implements the intended functionality.",
"low": "The code contains a detailed error." if not is_correct else "The code is correct and handles all edge cases."
}
# Function to analyze code using the unified prompt
def analyze_code(code, problem, hint=None):
prompt = UNIFIED_PROMPT.format(code=code, problem=problem, hint=hint if hint else "No hint provided.")
analysis = call_llm(prompt)
return analysis
# Function to verify consistency using the LLM
def verify_consistency(analysis, hint, problem):
prompt = VERIFICATION_PROMPT.format(
problem=problem,
hint=hint,
analysis=analysis
)
verification_result = call_llm(prompt)
return verification_result
# Function to synthesize training data with different hint levels
def synthesize_training_data(input_data, hint_level="high"):
incorrect_code = input_data["incorrect"]
correct_code = input_data["correct"]
diff = input_data["diff"]
problem = input_data["problem"]
# Generate all hint levels for incorrect and correct code
incorrect_hints = generate_all_hints(incorrect_code, problem, diff, is_correct=False)
correct_hints = generate_all_hints(correct_code, problem, diff, is_correct=True)
# Select the appropriate hint level
incorrect_hint = incorrect_hints[hint_level]
correct_hint = correct_hints[hint_level]
# Analyze incorrect code with hints
incorrect_analysis = analyze_code(incorrect_code, problem, incorrect_hint)
# Analyze correct code with hints
correct_analysis = analyze_code(correct_code, problem, correct_hint)
# Verify consistency between analysis and hints
incorrect_verification = verify_consistency(incorrect_analysis, incorrect_hint, problem)
correct_verification = verify_consistency(correct_analysis, correct_hint, problem)
# Combine results into training data
training_data = {
"incorrect_code": incorrect_code,
"incorrect_analysis": incorrect_analysis,
"incorrect_hint": incorrect_hint,
"incorrect_verification": incorrect_verification,
"correct_code": correct_code,
"correct_analysis": correct_analysis,
"correct_hint": correct_hint,
"correct_verification": correct_verification,
"diff": diff,
"problem": problem,
"hint_level": hint_level,
"all_incorrect_hints": incorrect_hints, # Include all hint levels for reference
"all_correct_hints": correct_hints, # Include all hint levels for reference
}
return training_data
# Main function
def main():
# Synthesize training data with different hint levels
hint_levels = ["high", "middle", "low"] # High, middle, and low hint levels
training_data_list = []
for hint_level in hint_levels:
training_data = synthesize_training_data(input_data, hint_level)
training_data_list.append(training_data)
# Save training data to a JSON file
with open("training_data.json", "w") as f:
json.dump(training_data_list, f, indent=4)
print("Training data synthesized and saved to 'training_data.json'.")
if __name__ == "__main__":
main()
\ No newline at end of file
...@@ -3,10 +3,10 @@ from collections import defaultdict ...@@ -3,10 +3,10 @@ from collections import defaultdict
from itertools import product, chain from itertools import product, chain
from pathlib import Path from pathlib import Path
import os import os
import re
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
from rapidfuzz import fuzz from rapidfuzz import fuzz
from codecritic.dataset.code import preprocess_code
from codecritic.utils.json import load_jsonl, save_jsonl from codecritic.utils.json import load_jsonl, save_jsonl
...@@ -22,26 +22,6 @@ def group_and_filter(dataset): ...@@ -22,26 +22,6 @@ def group_and_filter(dataset):
yield group yield group
# Precompile regular expressions
SINGLE_LINE_COMMENT_REGEX = re.compile(r'#.*')
MULTILINE_DOUBLE_QUOTE_REGEX = re.compile(r'^\s*""".*?"""\s*$', flags=re.DOTALL | re.MULTILINE)
MULTILINE_SINGLE_QUOTE_REGEX = re.compile(r"^\s*'''.*?'''\s*$", flags=re.DOTALL | re.MULTILINE)
def preprocess_code(code):
# Remove single-line comments
code = SINGLE_LINE_COMMENT_REGEX.sub('', code)
# Remove standalone docstrings (triple-quoted strings that are not part of an expression)
code = MULTILINE_DOUBLE_QUOTE_REGEX.sub('', code)
code = MULTILINE_SINGLE_QUOTE_REGEX.sub('', code)
# Remove blank lines
code = "\n".join([line for line in code.splitlines() if line.strip()])
return code
def compute_pair_similarity(group): def compute_pair_similarity(group):
correct_code_set, incorrect_code_set = set(), {''} correct_code_set, incorrect_code_set = set(), {''}
correct_samples, incorrect_samples = [], [] correct_samples, incorrect_samples = [], []
......
import re
import codecritic.dataset.code as codelib
def get_hint_prompt(question, incorrect_code, diff_edit, level):
prompt = f"""
You are a helpful programming assistant. Your task is to analyze the following **question, incorrect code, and diff edit**. Think step by step and generate a hint that:
1. Explains **why the algorithm is incorrect** (algorithmic reasoning).
2. Provides **actionable guidance** on how to fix the code (code action).
Ensure the hint is clear, actionable, and appropriate for a **{level}-level** learner.
---
## Input
### Question
{question}
### Incorrect Code
```python
{incorrect_code}
```
### Diff Edit
```diff
{diff_edit}
```
---
## Output Format
Return your response in the following format:
### Hint
[Your hint here. Include both algorithmic reasoning and actionable guidance. Entirely in natural language.]
"""
return prompt.strip()
def process_to_hint_prompt(chosen, rejected, level):
question = chosen["messages"][0]["content"]
# question = "\n".join(question.strip().splitlines()[1:-1])
chosen_code = codelib.preprocess_code(chosen["code"])
rejected_code = codelib.preprocess_code(rejected["code"])
diff_str = codelib.diff_code(rejected_code, chosen_code)
prompt = get_hint_prompt(question, rejected_code, diff_str, level)
messages = [{"role": "user", "content": prompt}]
return {
"dataset": rejected["dataset"],
"task_id": rejected["task_id"],
"solution_id": rejected["solution_id"],
"messages": messages
}
hint_pattern = re.compile(r"### hint\n(.*?)(?=\n###|$)", re.DOTALL)
def postprocess_to_hint(llm_response):
messages = llm_response.pop("messages")
response = messages[-1]["content"]
hint_match = hint_pattern.search(response)
if hint_match:
hint = hint_match.group(1).strip()
else:
hint = "No hints."
print("No hints in response: {}".format(response))
llm_response["hint"] = hint
return llm_response
def get_reason_prompt(question, code, hint):
prompt = f"""
You are a code analysis assistant specialized in finding bugs.
When reviewing code, reason through its logic step by step:
- Analyze each part's purpose and implementation
- If you spot potential issues, explain:
* Why it's a problem (based on the logic)
* How to fix it
---
## Input
### Question
{question}
### Code
```python
{code}
```
"""
hint_prompt = f"""
### Hint
{hint}
"""
if hint is None:
return prompt.strip()
else:
return (prompt + hint_prompt).strip()
def process_to_reason_prompt(item, hint):
question = item["messages"][0]["content"]
code = item["code"]
prompt = get_reason_prompt(question, code, hint)
messages = [{"role": "user", "content": prompt}]
return {
"dataset": item["dataset"],
"task_id": item["task_id"],
"solution_id": item["solution_id"],
"messages": messages
}
def get_debug_prompt():
return """
Based on the analysis provided, please:
1. **Draw a conclusion**: State whether the original code is correct or not by answering "Yes" or "No".
- Format: `Conclusion: <Yes/No>`
2. **If the code is not correct**, provide the corrected code.
---
## Output format
Conclusion: (Yes/No)
Corrected Code:
```python
[corrected code here]
```
""".strip()
def remove_hint(item):
raw_prompt = item["messages"][0]["content"]
hint_start = raw_prompt.find("### Hint")
if hint_start != -1:
truncated_string = raw_prompt[:hint_start].strip()
else:
truncated_string = raw_prompt
item["messages"][0]["content"] = truncated_string
def extract_conclusion_and_code(response):
# Extract conclusion
conclusion_line = [line for line in response.split('\n') if line.startswith('Conclusion:')][0]
conclusion_str = conclusion_line.split(': ')[1].strip().lower()
if "yes" in conclusion_str:
conclusion = True
elif "no" in conclusion_str:
conclusion = False
else:
print("llm doesn't draw to a conclusion")
conclusion = None
# Extract corrected code if conclusion is 'No'
corrected_code = ""
if not conclusion:
corrected_code = codelib.extract_code(response)
return conclusion, corrected_code
\ No newline at end of file
import re import re
from difflib import unified_diff
codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL) codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python code_template = """```python
...@@ -14,4 +14,34 @@ def extract_code(text: str): ...@@ -14,4 +14,34 @@ def extract_code(text: str):
code = "\n".join(codes) code = "\n".join(codes)
return code return code
else: else:
return "" return ""
\ No newline at end of file
# Precompile regular expressions
SINGLE_LINE_COMMENT_REGEX = re.compile(r'#.*')
MULTILINE_DOUBLE_QUOTE_REGEX = re.compile(r'^\s*""".*?"""\s*$', flags=re.DOTALL | re.MULTILINE)
MULTILINE_SINGLE_QUOTE_REGEX = re.compile(r"^\s*'''.*?'''\s*$", flags=re.DOTALL | re.MULTILINE)
def preprocess_code(code):
# Remove single-line comments
code = SINGLE_LINE_COMMENT_REGEX.sub('', code)
# Remove standalone docstrings (triple-quoted strings that are not part of an expression)
code = MULTILINE_DOUBLE_QUOTE_REGEX.sub('', code)
code = MULTILINE_SINGLE_QUOTE_REGEX.sub('', code)
# Remove blank lines
code = "\n".join([line for line in code.splitlines() if line.strip()])
return code
def diff_code(incorrect, correct):
diff = unified_diff(
incorrect.splitlines(keepends=True),
correct.splitlines(keepends=True),
fromfile="incorrect.py",
tofile="correct.py",
)
diff = ''.join(diff)
return diff
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment