Commit 9aeb08e2 by nzy

Fix Dead Loop and multiline input

parent 94c564ec
......@@ -15,7 +15,7 @@ class IOExample:
try:
assert_equal(astuple(self), astuple(__value))
return True
except AssertionError:
except Exception:
return False
else:
return False
......
......@@ -61,7 +61,10 @@ def anpl_check(anpl: ANPL, fun_name: str, show_err: bool=True) -> bool:
for io in anpl.funs[fun_name].gloden_ios:
ioc = anpl_trace(anpl, fun_name, io.inputs, fun_name)
if show_err and ioc.exception:
try:
traceback.print_exception(ioc.exception, limit=-1)
except Exception:
print(ioc.exception)
if ioc.crash or len(ioc.ios) < 1:
return False
# if len(ioc.ios) > 1:
......
......@@ -2,7 +2,7 @@ from anpl.synthesizer import raw_query, msg
from utils import sys_str, system_info, multiline_input, select_task, set_openai_key, rich_dumps, Logger
from rich.prompt import IntPrompt, Confirm, Prompt
import rich
from anpl.sandbox import import_module_from_string
from anpl.sandbox import import_module_from_string, timeout
import numpy as np
import time
import traceback
......@@ -72,11 +72,15 @@ while not is_correct:
try:
m = import_module_from_string(code)
inp_t = deepcopy(inp)
out = m.main(inp_t)
f = timeout(timeout=1)(m.main)
out = f(inp_t)
except Exception as e:
logger.log("system", "check", f"crash: {e}")
system_info("[red]Crash[/red]")
try:
traceback.print_exception(e, limit=-1)
except Exception:
print(e)
continue
if np.array_equal(out, real_out):
logger.log("system", "check", f"correct")
......
......@@ -11,7 +11,7 @@ import openai
import time
import os
import pickle
import sys
import prompt_toolkit
colors = ["#000000", "#0000FF", "#FF0000", "#008000", "#FFFF00",
"#808080", "#FFC0CB", "#FFA500", "#008080", "#800000"]
......@@ -61,7 +61,8 @@ def rich_dumps(obj):
def multiline_input():
buffer = sys.stdin.read()
system_info("Press [Esc] followed by [Enter] to accept input.")
buffer = prompt_toolkit.prompt(">", multiline=True, wrap_lines=False, mouse_support=True)
return buffer
sys_str = "[bold red]SYSTEM: [/bold red]"
......@@ -72,6 +73,9 @@ def print_anpl(anpl, for_user=False):
print(Syntax(anpl.to_python(for_user=for_user), "python", line_numbers=True))
def print_text_IOExamples(ios: list[IOExample]):
if len(ios) > 5:
system_info("Too many IOs. Only show the first 5 IO")
ios = ios[:5]
for i, io in enumerate(ios):
print(f"The {i}th IO exmaples is")
for k in io.inputs.keys():
......@@ -82,6 +86,9 @@ def print_text_IOExamples(ios: list[IOExample]):
def print_IOExamples(ios: list[IOExample]):
# TODO assume len of output is all the same
if len(ios) > 5:
system_info("Too many IOs. Only show the first 5 IO")
ios = ios[:5]
table = Table(title="Inputs & Output")
assert len(ios) > 0, "don't have IO examples"
fio = ios[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment