Commit 862ebb89 by nzy

Warning for functions with the same name, print error, log update

parent 2e1673e0
...@@ -4,6 +4,7 @@ from copy import deepcopy ...@@ -4,6 +4,7 @@ from copy import deepcopy
import functools import functools
from .sandbox import import_module_from_string, timeout from .sandbox import import_module_from_string, timeout
from typing import Optional, Any from typing import Optional, Any
import traceback
class IOCollector: class IOCollector:
...@@ -11,6 +12,7 @@ class IOCollector: ...@@ -11,6 +12,7 @@ class IOCollector:
self.ios = [] self.ios = []
self.fun_name = fun_name self.fun_name = fun_name
self.crash = False self.crash = False
self.exception = None
# Add deco to the function # Add deco to the function
fun = getattr(module, fun_name, None) fun = getattr(module, fun_name, None)
...@@ -50,13 +52,16 @@ def anpl_trace(anpl: ANPL, fun_name: str, inputs: dict[str, Any], entry: Optiona ...@@ -50,13 +52,16 @@ def anpl_trace(anpl: ANPL, fun_name: str, inputs: dict[str, Any], entry: Optiona
f = timeout(timeout=1)(entry_point) f = timeout(timeout=1)(entry_point)
f(**inputs) f(**inputs)
return io return io
except Exception: except Exception as e:
io.exception = e
return io return io
def anpl_check(anpl: ANPL, fun_name: str) -> bool: def anpl_check(anpl: ANPL, fun_name: str, show_err: bool=True) -> bool:
assert len(anpl.funs[fun_name].gloden_ios) > 0 assert len(anpl.funs[fun_name].gloden_ios) > 0
for io in anpl.funs[fun_name].gloden_ios: for io in anpl.funs[fun_name].gloden_ios:
ioc = anpl_trace(anpl, fun_name, io.inputs, fun_name) ioc = anpl_trace(anpl, fun_name, io.inputs, fun_name)
if show_err and ioc.exception:
traceback.print_exception(ioc.exception, limit=-1)
if ioc.crash or len(ioc.ios) < 1: if ioc.crash or len(ioc.ios) < 1:
return False return False
# if len(ioc.ios) > 1: # if len(ioc.ios) > 1:
......
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 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
\ No newline at end of file \ No newline at end of file
from rich.prompt import Confirm, Prompt, IntPrompt from rich.prompt import Confirm, Prompt, IntPrompt
from rich.progress import track from rich.progress import track
from utils import sys_str, system_info, code_input, value_input, print_anpl, print_IOExamples, select_task, set_openai_key, fun_select, print_text_IOExamples from utils import sys_str, system_info, code_input, value_input, print_anpl, print_IOExamples, select_task, set_openai_key, fun_select, print_text_IOExamples, Logger
from copy import deepcopy from copy import deepcopy
import time import time
...@@ -8,11 +8,12 @@ from anpl.anpl import IOExample, ANPL ...@@ -8,11 +8,12 @@ from anpl.anpl import IOExample, ANPL
from anpl.parser import ANPLParser from anpl.parser import ANPLParser
from anpl.synthesizer import fun_synthesis, batch_fun_synthesis from anpl.synthesizer import fun_synthesis, batch_fun_synthesis
from anpl.tracer import anpl_check, anpl_trace from anpl.tracer import anpl_check, anpl_trace
import traceback
set_openai_key() set_openai_key()
task_id, logger, input_grid, output_grid = select_task() task_id, input_grid, output_grid = select_task()
logger.log("system", "intro", "a") logger = Logger(task_id, "A")
parser = ANPLParser() parser = ANPLParser()
anpl = code_input(parser, logger) anpl = code_input(parser, logger)
...@@ -22,27 +23,31 @@ def syn_anpl(anpl: ANPL): ...@@ -22,27 +23,31 @@ def syn_anpl(anpl: ANPL):
for hole in track(holes, description="Synthesizing..."): for hole in track(holes, description="Synthesizing..."):
for i in range(5): for i in range(5):
res = fun_synthesis(anpl, hole, temp=i*0.1) res = fun_synthesis(anpl, hole, temp=i*0.1)
logger.log("gpt", "syn code", res) logger.log("gpt", "syn", res)
if res: if res:
newanpl = parser.try_parse(res, from_user=False) newanpl = parser.try_parse(res, from_user=False)
if newanpl: if newanpl:
logger.log("system", "parse_gpt", "success") logger.log("system", "syn", "info: gpt returns valid code")
if not hole.startswith("_hole"): if not hole.startswith("_hole"):
if hole in newanpl.funs: if hole in newanpl.funs:
newanpl.clean(hole) newanpl.clean(hole)
else: else:
logger.log("system", "parse_gpt", "do not synthesis the specific hole") logger.log("system", "syn", "error: do not synthesis the function with specific name")
continue continue
else: else:
if newanpl.entry in anpl.funs: if newanpl.entry in anpl.funs:
logger.log("system", "parse_gpt", "synthesis codes with a used name") if newanpl.entry != hole:
logger.log("system", "syn", "error: synthesis _hole with a used name")
system_info("[yellow]Warning[\yellow] Generated Function has the same name with some function before. Perhaps you have very similar sentences?")
else:
logger.log("system", "syn", "error: chatgpt do not give a new function name")
continue continue
anpl.fill_fun(hole, newanpl) anpl.fill_fun(hole, newanpl)
break break
logger.log("system", "parse_gpt", "error") logger.log("system", "parse_gpt", "error")
if len(anpl.get_holes()) > 0: if len(anpl.get_holes()) > 0:
logger.log("system", "syn_error", "cannot synthesis code") logger.log("system", "syn", "error: cannot synthesis code")
raise NotImplementedError("Cannot Synthesis") raise NotImplementedError("Cannot Synthesis")
def io_input(anpl: ANPL, name: str, logger): def io_input(anpl: ANPL, name: str, logger):
...@@ -86,13 +91,15 @@ while not is_correct: ...@@ -86,13 +91,15 @@ while not is_correct:
logger.log("user", "trace", f"{fun_name}") logger.log("user", "trace", f"{fun_name}")
ioc = anpl_trace(anpl, fun_name, anpl.funs[anpl.entry].gloden_ios[0].inputs) ioc = anpl_trace(anpl, fun_name, anpl.funs[anpl.entry].gloden_ios[0].inputs)
if ioc.crash: if ioc.crash:
logger.log("system", "trace_err", f"{fun_name}: crash") logger.log("system", "trace", f"{fun_name}: crash")
system_info("[red]ANPL crash in this function.[/red]") system_info("[red]ANPL crash in this function.[/red]")
traceback.print_exception(ioc.exception, limit=-1)
elif len(ioc.ios) == 0: elif len(ioc.ios) == 0:
logger.log("user", "trace_err", f"{fun_name}: crash before this function") logger.log("user", "trace", f"{fun_name}: crash before this function")
system_info("[red]ANPL crash before this function.[/red]") system_info("[red]ANPL crash before this function.[/red]")
traceback.print_exception(ioc.exception, limit=-1)
else: else:
logger.log("user", "trace_ok", f"{fun_name}: show io to user") logger.log("user", "trace", f"{fun_name}: show io to user")
system_info("[green]Visual IO[/green]") system_info("[green]Visual IO[/green]")
print_IOExamples(ioc.ios) print_IOExamples(ioc.ios)
system_info("[green]Textual IO[/green]") system_info("[green]Textual IO[/green]")
...@@ -103,14 +110,15 @@ while not is_correct: ...@@ -103,14 +110,15 @@ while not is_correct:
system_info(f"Please input your code for [italic yellow]{fun_name}[/italic yellow]") system_info(f"Please input your code for [italic yellow]{fun_name}[/italic yellow]")
newanpl = code_input(parser, logger) newanpl = code_input(parser, logger)
if newanpl.entry != fun_name: if newanpl.entry != fun_name:
logger.log("system", "edit_err", f"{fun_name} {newanpl.entry} is not match") logger.log("system", "edit", f"error: {fun_name} {newanpl.entry} is not match")
system_info("[red]Function name don't match.[/red]") system_info(f"[red]Function name don't match: {fun_name} {newanpl.entry}.[/red]")
continue continue
test_anpl = deepcopy(anpl) test_anpl = deepcopy(anpl)
test_anpl.fill_fun(fun_name, newanpl) test_anpl.fill_fun(fun_name, newanpl)
try: try:
syn_anpl(test_anpl) syn_anpl(test_anpl)
except NotImplementedError: except NotImplementedError:
system_info("[red]Cannot synthesis your code[/red]")
continue continue
anpl = test_anpl anpl = test_anpl
...@@ -124,32 +132,32 @@ while not is_correct: ...@@ -124,32 +132,32 @@ while not is_correct:
reses = batch_fun_synthesis(raw_test_anpl, fun_name, 10, 0.8) # The same config as CodeT reses = batch_fun_synthesis(raw_test_anpl, fun_name, 10, 0.8) # The same config as CodeT
for res in track(reses, description="Checking"): for res in track(reses, description="Checking"):
if res is None: if res is None:
logger.log("gpt", "resyn", "nothing") logger.log("gpt", "resyn", "error: gpt return nothing")
continue continue
logger.log("gpt", "resyn", res) logger.log("gpt", "resyn", res)
newanpl = parser.try_parse(res, from_user=False) newanpl = parser.try_parse(res, from_user=False)
if newanpl is None: if newanpl is None:
logger.log("system", "resyn_err", "gpt return wrong python") logger.log("system", "resyn", "error: gpt return wrong python code")
continue continue
if fun_name not in newanpl.funs: if fun_name not in newanpl.funs:
logger.log("system", "resyn_err", "gpt doesn't synthesis hole") logger.log("system", "resyn", "error: gpt doesn't synthesis hole")
continue continue
newanpl.clean(fun_name) newanpl.clean(fun_name)
test_anpl = deepcopy(raw_test_anpl) test_anpl = deepcopy(raw_test_anpl)
test_anpl.fill_fun(fun_name, newanpl) test_anpl.fill_fun(fun_name, newanpl)
if anpl_check(test_anpl, fun_name): if anpl_check(test_anpl, fun_name, show_err=False):
logger.log("system", "resyn_ok", "code pass user's io") logger.log("system", "resyn", "info: code pass user's io")
anpl = test_anpl anpl = test_anpl
find_correct_anpl = True find_correct_anpl = True
break break
if find_correct_anpl: if find_correct_anpl:
logger.log("system", "resyn_ok", "resynthesis correct function") logger.log("system", "resyn", "info: correct")
system_info("[green]Function Correct[/green].") system_info("[green]Function Correct[/green].")
else: else:
logger.log("system", "resyn_fail", "cannot resynthesis correct function") logger.log("system", "resyn", "info: Resyn failed. Cannot resynthesis correct function")
system_info("[red]Cannot synthesis correct function.[/red].") system_info("[red]Cannot synthesis correct function.[/red].")
else: else:
logger.log("user", "remove_io", f"{fun_name}") logger.log("user", "remove_io", f"{fun_name}")
...@@ -157,13 +165,13 @@ while not is_correct: ...@@ -157,13 +165,13 @@ while not is_correct:
system_info(f"Here is all IO Examples of {fun_name}.") system_info(f"Here is all IO Examples of {fun_name}.")
print_IOExamples(ios) print_IOExamples(ios)
idx = IntPrompt.ask("Which io would you like to remove? -1 to return") idx = IntPrompt.ask("Which io would you like to remove? -1 to return")
logger.log("user", "remove_io_exit", "nothing") logger.log("user", "remove_io", "exit")
if idx != -1: if idx != -1:
if idx not in range(0, len(ios)): if idx not in range(0, len(ios)):
logger.log("system", "remove_io_err", f"{fun_name}: {idx}") logger.log("system", "remove_io", f"error: {fun_name}: {idx}")
system_info(f"[red]{fun_name} doesn't have the {idx}th IO [/red].") system_info(f"[red]{fun_name} doesn't have the {idx}th IO [/red].")
else: else:
logger.log("system", "remove_io_info", f"{fun_name}: {idx}") logger.log("system", "remove_io", f"info: {fun_name}: {idx}")
ios.pop(idx) ios.pop(idx)
continue continue
...@@ -173,10 +181,7 @@ while not is_correct: ...@@ -173,10 +181,7 @@ while not is_correct:
if is_correct: if is_correct:
system_info("[green]ANPL CORRECT[/green], and here is the code") system_info("[green]ANPL CORRECT[/green], and here is the code")
print_anpl(anpl, for_user=False) print_anpl(anpl, for_user=False)
import pickle logger.save(anpl)
time_str = time.strftime("%Y%m%d_%H%M%S")
with open(f"./log/task{task_id}_{time_str}.pkl", "wb") as f:
pickle.dump(anpl, f)
else: else:
system_info("Good luck next time.") system_info("Good luck next time.")
logger.log("system", "exit", str(is_correct)) logger.log("system", "exit", str(is_correct))
from anpl.synthesizer import raw_query, msg from anpl.synthesizer import raw_query, msg
from utils import sys_str, system_info, multiline_input, select_task, set_openai_key, rich_dumps from utils import sys_str, system_info, multiline_input, select_task, set_openai_key, rich_dumps, Logger
from rich.prompt import IntPrompt, Confirm, Prompt from rich.prompt import IntPrompt, Confirm, Prompt
import rich import rich
from anpl.sandbox import import_module_from_string from anpl.sandbox import import_module_from_string
import numpy as np import numpy as np
import time import time
import traceback
history = [] history = []
def print_msg(message): def print_msg(message):
role, text = message["role"], message["content"] role, text = message["role"], message["content"]
print(f"{role}:") rich.print(f"[blue]{role}[/blue]:")
print(text) print(text)
def print_history(): def print_history():
...@@ -18,8 +19,8 @@ def print_history(): ...@@ -18,8 +19,8 @@ def print_history():
print_msg(message) print_msg(message)
set_openai_key() set_openai_key()
task_id, logger, inp, real_out = select_task() task_id, inp, real_out = select_task()
logger.log("system", "intro", "b") logger = Logger(task_id, "B")
is_correct = False is_correct = False
while not is_correct: while not is_correct:
...@@ -73,20 +74,21 @@ while not is_correct: ...@@ -73,20 +74,21 @@ while not is_correct:
except Exception as e: except Exception as e:
logger.log("system", "check", f"crash: {e}") logger.log("system", "check", f"crash: {e}")
system_info("[red]Crash[/red]") system_info("[red]Crash[/red]")
print(e) traceback.print_exception(e, limit=-1)
continue continue
if np.array_equal(out, real_out): if np.array_equal(out, real_out):
logger.log("system", "check", f"correct") logger.log("system", "check", f"correct")
system_info("[green] Code CORRECT [/green]") system_info("[green]Code CORRECT[/green]")
time_str = time.strftime("%Y%m%d_%H%M%S") logger.save(code)
with open(f"./log/btask{task_id}_{time_str}.py", "w") as f:
f.write(code)
is_correct = True is_correct = True
else: else:
logger.log("system", "check", f"wrong") logger.log("system", "check", f"wrong")
system_info("[red] Code WRONG [/red]") system_info("[red]Code WRONG[/red]")
print("The output is") rich.print("The output is")
rich.print("[green]Visual Output[/green]")
rich.print(rich_dumps(out)) rich.print(rich_dumps(out))
rich.print("[green]Textual Output[/green]")
print(" ".join(out.__repr__().split()))
else: else:
quit_time = time.time() quit_time = time.time()
if quit_time - logger.start_time < 30 * 60: if quit_time - logger.start_time < 30 * 60:
......
...@@ -9,6 +9,8 @@ from rich.prompt import Prompt, IntPrompt ...@@ -9,6 +9,8 @@ from rich.prompt import Prompt, IntPrompt
from numpy import array from numpy import array
import openai import openai
import time import time
import os
import pickle
colors = ["#000000", "#0000FF", "#FF0000", "#008000", "#FFFF00", colors = ["#000000", "#0000FF", "#FF0000", "#008000", "#FFFF00",
"#808080", "#FFC0CB", "#FFA500", "#008080", "#800000"] "#808080", "#FFC0CB", "#FFA500", "#008080", "#800000"]
...@@ -105,14 +107,12 @@ def code_input(parser, logger): ...@@ -105,14 +107,12 @@ def code_input(parser, logger):
while anpl is None: while anpl is None:
user_input = multiline_input() user_input = multiline_input()
user_input = user_input
logger.log("user", "enter code", user_input) logger.log("user", "enter code", user_input)
anpl = parser.try_parse(user_input) anpl = parser.try_parse(user_input)
# system_info(f"[green]ANPL successfully parsed[/green]")
if anpl is None: if anpl is None:
logger.log("system", "parse_user", "error") logger.log("system", "parser", "user enter wrong code")
system_info("[red]Your code is not correct. Please try again.[/red]") system_info("[red]Your code is not correct. Please try again.[/red]")
logger.log("system", "parse_user", "success") logger.log("system", "parser", "user enter correct code")
return anpl return anpl
def value_input(param, logger): def value_input(param, logger):
...@@ -149,11 +149,14 @@ def set_openai_key(): ...@@ -149,11 +149,14 @@ def set_openai_key():
class Logger: class Logger:
def __init__(self, task_id): def __init__(self, task_id, system_name):
timestr = time.strftime("%Y%m%d_%H%M%S") timestr = time.strftime("%m%d%H%M%S")
self.system_name = system_name
self.start_time = time.time() self.start_time = time.time()
self.file_name = f"./log/task_{task_id}_{timestr}.log" self.folder_path = f"./log/task{task_id}_{system_name}"
self.log("system", "start", "nothing") if not os.path.exists(self.folder_path):
os.makedirs(self.folder_path)
self.file_name = f"{self.folder_path}/task_{system_name}_{task_id}_{timestr}.log"
def log(self, role, action, content): def log(self, role, action, content):
s = {"role": role, "action": action, "content": content, "time": time.time()} s = {"role": role, "action": action, "content": content, "time": time.time()}
...@@ -161,6 +164,17 @@ class Logger: ...@@ -161,6 +164,17 @@ class Logger:
f.write(json.dumps(s)) f.write(json.dumps(s))
f.write("\n") f.write("\n")
def save(self, object):
timestr = time.strftime("%m%d%H%M%S")
if self.system_name == "A":
with open(f"{self.folder_path}/task{self.task_id}_{timestr}.pkl", "wb") as f:
pickle.dump(object, f)
elif self.system_name == "B":
with open(f"{self.folder_path}/btask{self.task_id}_{timestr}.py", "w") as f:
f.write(object)
else:
raise NotImplementedError("Unknown System")
def select_task(): def select_task():
task_id = IntPrompt.ask(sys_str + "Which problem do you want to solve?") task_id = IntPrompt.ask(sys_str + "Which problem do you want to solve?")
...@@ -171,5 +185,5 @@ def select_task(): ...@@ -171,5 +185,5 @@ def select_task():
input_grid = np.array(data["input"]) input_grid = np.array(data["input"])
output_grid = np.array(data["output"]) output_grid = np.array(data["output"])
return task_id, Logger(task_id), input_grid, output_grid return task_id, input_grid, output_grid
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment