import re
from codecritic.utils.json import load_json, save_json
from transformers import AutoTokenizer

codeblock_pattern = re.compile(r"```python(.+?)```", flags=re.DOTALL)
code_template = """```python
{}
```
"""


def extract_code(text: str):
    codes = [match.strip() for match in re.findall(codeblock_pattern, text)]
    if len(codes) > 0:
        code = "\n".join(codes)
        return code
    else:
        return ""


def mk_preference_pair(instruction, chosen_code, rejected_code):
    return {
        "messages": [
            {"role": "user", "content": instruction},
        ],
        "chosen": {"role": "assistant", "content": code_template.format(chosen_code)},
        "rejected": {
            "role": "assistant",
            "content": code_template.format(rejected_code),
        },
    }


# Note that the human and observation should appear in odd positions
# while llm should appear in even positions.
def mk_sft_item(messages):
    return {"messages": messages}


def mk_critic_qa(instruction, code):
    # Code should be enclosed in a markdown code block
    return [
        {"role": "user", "content": instruction},
        {"role": "assistant", "content": code},
    ]


JUDGE_PROMPT = "Is the code correct (Yes/No)?"


def mk_critic_verify(answer=None):
    # answer: bool or none
    message = [{"role": "user", "content": JUDGE_PROMPT}]
    if answer is not None:
        response = "Yes" if answer else "No"
        message.append({"role": "assistant", "content": response})

    return message


def save_dataset(llamafactory_path, dataset_info, dataset):
    all_dataset_info_path = f"{llamafactory_path}/data/dataset_info.json"
    all_dataset_info = load_json(all_dataset_info_path)
    all_dataset_info |= dataset_info
    save_json(all_dataset_info, all_dataset_info_path, indent=4)

    assert len(dataset_info.keys()) == 1
    dataset_name = list(dataset_info.keys())[0]
    dataset_relative_path = dataset_info[dataset_name]["file_name"]
    save_json(dataset, f"{llamafactory_path}/data/{dataset_relative_path}")


def get_score_token_id(model_path, token_str="Yes"):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    score_tokens = tokenizer.encode(token_str, add_special_tokens=False)
    assert len(score_tokens) == 1
    return score_tokens[0]


def mk_critic_reason(codedit, explanation):
    user_question = {
        "role": "user",
        "content": "Edit your code in diff format to fix any issues and explain the changes.",
    }
    llm_answer_content = f"""\
**Edited Code (in diff format):**
```diff
{codedit}
```

**Explanation:**
{explanation}
"""
    llm_answer = {"role": "assistant", "content": llm_answer_content}
    return [user_question, llm_answer]


SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"
