import os

import numpy as np
from vllm import LLM, SamplingParams

SPLITTER = "__I_wish_it_were_weekends_all_the_time.__"


def generate_worker(cuda_device, prompts, model_path, sampling_params):
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_device)

    llm = LLM(
        model=model_path,
        seed=42,
        max_model_len=8 * 1024,
        swap_space=16,
        tensor_parallel_size=len(cuda_device),
    )

    tokenizer = llm.get_tokenizer()

    def messages_to_text(messages):
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        if SPLITTER in text:
            text = text.split(SPLITTER)[0]
        return text

    text_prompts = [messages_to_text(item["messages"]) for item in prompts]

    outputs = llm.generate(text_prompts, sampling_params=sampling_params, use_tqdm=True)

    results = []
    for item, output in zip(prompts, outputs):
        for response in output.outputs:
            generated_text = response.text

            messages = item["messages"].copy()
            if SPLITTER in messages[-1]["content"]:
                message = messages.pop()
                raw_content = message["content"].split(SPLITTER)[0]
                message["content"] = raw_content + generated_text
            else:
                message = {"role": "assistant", "content": generated_text}
            messages.append(message)

            results.append({**item, "messages": messages})

    return results


def score_worker(cuda_device, prompts, model_path, positive_token, negative_token=None):
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_device)

    llm = LLM(
        model=model_path,
        seed=42,
        max_model_len=8 * 1024,
        swap_space=16,
        tensor_parallel_size=len(cuda_device),
    )

    tokenizer = llm.get_tokenizer()
    text_prompts = [
        tokenizer.apply_chat_template(
            item["messages"], tokenize=False, add_generation_prompt=True
        )
        for item in prompts
    ]

    sampling_params = SamplingParams(n=1, temperature=0, max_tokens=5, logprobs=20)
    outputs = llm.generate(text_prompts, sampling_params=sampling_params, use_tqdm=True)

    results = []
    for item, output in zip(prompts, outputs):
        assert len(output.outputs) == 1, "The scorer must provide a single score."
        for response in output.outputs:
            # response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
            logprob = response.logprobs[0]
            positive_logprob = logprob.get(positive_token)
            positive_prob = np.exp(positive_logprob.logprob) if positive_logprob else 0

            if negative_token:
                negative_logprob = logprob.get(negative_token)
                negative_prob = np.exp(negative_logprob.logprob) if negative_logprob else 0
            else:
                negative_prob = None
            text = response.text

            result = item.copy()
            result["positive_score"] = positive_prob
            result["negative_score"] = negative_prob
            result["meta_score_response"] = text
            results.append(result)

    return results
