Commit b731aa9a by nanziyuan

refactor: extract vmap from vllm_complete and vllm_score

parent 1a4460f9
...@@ -76,48 +76,54 @@ def comb_group(n, k): ...@@ -76,48 +76,54 @@ def comb_group(n, k):
yield from helper(list(range(n))) yield from helper(list(range(n)))
def get_optimal_groups(matrix, index, k): def allocate_gpu(model_required_gpus):
m = matrix[index][:, index] cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
gpu_num = len(cuda_devices)
assert gpu_num % model_required_gpus == 0, "gpus must be n * tensor_parallel"
gpu_ids = [int(x) for x in cuda_devices]
m = get_gpu_topology()[gpu_ids][:, gpu_ids]
cost_memory = dict() cost_memory = dict()
for group in combinations(range(len(m)), k): for group in combinations(range(gpu_num), model_required_gpus):
indices = list(group) indices = list(group)
cost_memory[group] = np.sum(m[indices][:, indices]) cost_memory[group] = np.sum(m[indices][:, indices])
min_cost = float('inf') min_cost, min_groups = float('inf'), []
min_groups = [] for groups in comb_group(len(m), model_required_gpus):
for groups in comb_group(len(m), k):
cost = sum(cost_memory[group] for group in groups) cost = sum(cost_memory[group] for group in groups)
if cost < min_cost: if cost < min_cost:
min_cost = cost min_cost, min_groups = cost, groups
min_groups = groups
return [[str(index[x]) for x in group] for group in min_groups]
return [[str(gpu_ids[x]) for x in group] for group in min_groups]
def allocate_gpu(model_required_gpus):
cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
print(cuda_devices)
assert len(cuda_devices) % model_required_gpus == 0, "gpus must be n * tensor_parallel"
if model_required_gpus > 1:
matrix = get_gpu_topology()
index = [int(x) for x in cuda_devices]
cuda_devices = get_optimal_groups(matrix, index, model_required_gpus)
else:
cuda_devices = [[x] for x in cuda_devices]
return cuda_devices def split_data(data, num):
"""
The average length of chat in the dataset is not uniformly distributed.
def data_split(data, num): Sometimes, the initial chats are shorter, while the later ones are longer.
To ensure that all GPUs have nearly the same execution time,
we intentionally shuffle the dataset.
"""
groups = [[] for _ in range(num)] groups = [[] for _ in range(num)]
for i, item in enumerate(data): for i, item in enumerate(data):
groups[i % num].append(item) groups[i % num].append(item)
return groups return groups
def vmap(worker, data, model_required_gpus):
cuda_devices = allocate_gpu(model_required_gpus)
group_num = len(cuda_devices)
data_groups = split_data(data, group_num)
args = list(zip(cuda_devices, data_groups))
with multiprocessing.Pool(group_num) as pool:
nested_results = pool.starmap(worker, args)
return list(chain(*nested_results))
def generate_worker(cuda_device, prompts, model_path, sampling_params): def generate_worker(cuda_device, prompts, model_path, sampling_params):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_device) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_device)
...@@ -208,27 +214,11 @@ def score_worker(cuda_device, prompts, model_path, score_token): ...@@ -208,27 +214,11 @@ def score_worker(cuda_device, prompts, model_path, score_token):
def vllm_chatcomplete(model_path, prompts, sampling_params, model_required_gpus=1): def vllm_chatcomplete(model_path, prompts, sampling_params, model_required_gpus=1):
cuda_devices = allocate_gpu(model_required_gpus) worker = partial(generate_worker, model_path=model_path, sampling_params=sampling_params)
group_num = len(cuda_devices) return vmap(worker, prompts, model_required_gpus)
data_groups = data_split(prompts, group_num)
args = list(zip(cuda_devices, data_groups))
worker_llm = partial(generate_worker, model_path=model_path, sampling_params=sampling_params)
with multiprocessing.Pool(group_num) as pool:
nested_results = pool.starmap(worker_llm, args)
return list(chain(*nested_results))
def vllm_score(model_path, prompts, score_token, model_required_gpus=1): def vllm_score(model_path, prompts, score_token, model_required_gpus=1):
cuda_devices = allocate_gpu(model_required_gpus) worker = partial(score_worker, model_path=model_path, score_token=score_token)
group_num = len(cuda_devices) return vmap(worker, prompts, model_required_gpus)
data_groups = data_split(prompts, group_num)
args = list(zip(cuda_devices, data_groups))
worker_llm = partial(score_worker, model_path=model_path, score_token=score_token)
with multiprocessing.Pool(group_num) as pool:
nested_results = pool.starmap(worker_llm, args)
return list(chain(*nested_results))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment