# Additional Experiment:
# Is reasoning really work? Let's verify step by step.
import argparse
from itertools import chain
from utils import load_json, extract_code, code_template

from codecritic.utils.data import mk_message, mk_sft_item, mk_critic_verify, mk_sft_dataset_info, save_dataset, SPLITTER
from codecritic.utils.vllm import vllm_chatcomplete

COV_PROMPT = "Please verify your code step by step using Markdown code blocks. After each step, explain whether it's correct or not, and if not, explain the issue."

COV_EXAMPLE = """\
** Example RETURN FORMAT **

```python
def add_numbers(a, b):
    return a + b

result = add_numbers(5, '10')
```

1. **Code:**
```python
def add_numbers(a, b):
    return a + b
```
**Explanation:** Correct. This defines a function `add_numbers` that takes two arguments and returns their sum.

2. **Code:**
```python
result = add_numbers(5, '10')
```
**Explanation:** Incorrect. The second argument is a string (`'10'`), which will cause a TypeError when trying to add it to an integer.
"""

CORRECT_PROMPT = "Your code is correct."
INCORRECT_PROMPT = "Your code is incorrect."

def mk_cov_prompt(is_correct):
    prompt1 = CORRECT_PROMPT if is_correct else INCORRECT_PROMPT

    return [{"role": "user", "content": prompt1 + "\n" + COV_PROMPT + '\n' + COV_EXAMPLE},
            {"role": "assistant", "content": "Here's a step-by-step verification of the code." + SPLITTER}]


def convert_preference_to_vot_prompt(item):
    message = item["messages"][0]["content"]
    chosen = item["chosen"]["content"]
    rejected = item["rejected"]["content"]

    chosen = code_template.format(extract_code(chosen))
    rejected = code_template.format(extract_code(rejected))

    messages1 = mk_message(message, chosen) + mk_cov_prompt(True)
    messages2 = mk_message(message, rejected) + mk_cov_prompt(False)
    return mk_sft_item(messages1), mk_sft_item(messages2)


def convert_cov_to_cov_dataset(item):
    user_content = item["messages"][2]["content"]
    item["messages"][2]["content"] = COV_PROMPT
    if CORRECT_PROMPT in user_content:
        is_correct = True
    elif INCORRECT_PROMPT in user_content:
        is_correct = False
    else:
        raise ValueError("Invalid prompt")
    item["messages"] += mk_critic_verify(is_correct)
    return item


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--preference_dataset", type=str)
    parser.add_argument("--llamafactory", type=str)
    parser.add_argument("--dataset_name", type=str)
    args = parser.parse_args()

    preference_dataset = load_json(args.preference_dataset)
    cov_prompts = [convert_preference_to_vot_prompt(x) for x in preference_dataset]
    cov_prompts = list(chain(*cov_prompts))

    sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
    covs = vllm_chatcomplete(args.model, cov_prompts, sampling_params)
    dataset = list(map(convert_cov_to_cov_dataset, covs))

    dataset_info = mk_sft_dataset_info(args.dataset_name)
    save_dataset(args.llamafactory, dataset_info, dataset)
