Commit 1a4460f9 by nanziyuan

refactor: replace stack with generator for 3x performance improvement

parent 030769db
...@@ -5,7 +5,6 @@ import multiprocessing ...@@ -5,7 +5,6 @@ import multiprocessing
from itertools import chain, combinations from itertools import chain, combinations
from functools import partial from functools import partial
import subprocess import subprocess
import unicodedata
import numpy as np import numpy as np
...@@ -62,39 +61,19 @@ def get_gpu_topology(): ...@@ -62,39 +61,19 @@ def get_gpu_topology():
def comb_group(n, k): def comb_group(n, k):
comb_stack, lst_stack = [], [] groups = []
def comb_pivot(lst): def helper(lst):
""" if len(lst) == 0:
lst should be sorted yield groups.copy()
"""
pivot = lst[0]
for other in combinations(lst[1:], k-1):
yield (pivot,) + other
def fill_stack():
lst = [x for x in range(n) if x not in list(chain(*lst_stack))]
stack_len = len(comb_stack)
for _ in range((n // k) - stack_len):
new_comb = comb_pivot(lst)
new_group = next(new_comb)
comb_stack.append(new_comb)
lst_stack.append(new_group)
lst = [x for x in lst if x not in new_group]
fill_stack()
yield lst_stack.copy()
while len(comb_stack) > 0:
new_group = next(comb_stack[-1], None)
if new_group:
lst_stack[-1] = new_group
fill_stack()
yield lst_stack.copy()
else: else:
comb_stack.pop() head, *rest = lst
lst_stack.pop() for group in combinations(rest, k-1):
groups.append((head,) + group)
yield from helper([x for x in rest if x not in group])
groups.pop()
yield from helper(list(range(n)))
def get_optimal_groups(matrix, index, k): def get_optimal_groups(matrix, index, k):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment