from vllm import LLM, SamplingParams

import os
import multiprocessing
from itertools import chain
from functools import partial

from utils import load_jsonl, save_jsonl


def worker(cuda_device, prompts, model_path, sampling_params):
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device

    llm = LLM(model=model_path, seed=42, max_model_len=8 * 1024, swap_space=16)

    tokenizer = llm.get_tokenizer()
    stop_tokens = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")

    vllm_sampling_params = SamplingParams(
        n=sampling_params['n'],
        temperature=sampling_params['temperature'],
        top_p=0.95,
        max_tokens=sampling_params['max_tokens'],
        stop_token_ids=stop_tokens
    )

    text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
    outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=True)

    result = []
    for item, output in zip(prompts, outputs):
        for response in output.outputs:
            generated_text = response.text

            messages, newitem = item["messages"].copy(), item.copy()
            messages.append({"role": "assistant", "content": generated_text})
            newitem["messages"] = messages
            result.append(newitem)

    return result


def vllm_inference(model_path, prompt_path, output_path, sampling_params):
    prompts = load_jsonl(prompt_path)

    # Respect the slurm's gpu allocation
    cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
    gpu_num = len(cuda_devices)

    # split data
    sub_prompts = [[] for _ in range(gpu_num)]
    for i, prompt in enumerate(prompts):
        sub_prompts[i % gpu_num].append(prompt)

    args = list(zip(cuda_devices, sub_prompts))
    worker_llm = partial(worker, model_path=model_path, sampling_params=sampling_params)

    with multiprocessing.Pool(gpu_num) as pool:
        nested_results = pool.starmap(worker_llm, args)

    results = list(chain(*nested_results))
    print(f"size of dataset: {len(results)}")
    save_jsonl(results, output_path)
