Commit d20b66c9 by nanziyuan

step1: evaluate_code running twice & fix small bugs

parent 0bd552ba
......@@ -67,8 +67,10 @@ def test_generation(args, debug=False):
print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
finally:
assert isinstance(curr_res, list)
problem_result = np.asarray(curr_res)
code_sample["eval_result"] = curr_res
code_sample["eval_result"] = bool(np.all(problem_result > 0))
code_sample["testcase"] = curr_res
return code_sample
......@@ -83,7 +85,7 @@ def get_apps_item(item, apps):
def evaluate_code_samples(code_samples, apps):
args = [(get_apps_item(sample, apps), sample) for sample in code_samples]
cpu_num = multiprocessing.cpu_count()
chunksize = len(code_samples) // (cpu_num * 5)
chunksize = max(len(code_samples) // (cpu_num * 5), 1)
results = process_map(
test_generation, args, max_workers=cpu_num, chunksize=chunksize
)
......@@ -102,29 +104,36 @@ def evaluate_incorrect_code_samples_again(results, apps, loop_num):
"""
maybe_incorrect_lst, correct_lst = [], []
for item in results:
if any(x in item["eval_result"] for x in (-1, -2)):
if any(x in item["testcase"] for x in (-1, -2)):
maybe_incorrect_lst.append(item)
else:
correct_lst.append(item)
print(f"maybe incorrect lst size: {len(maybe_incorrect_lst)}")
for _ in range(loop_num):
if len(maybe_incorrect_lst) == 0:
break
new_results = evaluate_code_samples(maybe_incorrect_lst, apps)
for i, (old_item, new_item) in enumerate(zip(maybe_incorrect_lst, new_results)):
old_eval, new_eval = old_item["eval_results"], new_item["eval_results"]
print(f"maybe incorrect lst size: {len(maybe_incorrect_lst)}")
check_lst = []
for i in range(len(new_results)):
old_item, new_item = maybe_incorrect_lst[i], new_results[i]
old_eval, new_eval = old_item["eval_result"], new_item["eval_result"]
if old_eval == new_eval:
item = maybe_incorrect_lst.pop(i)
correct_lst.append(item)
correct_lst.append(old_item)
else:
maybe_incorrect_lst[i] = new_item
check_lst.append(new_item)
print(old_item["problem_id"], old_eval, new_item["problem_id"], new_eval)
maybe_incorrect_lst = check_lst
if len(results) != len(correct_lst):
save_jsonl(maybe_incorrect_lst, "debug.jsonl")
# raise ValueError("cannot correctly evaluate codes")
print("cannot correctly evalute code. see debug.jsonl")
if len(maybe_incorrect_lst) < 5:
correct_lst.extend(maybe_incorrect_lst)
assert len(results) == len(correct_lst), "cannot correctly evaluate codes" + str(
maybe_incorrect_lst
)
return correct_lst
......@@ -132,7 +141,11 @@ def evaluate(code_sample_path, dataset_path, output_path):
code_samples = load_jsonl(code_sample_path)
apps = load_dataset(dataset_path)
results = evaluate_code_samples(code_samples, apps)
results = evaluate_incorrect_code_samples_again(results, apps, 5)
for item in results:
item["testcase"] = item["eval_result"]
item["eval_result"] = bool(np.all(np.asarray(item["testcase"]) > 0))
results = evaluate_incorrect_code_samples_again(results, apps, 10)
save_jsonl(results, output_path)
......
......@@ -72,7 +72,7 @@ def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
problem_contributions = defaultdict(int)
preference_pairs, pairs_metadata = [], []
for distance, problem_id, instr, pair in all_pairs:
if len(preference_pairs) > k:
if len(preference_pairs) >= k:
break
is_code_used = (pair[0] in code_usages[problem_id]) or (
......@@ -100,10 +100,10 @@ if __name__ == "__main__":
all_edit_distance_pairs, 10 * 1000, 5, is_max=True
)
dataset_info = mk_dataset_info("apps_max_edit_distance_prefrence")
save_jsonl(
save_json(
metadata, cfg["preference_dataset"]["max_edit_distance"]["metadata_path"]
)
save_jsonl(
save_json(
preference_pairs,
cfg["preference_dataset"]["max_edit_distance"]["preference_dataset_path"],
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment