Commit c08b0931 by nanziyuan

Add support for multigpu inference for vllm

parent b731aa9a
......@@ -27,7 +27,7 @@ def append_prompt(item, content):
return item
def run_sft_model(model_path, test_path, apps_path, reason_prompt=None):
def run_sft_model(model_path, test_path, apps_path, reason_prompt=None, model_gpu):
home_path = Path(model_path).parent
result_dir = home_path / "eval"
result_dir.mkdir(exist_ok=True)
......@@ -40,13 +40,13 @@ def run_sft_model(model_path, test_path, apps_path, reason_prompt=None):
if reason_prompt:
test_dataset = [append_prompt(x, COV_PROMPT) for x in test_dataset]
sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
test_dataset = vllm_chatcomplete(model_path, test_dataset, sampling_params)
test_dataset = vllm_chatcomplete(model_path, test_dataset, sampling_params, model_gpu)
# score
tokenizer = AutoTokenizer.from_pretrained(model_path)
score_token = get_score_token_id(tokenizer)
test_dataset = [append_prompt(x, JUDGE_PROMPT) for x in test_dataset]
results = vllm_score(model_path, test_dataset, score_token)
results = vllm_score(model_path, test_dataset, score_token, model_gpu)
score_path = result_dir / "scores.jsonl"
save_jsonl(results, score_path)
......@@ -65,8 +65,9 @@ if __name__ == "__main__":
parser.add_argument("--test", type=str)
parser.add_argument("--apps", type=str)
parser.add_argument("--reason", choices=["cov"])
parser.add_argument("--gpu", type=int, default=1, help="gpu number required by model")
args = parser.parse_args()
reason_prompts = {"cov": COV_PROMPT}
reason_prompt = reason_prompts.get(args.reason, None)
run_sft_model(args.model, args.test, args.apps, reason_prompt)
run_sft_model(args.model, args.test, args.apps, reason_prompt, args.gpu)
from vllm import LLM, SamplingParams
import os
import multiprocessing
from itertools import chain, combinations
from functools import partial
import subprocess
import numpy as np
from codecritic.data.utils import SPLITTER
def get_distance(connection_type):
if connection_type.startswith("NV"):
return 1
elif connection_type == "X":
return 0
elif connection_type == "PIX":
return 2
elif connection_type == "PBX":
return 3
elif connection_type == "PHB":
return 4
elif connection_type == "NODE":
return 5
elif connection_type == "SYS":
return 6
else:
raise RuntimeError("Unknown connection type")
def get_gpu_topology():
"""
Get the GPU topology using `nvidia-smi topo -m` and return a distance matrix.
"""
try:
result = subprocess.run(['nvidia-smi', 'topo', '-m'], stdout=subprocess.PIPE, text=True)
topo_output = result.stdout
except FileNotFoundError:
raise RuntimeError("nvidia-smi not found. Make sure NVIDIA drivers are installed and nvidia-smi is in PATH.")
# Parse the topology matrix
matrix_str = topo_output.split('\n\n')[0]
lines = matrix_str.splitlines()
header = lines[0].split()
gpu_num = sum([x.startswith("GPU") for x in header])
for idx in range(gpu_num):
assert header[idx].endswith(f"GPU{idx}"), header[idx]
matrix = []
for idx, line in enumerate(lines[1:1 + gpu_num]):
assert line.startswith(f"GPU{idx}")
matrix.append(line.split()[1:1 + gpu_num])
# Convert to a numeric distance matrix (lower is better)
distance_matrix = [[get_distance(e) for e in r] for r in matrix]
return np.array(distance_matrix)
def comb_group(n, k):
groups = []
def helper(lst):
if len(lst) == 0:
yield groups.copy()
else:
head, *rest = lst
for group in combinations(rest, k-1):
groups.append((head,) + group)
yield from helper([x for x in rest if x not in group])
groups.pop()
yield from helper(list(range(n)))
def allocate_gpu(model_required_gpus):
cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
gpu_num = len(cuda_devices)
assert gpu_num % model_required_gpus == 0, "gpus must be n * tensor_parallel"
gpu_ids = [int(x) for x in cuda_devices]
m = get_gpu_topology()[gpu_ids][:, gpu_ids]
cost_memory = dict()
for group in combinations(range(gpu_num), model_required_gpus):
indices = list(group)
cost_memory[group] = np.sum(m[indices][:, indices])
min_cost, min_groups = float('inf'), []
for groups in comb_group(len(m), model_required_gpus):
cost = sum(cost_memory[group] for group in groups)
if cost < min_cost:
min_cost, min_groups = cost, groups
return [[str(gpu_ids[x]) for x in group] for group in min_groups]
def split_data(data, num):
"""
The average length of chat in the dataset is not uniformly distributed.
Sometimes, the initial chats are shorter, while the later ones are longer.
To ensure that all GPUs have nearly the same execution time,
we intentionally shuffle the dataset.
"""
groups = [[] for _ in range(num)]
for i, item in enumerate(data):
groups[i % num].append(item)
return groups
def vmap(worker, data, model_required_gpus):
cuda_devices = allocate_gpu(model_required_gpus)
group_num = len(cuda_devices)
data_groups = split_data(data, group_num)
args = list(zip(cuda_devices, data_groups))
with multiprocessing.Pool(group_num) as pool:
nested_results = pool.starmap(worker, args)
return list(chain(*nested_results))
def generate_worker(cuda_device, prompts, model_path, sampling_params):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_device)
llm = LLM(model=model_path,
seed=42,
max_model_len=8 * 1024,
swap_space=16,
tensor_parallel_size=len(cuda_device))
tokenizer = llm.get_tokenizer()
stop_tokens = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")
vllm_sampling_params = SamplingParams(
n=sampling_params['n'],
temperature=sampling_params['temperature'],
top_p=0.95,
max_tokens=sampling_params['max_tokens'],
stop_token_ids=stop_tokens
)
def messages_to_text(messages):
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if SPLITTER in text:
text = text.split(SPLITTER)[0]
return text
text_prompts = [messages_to_text(item["messages"]) for item in prompts]
outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=True)
results = []
for item, output in zip(prompts, outputs):
for response in output.outputs:
generated_text = response.text
messages = item["messages"].copy()
if SPLITTER in messages[-1]["content"]:
message = messages.pop()
raw_content = message["content"].split(SPLITTER)[0]
message["content"] = raw_content + generated_text
else:
message = {"role": "assistant", "content": generated_text}
messages.append(message)
item["messages"].append(message)
results.append({**item, "messages": messages})
return results
def score_worker(cuda_device, prompts, model_path, score_token):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_device)
llm = LLM(model=model_path,
seed=42,
max_model_len=8 * 1024,
swap_space=16,
tensor_parallel_size=len(cuda_device))
tokenizer = llm.get_tokenizer()
print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")
vllm_sampling_params = SamplingParams(
n=1,
temperature=0,
max_tokens=5,
logprobs=20
)
text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=False)
results = []
for item, output in zip(prompts, outputs):
for response in output.outputs:
# response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
sample_logprobs = response.logprobs
logprob = sample_logprobs[0].get(score_token)
score = np.exp(logprob.logprob) if logprob else 0
text = response.text
results.append({**item, "score": score, "critic_text": text})
return results
def vllm_chatcomplete(model_path, prompts, sampling_params, model_required_gpus=1):
worker = partial(generate_worker, model_path=model_path, sampling_params=sampling_params)
return vmap(worker, prompts, model_required_gpus)
def vllm_score(model_path, prompts, score_token, model_required_gpus=1):
worker = partial(score_worker, model_path=model_path, score_token=score_token)
return vmap(worker, prompts, model_required_gpus)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment