Commit da039794 by Cody Hao Yu Committed by Wuwei Lin

[AutoTVM] Enhance tuning space of split (#3949)

* Refine policies for define_split

- Rename policy "all" to "factors"
- Add policy "verbose" and "power2"

* Refine search space

* add doc
parent e35e1cc2
......@@ -165,6 +165,20 @@ def get_factors(n):
ret.sort()
return ret
def get_pow2s(n):
"""return all power-of-two numbers that are less or equal than the integer
Parameters
----------
n: int
integer for reference
Returns
-------
factors: list
List of all power-of-two numbers
"""
return [2**x for x in range(math.floor(math.log2(n)) + 1)]
class SplitSpace(TransformSpace):
"""Split an axis for several times"""
......@@ -175,43 +189,49 @@ class SplitSpace(TransformSpace):
self.policy = policy
self.entities = []
if policy == 'all':
num_outputs = kwargs["num_outputs"]
max_factor = kwargs.get("max_factor", 1 << 31)
fil = kwargs.get("filter", lambda x: True)
length = axis.length
factors = get_factors(length)
factors = [x for x in factors if x <= max_factor]
# copy factors for every level
self.product = length
self.num_outputs = num_outputs
self.factors = [factors] * (num_outputs-1)
self._generate_space(0, [None] * (num_outputs - 1))
self.entities = list(filter(fil, self.entities))
self.num_output = num_outputs
elif policy == 'candidate':
self.product = axis.length
self.num_outputs = kwargs["num_outputs"]
max_factor = kwargs.get("max_factor", 1 << 31)
fil = kwargs.get("filter", lambda x: True)
self.product = axis.length
self.num_output = kwargs.get("num_outputs", 0)
assert self.num_output > 0
if policy == 'candidate':
for size in kwargs["candidate"]:
assert len(size) == self.num_outputs
# assert np.prod(size) == self.product
assert len(size) == self.num_output
self.entities.append(SplitEntity(size))
self.num_output = self.num_outputs
else:
raise RuntimeError("Invalid policy: " + policy)
if policy == 'verbose':
# Include factors and power-of-twos. May generate tails.
divisibles = get_factors(self.product)
pow2s = get_pow2s(self.product)
factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor]
elif policy == 'factors':
# Include divisible factors. Guarantee no tails.
factors = [x for x in get_factors(self.product) if x <= max_factor]
elif policy == 'power2':
# Include less, equal, and round-up power-of-two numbers. May generate tails.
factors = [x for x in get_pow2s(self.product) if x <= max_factor]
else:
raise RuntimeError("Invalid policy: %s" % policy)
def _generate_space(self, now, tmp_stack):
# Enforce the product of all split factors equals to the axis length
no_tail = kwargs.get("no_tail", policy == 'factors')
# Generate split entity by enumerating candidate factors.
self.factors = factors
self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail)
self.entities = list(filter(fil, self.entities))
def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
"""Generate space by DFS"""
if now == self.num_outputs - 1:
size = np.prod(tmp_stack, dtype=np.int64)
if self.product % size == 0:
first = int(self.product // int(size))
self.entities.append(SplitEntity([first] + tmp_stack[::-1]))
if now == self.num_output - 1:
if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0:
self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
else:
for factor in self.factors[now]:
for factor in self.factors:
tmp_stack[now] = factor
self._generate_space(now + 1, tmp_stack)
self._generate_space(now + 1, tmp_stack, enforce_no_tail)
@staticmethod
def get_num_output(axes, policy, **kwargs):
......@@ -219,7 +239,7 @@ class SplitSpace(TransformSpace):
def __repr__(self):
return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
(self.policy, self.product, self.num_outputs, len(self)))
(self.policy, self.product, self.num_output, len(self)))
class SplitEntity(object):
......@@ -609,7 +629,7 @@ class ConfigSpace(object):
reduce_axis = axis
def define_split(self, name, axis, policy='all', **kwargs):
def define_split(self, name, axis, policy='factors', **kwargs):
"""Define a new tunable knob which splits an axis into a list of axes
Parameters
......@@ -620,11 +640,22 @@ class ConfigSpace(object):
axis to split
policy: str
name of policy.
If is 'all', the tuner will try all divisible factors.
If is 'candidate', try listed candidate.
If is 'factors', the tuner will try all divisible factors.
If is 'power2', the tuner will try power-of-two factors less or equal to the length.
If is 'verbose', the tuner will try all candidates in above two policies.
If is 'candidate', try given candidates.
kwargs: dict
extra arguments for policy
see examples below for how to use filter
max_factor: int
the maximum split factor.
filter: function(int) -> bool
see examples below for how to use filter.
num_outputs: int
the total number of axis after split.
no_tail: bool
should we only include divisible numbers as split factors.
candidate: list
(policy=candidate) manual candidate list.
Examples
--------
......@@ -632,7 +663,7 @@ class ConfigSpace(object):
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
>>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
"""
axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
......@@ -944,7 +975,7 @@ class FallbackConfigEntity(ConfigSpace):
"""
space = self.space_map[name]
assert isinstance(space, SplitSpace)
assert len(constraints) == space.num_outputs
assert len(constraints) == space.num_output
# '-1' means no constraint
constraints = [x if x != -1 else 1e10 for x in constraints]
......@@ -952,7 +983,7 @@ class FallbackConfigEntity(ConfigSpace):
entity = self._entity_map[name]
now = space.product
for i in reversed(range(space.num_outputs)):
for i in reversed(range(space.num_output)):
factors = get_factors(now)
find = len(factors) - 1
......
......@@ -82,11 +82,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: x.size[-1] == 8)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
......@@ -278,13 +278,13 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)
#### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)
......
......@@ -66,10 +66,10 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp
x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2,
ko, ki = cfg.define_split('tile_k', k, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2,
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8)
cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
......
......@@ -254,11 +254,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
......@@ -358,11 +358,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
cfg.define_reorder("reorder_0",
......
......@@ -95,9 +95,9 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
......
......@@ -99,7 +99,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline()
_, _, h, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
if cfg["tile_ah"].size[1] == 1:
oaxis = oh
......@@ -116,7 +116,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
##### Schedule Kenerl bitpacking
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
if cfg["tile_bco"].size[1] == 1:
oaxis = oco
......@@ -185,13 +185,13 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline()
_, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)
##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)
......
......@@ -95,7 +95,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s
assert ic == k_ic*k_ic_f*k_ic_s
else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment