Commit da039794 by Cody Hao Yu Committed by Wuwei Lin

[AutoTVM] Enhance tuning space of split (#3949)

* Refine policies for define_split

- Rename policy "all" to "factors"
- Add policy "verbose" and "power2"

* Refine search space

* add doc
parent e35e1cc2
...@@ -165,6 +165,20 @@ def get_factors(n): ...@@ -165,6 +165,20 @@ def get_factors(n):
ret.sort() ret.sort()
return ret return ret
def get_pow2s(n):
"""return all power-of-two numbers that are less or equal than the integer
Parameters
----------
n: int
integer for reference
Returns
-------
factors: list
List of all power-of-two numbers
"""
return [2**x for x in range(math.floor(math.log2(n)) + 1)]
class SplitSpace(TransformSpace): class SplitSpace(TransformSpace):
"""Split an axis for several times""" """Split an axis for several times"""
...@@ -175,43 +189,49 @@ class SplitSpace(TransformSpace): ...@@ -175,43 +189,49 @@ class SplitSpace(TransformSpace):
self.policy = policy self.policy = policy
self.entities = [] self.entities = []
if policy == 'all': max_factor = kwargs.get("max_factor", 1 << 31)
num_outputs = kwargs["num_outputs"] fil = kwargs.get("filter", lambda x: True)
max_factor = kwargs.get("max_factor", 1 << 31) self.product = axis.length
fil = kwargs.get("filter", lambda x: True) self.num_output = kwargs.get("num_outputs", 0)
assert self.num_output > 0
length = axis.length
factors = get_factors(length) if policy == 'candidate':
factors = [x for x in factors if x <= max_factor]
# copy factors for every level
self.product = length
self.num_outputs = num_outputs
self.factors = [factors] * (num_outputs-1)
self._generate_space(0, [None] * (num_outputs - 1))
self.entities = list(filter(fil, self.entities))
self.num_output = num_outputs
elif policy == 'candidate':
self.product = axis.length
self.num_outputs = kwargs["num_outputs"]
for size in kwargs["candidate"]: for size in kwargs["candidate"]:
assert len(size) == self.num_outputs assert len(size) == self.num_output
# assert np.prod(size) == self.product
self.entities.append(SplitEntity(size)) self.entities.append(SplitEntity(size))
self.num_output = self.num_outputs
else: else:
raise RuntimeError("Invalid policy: " + policy) if policy == 'verbose':
# Include factors and power-of-twos. May generate tails.
divisibles = get_factors(self.product)
pow2s = get_pow2s(self.product)
factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor]
elif policy == 'factors':
# Include divisible factors. Guarantee no tails.
factors = [x for x in get_factors(self.product) if x <= max_factor]
elif policy == 'power2':
# Include less, equal, and round-up power-of-two numbers. May generate tails.
factors = [x for x in get_pow2s(self.product) if x <= max_factor]
else:
raise RuntimeError("Invalid policy: %s" % policy)
def _generate_space(self, now, tmp_stack): # Enforce the product of all split factors equals to the axis length
no_tail = kwargs.get("no_tail", policy == 'factors')
# Generate split entity by enumerating candidate factors.
self.factors = factors
self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail)
self.entities = list(filter(fil, self.entities))
def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
"""Generate space by DFS""" """Generate space by DFS"""
if now == self.num_outputs - 1: if now == self.num_output - 1:
size = np.prod(tmp_stack, dtype=np.int64) if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0:
if self.product % size == 0: self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
first = int(self.product // int(size))
self.entities.append(SplitEntity([first] + tmp_stack[::-1]))
else: else:
for factor in self.factors[now]: for factor in self.factors:
tmp_stack[now] = factor tmp_stack[now] = factor
self._generate_space(now + 1, tmp_stack) self._generate_space(now + 1, tmp_stack, enforce_no_tail)
@staticmethod @staticmethod
def get_num_output(axes, policy, **kwargs): def get_num_output(axes, policy, **kwargs):
...@@ -219,7 +239,7 @@ class SplitSpace(TransformSpace): ...@@ -219,7 +239,7 @@ class SplitSpace(TransformSpace):
def __repr__(self): def __repr__(self):
return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" % return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
(self.policy, self.product, self.num_outputs, len(self))) (self.policy, self.product, self.num_output, len(self)))
class SplitEntity(object): class SplitEntity(object):
...@@ -609,7 +629,7 @@ class ConfigSpace(object): ...@@ -609,7 +629,7 @@ class ConfigSpace(object):
reduce_axis = axis reduce_axis = axis
def define_split(self, name, axis, policy='all', **kwargs): def define_split(self, name, axis, policy='factors', **kwargs):
"""Define a new tunable knob which splits an axis into a list of axes """Define a new tunable knob which splits an axis into a list of axes
Parameters Parameters
...@@ -620,11 +640,22 @@ class ConfigSpace(object): ...@@ -620,11 +640,22 @@ class ConfigSpace(object):
axis to split axis to split
policy: str policy: str
name of policy. name of policy.
If is 'all', the tuner will try all divisible factors. If is 'factors', the tuner will try all divisible factors.
If is 'candidate', try listed candidate. If is 'power2', the tuner will try power-of-two factors less or equal to the length.
If is 'verbose', the tuner will try all candidates in above two policies.
If is 'candidate', try given candidates.
kwargs: dict kwargs: dict
extra arguments for policy extra arguments for policy
see examples below for how to use filter max_factor: int
the maximum split factor.
filter: function(int) -> bool
see examples below for how to use filter.
num_outputs: int
the total number of axis after split.
no_tail: bool
should we only include divisible numbers as split factors.
candidate: list
(policy=candidate) manual candidate list.
Examples Examples
-------- --------
...@@ -632,7 +663,7 @@ class ConfigSpace(object): ...@@ -632,7 +663,7 @@ class ConfigSpace(object):
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]]) >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4 >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4) >>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
""" """
axes = [axis] axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
...@@ -944,7 +975,7 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -944,7 +975,7 @@ class FallbackConfigEntity(ConfigSpace):
""" """
space = self.space_map[name] space = self.space_map[name]
assert isinstance(space, SplitSpace) assert isinstance(space, SplitSpace)
assert len(constraints) == space.num_outputs assert len(constraints) == space.num_output
# '-1' means no constraint # '-1' means no constraint
constraints = [x if x != -1 else 1e10 for x in constraints] constraints = [x if x != -1 else 1e10 for x in constraints]
...@@ -952,7 +983,7 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -952,7 +983,7 @@ class FallbackConfigEntity(ConfigSpace):
entity = self._entity_map[name] entity = self._entity_map[name]
now = space.product now = space.product
for i in reversed(range(space.num_outputs)): for i in reversed(range(space.num_output)):
factors = get_factors(now) factors = get_factors(now)
find = len(factors) - 1 find = len(factors) - 1
......
...@@ -82,11 +82,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ...@@ -82,11 +82,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits) ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: x.size[-1] == 8) filter=lambda x: x.size[-1] == 8)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: x.size[-1] >= 2) filter=lambda x: x.size[-1] >= 2)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: x.size[-1] >= 2) filter=lambda x: x.size[-1] >= 2)
ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2, ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16) filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
...@@ -278,13 +278,13 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, ...@@ -278,13 +278,13 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis _, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h) oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh) s[data_vec].parallel(oh)
#### Schedule kernel packing #### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco) s[kernel_vec].parallel(oco)
......
...@@ -66,10 +66,10 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp ...@@ -66,10 +66,10 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp
x, y = cfg.axis(batch), cfg.axis(out_dim) x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2, ko, ki = cfg.define_split('tile_k', k, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16) filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2) xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2, yo, yi = cfg.define_split('tile_y', y, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8) filter=lambda xx: xx.size[-1] == 8)
cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki], cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
......
...@@ -254,11 +254,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -254,11 +254,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
...@@ -358,11 +358,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -358,11 +358,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16) filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
cfg.define_reorder("reorder_0", cfg.define_reorder("reorder_0",
......
...@@ -95,9 +95,9 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp ...@@ -95,9 +95,9 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
######## Search space ######## Search space
x, y = cfg.axis(X), cfg.axis(Y) x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2) ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2) yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2) xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[ policy='candidate', candidate=[
......
...@@ -99,7 +99,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, ...@@ -99,7 +99,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, _, h, _, _, _, _ = s[data_vec].op.axis _, _, h, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h) oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
if cfg["tile_ah"].size[1] == 1: if cfg["tile_ah"].size[1] == 1:
oaxis = oh oaxis = oh
...@@ -116,7 +116,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, ...@@ -116,7 +116,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
##### Schedule Kenerl bitpacking ##### Schedule Kenerl bitpacking
co, _, _, _, _, _ = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
if cfg["tile_bco"].size[1] == 1: if cfg["tile_bco"].size[1] == 1:
oaxis = oco oaxis = oco
...@@ -185,13 +185,13 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec, ...@@ -185,13 +185,13 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline() s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis _, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h) oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh) s[data_vec].parallel(oh)
##### Schedule kernel packing ##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco) s[kernel_vec].parallel(oco)
......
...@@ -95,7 +95,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -95,7 +95,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
if _is_int8_hw_support(data.dtype, kernel.dtype, target): if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s assert ic == k_ic*k_ic_f*k_ic_s
else: else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk assert ic_chunk == k_ic_chunk
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment