from vllm import LLM, SamplingParams

import os
import multiprocessing
from itertools import chain
from functools import partial

from codecritic.data.utils import SPLITTER
import numpy as np

def generate_worker(cuda_device, prompts, model_path, sampling_params):
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device

    llm = LLM(model=model_path, seed=42, max_model_len=8 * 1024, swap_space=16)

    tokenizer = llm.get_tokenizer()
    stop_tokens = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")

    vllm_sampling_params = SamplingParams(
        n=sampling_params['n'],
        temperature=sampling_params['temperature'],
        top_p=0.95,
        max_tokens=sampling_params['max_tokens'],
        stop_token_ids=stop_tokens
    )

    text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
    text_prompts = [prompt.split(SPLITTER)[0] if SPLITTER in prompt else prompt for prompt in text_prompts]
    outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=True)

    result = []
    for item, output in zip(prompts, outputs):
        for response in output.outputs:
            generated_text = response.text

            messages, newitem = item["messages"].copy(), item.copy()
            if SPLITTER in messages[-1]["content"]:
                raw_content = messages[-1]["content"].split(SPLITTER)[0]
                messages[-1]["content"] = raw_content + generated_text
            else:
                messages.append({"role": "assistant", "content": generated_text})
            newitem["messages"] = messages
            result.append(newitem)

    return result


def score_worker(cuda_device, prompts, model_path, score_token):
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device

    llm = LLM(model=model_path, seed=42, max_model_len=8 * 1024, swap_space=16)

    tokenizer = llm.get_tokenizer()
    stop_tokens = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")

    vllm_sampling_params = SamplingParams(
        n=1,
        temperature=0,
        max_tokens=5,
        logprobs=20
    )

    text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
    outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=False)

    result = []
    for item, output in zip(prompts, outputs):
        for response in output.outputs:
            # response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
            sample_logprobs = response.logprobs
            logprob = sample_logprobs[0].get(score_token)
            newitem = item.copy()
            # model always return 4 tokens, ['\n', 'Yes'/'No', '\n', <EOT>]
            if logprob:
                newitem["score"] = np.exp(logprob.logprob)
                newitem["critic_text"] = response.text
            else:
                newitem["score"] = 0
                newitem["critic_text"] = response.text
            result.append(newitem)
    return result


def vllm_chatcomplete(model_path, prompts, sampling_params):
    # Respect the slurm's gpu allocation
    cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
    gpu_num = len(cuda_devices)

    # split data
    sub_prompts = [[] for _ in range(gpu_num)]
    for i, prompt in enumerate(prompts):
        sub_prompts[i % gpu_num].append(prompt)

    args = list(zip(cuda_devices, sub_prompts))
    worker_llm = partial(generate_worker, model_path=model_path, sampling_params=sampling_params)

    with multiprocessing.Pool(gpu_num) as pool:
        nested_results = pool.starmap(worker_llm, args)

    results = list(chain(*nested_results))
    return results


def vllm_score(model_path, prompts, score_token):
    # Respect the slurm's gpu allocation
    cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
    gpu_num = len(cuda_devices)

    # split data
    sub_prompts = [[] for _ in range(gpu_num)]
    for i, prompt in enumerate(prompts):
        sub_prompts[i % gpu_num].append(prompt)

    args = list(zip(cuda_devices, sub_prompts))
    worker_llm = partial(score_worker, model_path=model_path, score_token=score_token)

    with multiprocessing.Pool(gpu_num) as pool:
        nested_results = pool.starmap(worker_llm, args)

    results = list(chain(*nested_results))
    return results
