Commit 55cb6deb by nanziyuan

additional exp: fix bugs.

parent 04e7f84d
......@@ -2,21 +2,45 @@
# Is reasoning really work? Let's verify step by step.
import argparse
from itertools import chain
from pathlib import Path
from utils import load_json
from utils_dataset import mk_critic_qa, mk_sft_item, mk_sft_dataset_info, save_dataset
from utils import load_json, extract_code, code_template
from utils_dataset import mk_critic_qa, mk_sft_item, mk_sft_dataset_info, save_dataset, SPLITTER
from utils_vllm import vllm_chatcomplete
COV_PROMPT = "Let's verify step by step."
COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."
COV_EXAMPLE = """\
** Example RETURN FORMAT **
```python
def add_numbers(a, b):
return a + b
result = add_numbers(5, '10')
```
1. **Code:**
```python
def add_numbers(a, b):
return a + b
```
**Explanation:** Correct. This defines a function `add_numbers` that takes two arguments and returns their sum.
2. **Code:**
```python
result = add_numbers(5, '10')
```
**Explanation:** Incorrect. The second argument is a string (`'10'`), which will cause a TypeError when trying to add it to an integer.
"""
def mk_cov_prompt(is_correct):
if is_correct:
prompt1 = "This code is correct."
prompt1 = "Your code is correct."
else:
prompt1 = "This code is incorrect."
prompt1 = "Your code is incorrect."
return [{"role": "user", "content": prompt1 + " " + COV_PROMPT}]
return [{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + '\n' + COV_EXAMPLE},
{"role": "assistant", "content": "Here's a step-by-step verification of the code." + SPLITTER}]
def convert_preference_to_vot_prompt(item):
......@@ -24,6 +48,9 @@ def convert_preference_to_vot_prompt(item):
chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"]
chosen = code_template.format(extract_code(chosen))
rejected = code_template.format(extract_code(rejected))
messages1 = mk_critic_qa(message, chosen) + mk_cov_prompt(True)
messages2 = mk_critic_qa(message, rejected) + mk_cov_prompt(False)
return mk_sft_item(messages1), mk_sft_item(messages2)
......@@ -43,9 +70,10 @@ if __name__ == "__main__":
args = parser.parse_args()
preference_dataset = load_json(args.preference_dataset)
cov_prompts = list(chain(*convert_preference_to_vot_prompt(preference_dataset)))
cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset]
cov_prompts = list(chain(*cov_prompts))
sampling_params = dict(n=1, temperature=0.8, max_tokens=2048)
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs))
......
......@@ -86,7 +86,7 @@ def save_dataset(llamafactory_path, dataset_info, dataset):
save_json(all_dataset_info, all_dataset_info_path, indent=4)
assert len(dataset_info.keys()) == 1
dataset_name = dataset_info.keys()[0]
dataset_name = list(dataset_info.keys())[0]
dataset_relative_path = dataset_info[dataset_name]["file_name"]
save_json(dataset, f"{llamafactory_path}/data/{dataset_relative_path}")
......@@ -111,3 +111,5 @@ def mk_critic_reason(codedit, explanation):
"""
llm_answer = {"role": "assistant", "content": llm_answer_content}
return [user_question, llm_answer]
SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"
......@@ -5,7 +5,7 @@ import multiprocessing
from itertools import chain
from functools import partial
from utils import load_jsonl, save_jsonl
from utils_dataset import SPLITTER
import numpy as np
def generate_worker(cuda_device, prompts, model_path, sampling_params):
......@@ -26,6 +26,7 @@ def generate_worker(cuda_device, prompts, model_path, sampling_params):
)
text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
text_prompts = [prompt.split(SPLITTER)[0] if SPLITTER in prompt else prompt for prompt in text_prompts]
outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=True)
result = []
......@@ -34,7 +35,11 @@ def generate_worker(cuda_device, prompts, model_path, sampling_params):
generated_text = response.text
messages, newitem = item["messages"].copy(), item.copy()
messages.append({"role": "assistant", "content": generated_text})
if SPLITTER in messages[-1]["content"]:
raw_content = messages[-1]["content"].split(SPLITTER)[0]
messages[-1]["content"] = raw_content + generated_text
else:
messages.append({"role": "assistant", "content": generated_text})
newitem["messages"] = messages
result.append(newitem)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment