Commit 048ea73a by nanziyuan

fix bugs

parent ee225d88
......@@ -35,27 +35,28 @@ if __name__ == "__main__":
ds[item["task_id"]][item["solution_id"]] = item
# Step1 Generate hints
hint_prompts = []
for pair in pairinfo:
task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id]
prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
hint_prompts.append(prompt)
# hint_prompts = []
# for pair in pairinfo:
# task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
# chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id]
# prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
# hint_prompts.append(prompt)
sampling_params = SamplingParams(
n=1,
temperature=0,
top_p=0.95,
max_tokens=2048,
)
# sampling_params = SamplingParams(
# n=1,
# temperature=0,
# top_p=0.95,
# max_tokens=2048,
# )
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
hint_responses = model_map(worker, hint_prompts, args.tp)
# worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
# hint_responses = model_map(worker, hint_prompts, args.tp)
hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
# hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}
# hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
# # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}
save_jsonl(hints, args.output + ".hints")
# save_jsonl(hints, args.output + ".hints")
hints = load_jsonl(args.output + ".hints")
hints_dict = defaultdict(dict)
for item in hints:
......@@ -86,6 +87,7 @@ if __name__ == "__main__":
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
reason_responses = model_map(worker, reason_prompts, args.tp)
save_jsonl(reason_responses, args.output + ".reason")
# Step3 Verify reasoning results
# add prompt "correct the code based the reasoning"
......
......@@ -152,10 +152,7 @@ def remove_hint(item):
def extract_conclusion_and_code(response):
# Extract conclusion
if 'Conclusion:' not in response:
conclusion = None
print("not found conclusion\n{}".format(response))
else:
try:
conclusion_line = [line for line in response.split('\n') if line.startswith('Conclusion:')][0]
conclusion_str = conclusion_line.split(': ')[1].strip().lower()
......@@ -166,6 +163,9 @@ def extract_conclusion_and_code(response):
else:
print("llm doesn't draw to a conclusion\n{}".format(response))
conclusion = None
except Exception as e:
print("not found conclusion\n{}\n{}".format(response, e))
conclusion = None
# Extract corrected code if conclusion is 'No'
corrected_code = ""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment