# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ # pylint: disable=consider-using-enumerate,too-many-lines """ Template configuration space. Each template function can be parametrized by a ConfigSpace. The space is declared when we invoke the template function with ConfigSpace. During evaluation, we pass in a ConfigEntity, which contains a specific entity in the space. This entity contains deterministic parameters. """ from __future__ import absolute_import as _abs import itertools import functools import math from collections import namedtuple, OrderedDict import numpy as np from tvm import schedule, thread_axis from tvm.autotvm.util import get_const_int Axis = namedtuple('Axis', ['space', 'index']) try: _long = long except NameError: _long = int class InstantiationError(ValueError): """Actively detected error in instantiating a template with a config, raised by cfg.raise_error e.g. too many unrolling, too many threads in a block """ pass class TransformSpace(object): """Base class for transform space TransformSpace is the node in the computation graph of axes Note ---- We can regard our schedule code as a transformation graph of axes. Starting from raw axes in the definition of tvm.compute, we can transform these axes by some operators. The operator includes 'split', 'reorder' and 'annotate'. Each operator has some tunable parameters (e.g. the split factor). Then the tuning process is just to find good parameters of these op. So the all the combinations of the parameters of these op forms our search space. Naming convention: We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...) We call a specific entity in a space as XXXEntity. """ def __init__(self): self.ins = [] self.num_output = 0 self.entities = [] def __len__(self): return len(self.entities) def __getitem__(self, index): """Get an entity of the space by index Parameters ---------- index: int Returns ------- transform entity """ return self.entities[index] @staticmethod def get_num_output(): """get number of output axes after this transform Returns ------- n: int number of output axes """ return 0 class VirtualAxis(TransformSpace): """Axis placeholder in template Parameters ---------- var: int or tvm.schedule.IterVar If is int, return a virtual axis whose length is the provided argument. If is IterVar, return a virtual axis whose length is extracted from the IterVar's extent domain. name: str """ name_ct = 0 def __init__(self, var, name=None): super(VirtualAxis, self).__init__() self.num_output = 1 if name is None: name = 'axis_%d' % VirtualAxis.name_ct VirtualAxis.name_ct += 1 self.name = name if isinstance(var, (int, _long)): self.length = var elif isinstance(var, schedule.IterVar): self.name = var.var.name if var.dom is None: self.length = -1 else: self.length = get_const_int(var.dom.extent) elif isinstance(var, VirtualAxis): self.length = var.length else: raise RuntimeError("Invalid type of axis: " + str(type(var))) @staticmethod def get_num_output(var, name=None): return 1 def __repr__(self): return "vaxis(%s)" % self.name def get_factors(n): """return all factors of an integer Parameters ---------- n: int integer to factorize Returns ------- factors: list List of all factors """ step = 2 if n % 2 else 1 ret = list(set( functools.reduce( list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step) if n % i == 0)))) ret.sort() return ret class SplitSpace(TransformSpace): """Split an axis for several times""" def __init__(self, axes, policy, **kwargs): super(SplitSpace, self).__init__() axis = axes[0] self.policy = policy self.entities = [] if policy == 'all': num_outputs = kwargs["num_outputs"] max_factor = kwargs.get("max_factor", 1 << 31) fil = kwargs.get("filter", lambda x: True) length = axis.length factors = get_factors(length) factors = [x for x in factors if x <= max_factor] # copy factors for every level self.product = length self.num_outputs = num_outputs self.factors = [factors] * (num_outputs-1) self._generate_space(0, [None] * (num_outputs - 1)) self.entities = list(filter(fil, self.entities)) self.num_output = num_outputs elif policy == 'candidate': self.product = axis.length self.num_outputs = kwargs["num_outputs"] for size in kwargs["candidate"]: assert len(size) == self.num_outputs # assert np.prod(size) == self.product self.entities.append(SplitEntity(size)) self.num_output = self.num_outputs else: raise RuntimeError("Invalid policy: " + policy) def _generate_space(self, now, tmp_stack): """Generate space by DFS""" if now == self.num_outputs - 1: if self.product % np.prod(tmp_stack) == 0: first = int(self.product // int(np.prod(tmp_stack))) self.entities.append(SplitEntity([first] + tmp_stack[::-1])) else: for factor in self.factors[now]: tmp_stack[now] = factor self._generate_space(now + 1, tmp_stack) @staticmethod def get_num_output(axes, policy, **kwargs): return kwargs["num_outputs"] def __repr__(self): return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" % (self.policy, self.product, self.num_outputs, len(self))) class SplitEntity(object): """ A split operation with detailed parameters that can apply to an axis Parameters ---------- size: Array of int the size of every axis after split e.g. an axis of extent 128, we split it into 3 axes, a possible size is [4, 4, 8] (4x4x8 = 128) """ def __init__(self, size): self.size = size def apply(self, sch, op, axis): """Apply split to an axis Parameters ---------- sch: tvm.schedule.Schedule The tvm schedule op: tvm.tensor.Operation The stage to be applied axis: tvm.schedule.IterVar axis to split Returns ------- axes : list of Axis The transformed axes. """ ret = [] for i in range(1, len(self.size)): ax0, ax1 = sch[op].split(axis, int(np.prod(self.size[i:]))) ret.append(ax0) axis = ax1 return ret + [axis] def __repr__(self): return str(self.size) class ReorderSpace(TransformSpace): """The parameter space for ordering an array of axes""" def __init__(self, axes, policy, **kwargs): super(ReorderSpace, self).__init__() self.ins = axes self.policy = policy self.num_output = len(axes) if policy == 'identity': self.entities = [ReorderEntity(range(len(axes)))] elif policy == 'all': self.entities = [ ReorderEntity(x) for x in itertools.permutations(range(len(axes)))] elif policy == 'interval_all': begin, end = kwargs['interval'] sub_space = list(itertools.permutations(range(begin, end))) prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes))) self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space] elif policy == 'candidate': candidate = kwargs["candidate"] for can in candidate: perm = [axes.index(x) for x in can] self.entities.append(ReorderEntity(perm)) elif policy == 'interleave': spatial, reduce = kwargs['spatial'], kwargs['reduce'] spatial = [[axes.index(x) for x in ch] for ch in spatial] reduce = [[axes.index(x) for x in ch] for ch in reduce] outer_merged = self._merge_chain([x[:-1] for x in spatial]) inner_merged = self._merge_chain([x[-1:] for x in spatial] + reduce) for o in outer_merged: for i in inner_merged: self.entities.append(ReorderEntity(o + i)) elif policy == 'interleave_cuda': spatial, reduce = kwargs['spatial'], kwargs['reduce'] spatial = [[axes.index(x) for x in ch] for ch in spatial] reduce = [[axes.index(x) for x in ch] for ch in reduce] outer_merged = self._merge_chain([x[:-1] for x in spatial]) reduce_merged = self._merge_chain(reduce) inner_merged = [x[-1] for x in spatial] for o in outer_merged: for r in reduce_merged: self.entities.append(ReorderEntity(o + r + inner_merged)) else: raise RuntimeError("Invalid policy: " + policy) @staticmethod def get_num_output(axes, policy, **kwargs): return len(axes) def __repr__(self): return "Reorder(policy=%s) len=%d" % (self.policy, len(self)) def _merge_chain(self, chains): """generate all combinations of merge some chains""" merged = [] tmp_pt = [0] * len(chains) tmp_stack = [] size = np.sum([len(x) for x in chains]) self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged) return merged def _merge_dfs(self, chains, size, tmp_pt, tmp_stack, merged): if np.sum(tmp_pt) == size: merged.append(list(tmp_stack)) return else: for i in range(len(chains)): # use i == np.argmax(....) here to take spatial order into consideration # if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....) if (tmp_pt[i] < len(chains[i]) and (i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))): tmp_stack.append(chains[i][tmp_pt[i]]) tmp_pt[i] += 1 self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged) tmp_pt[i] -= 1 tmp_stack.pop() class ReorderEntity(object): """A reorder operation with detailed parameters that can apply to axes Parameters ---------- perm: Array of int define the permutation """ def __init__(self, perm): self.perm = perm def apply(self, sch, op, axes): """Apply reorder to an array of axes Parameters ---------- sch: tvm.schedule.Schedule The tvm schedule op: tvm.tensor.Operation The stage to be applied axis: tvm.schedule.IterVar axis to split Returns ------- axes : list of Axis The transformed axes. """ if len(axes) == len(self.perm): new_order = [axes[i] for i in self.perm] else: new_order = [axes[i] for i in self.perm if i < len(axes)] sch[op].reorder(*new_order) return new_order def __repr__(self): return str(self.perm) class AnnotateSpace(TransformSpace): """The parameter space for annotating an array of axes""" def __init__(self, axes, policy, **kwargs): super(AnnotateSpace, self).__init__() self.ins = axes self.policy = policy self.num_output = len(axes) if policy == 'bind_gpu': self.num_axis = len(axes) if self.num_axis >= 6: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 6) + ['blockIdx.z', 'blockIdx.y', 'blockIdx.x', 'threadIdx.z', 'threadIdx.y', 'threadIdx.x'])) elif self.num_axis >= 4: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 4) + ['blockIdx.y', 'blockIdx.x', 'threadIdx.y', 'threadIdx.x'])) elif self.num_axis >= 2: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 2) + ['blockIdx.x', 'threadIdx.x'])) else: raise RuntimeError("Unhandled case in bind_gpu") elif policy == 'bind_gpu_virtual': self.num_axis = len(axes) if self.num_axis >= 9: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 9) + ['blockIdx.z', 'blockIdx.y', 'blockIdx.x', 'vthread', 'vthread', 'vthread', 'threadIdx.z', 'threadIdx.y', 'threadIdx.x'])) elif self.num_axis >= 6: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 6) + ['blockIdx.y', 'blockIdx.x', 'vthread', 'vthread', 'threadIdx.y', 'threadIdx.x'])) elif self.num_axis >= 3: self.entities.append(AnnotateEntity( ['fuse'] * (self.num_axis - 3) + ['blockIdx.x', 'vthread', 'threadIdx.x'])) else: raise RuntimeError("Unhandled case in bind_gpu") elif policy == 'locate_cache': self.num_axis = len(axes) num_anchor = kwargs["num_anchor"] self.anns = list(itertools.combinations(np.arange(self.num_axis), num_anchor)) self.entities = [AnnotateEntity(x) for x in self.anns] else: # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ... anns = policy.replace('try', 'none').split('_') for ann in anns: if ann not in ['none', 'unroll', 'vec']: raise RuntimeError("Invalid policy: " + policy) self.num_axis = len(axes) self.anns = [anns] * self.num_axis self._generate_space(0, [""] * self.num_axis) def _generate_space(self, now, tmp_stack): """Generate space by DFS""" if now == self.num_axis: # only vectorize inner most dimension vec_ct = tmp_stack.count('vec') if vec_ct == 0 or vec_ct == 1: self.entities.append(AnnotateEntity(list(tmp_stack))) else: for ann in self.anns[now]: tmp_stack[now] = ann self._generate_space(now + 1, tmp_stack) @staticmethod def get_num_output(axes, policy, **kwargs): return len(axes) def __repr__(self): return "Annotate(policy=%s) len=%d" % (self.policy, len(self)) class AnnotateEntity(object): """An annotation operation with detailed parameters that can apply to axes Parameters ---------- anns: Array of string The annotations of axes """ def __init__(self, anns): self.anns = anns def apply(self, sch, op, axes, axis_lens=None, max_unroll=None, vec_size=None, cfg=None, source=None): """Apply annotation to an array of axes Parameters ---------- sch: tvm.schedule.Schedule The tvm schedule op: tvm.tensor.Operation The stage to be applied axes: Array of tvm.schedule.IterVar axis to split axis_lens: Array of int, optional the length of axes max_unroll: int, optional maximum unroll step vec_size: Array of int, optional valid vector lanes for vectorization cfg: ConfigEntity, optional cfg for recording error source: Array of Array tensor, optional source tensor for attaching cache Returns ------- axes : list of tvm.schedule.IterVar The transformed axes """ if source is not None: # special case : attach cache_read/cache_write for src, to in zip(source, self.anns): for t in src: sch[t].compute_at(sch[op], axes[to]) else: # other cases for i, ann in enumerate(self.anns): if ann == 'none': pass elif ann == 'unroll': if max_unroll and axis_lens[i] > max_unroll: cfg.raise_error("Too large factor for unrolling") sch[op].unroll(axes[i]) elif ann == 'vec': if vec_size and axis_lens[i] not in vec_size: cfg.raise_error("Wrong size of lanes in vectorization") sch[op].vectorize(axes[i]) elif ann == 'blockIdx.x': sch[op].bind(axes[i], thread_axis('blockIdx.x')) elif ann == 'blockIdx.y': sch[op].bind(axes[i], thread_axis('blockIdx.y')) elif ann == 'blockIdx.z': sch[op].bind(axes[i], thread_axis('blockIdx.z')) elif ann == 'threadIdx.x': sch[op].bind(axes[i], thread_axis('threadIdx.x')) elif ann == 'threadIdx.y': sch[op].bind(axes[i], thread_axis('threadIdx.y')) elif ann == 'threadIdx.z': sch[op].bind(axes[i], thread_axis('threadIdx.z')) elif ann == 'vthread': sch[op].bind(axes[i], thread_axis("vthread")) elif ann == 'fuse': assert i < len(axes) - 1 axes[i+1] = sch[op].fuse(axes[i], axes[i+1]) else: raise RuntimeError("Invalid annotation " + ann) return axes def __repr__(self): return str(self.anns) class OtherOptionSpace(TransformSpace): """The parameter space for general option""" def __init__(self, axes, policy, **kwargs): super(OtherOptionSpace, self).__init__() candidate = kwargs["candidate"] self.entities = [OtherOptionEntity(x) for x in candidate] @staticmethod def get_num_output(axes, policy, **kwargs): return 0 def __repr__(self): return "OtherOption(%s) len=%d" % (self.entities, len(self)) class OtherOptionEntity(object): """The parameter entity for general option, with a detailed value""" def __init__(self, val): self.val = val def __repr__(self): return str(self.val) class ConfigSpace(object): """The configuration space of a schedule. Pass it as config in template to collect transformation space and build transform graph of axes """ def __init__(self): # private dict to provide sugar self.space_map = OrderedDict() # name -> space self._collect = True self._length = None self._entity_map = OrderedDict() # name -> entity self._constraints = [] self.errors = [] self.template_key = None self.code_hash = None self.flop = 0 self.is_fallback = False @staticmethod def axis(var): """get a virtual axis (axis placeholder) Parameters ---------- var: int or tvm.schedule.IterVar If is int, return an axis whose length is the provided argument. If is IterVar, return an axis whose length is extracted from the IterVar's extent domain. """ return VirtualAxis(var) reduce_axis = axis def define_split(self, name, axis, policy='all', **kwargs): """Define a new tunable knob which splits an axis into a list of axes Parameters ---------- name: str name to index the entity of this space axis: tvm.schedule.IterVar axis to split policy: str name of policy. If is 'all', the tuner will try all divisible factors. If is 'candidate', try listed candidate. kwargs: dict extra arguments for policy see examples below for how to use filter Examples -------- >>> # use custom candidates >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]]) >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4 >>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4) """ axes = [axis] return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) def define_reorder(self, name, axes, policy, **kwargs): """Define a new tunable knob which reorders a list of axes Parameters ---------- name: str name to index the entity of this space axes: Array of tvm.schedule.IterVar axes to reorder policy: str name of policy If is 'identity', do an identity permutation. If is 'all', try all permutations. If is 'interval_all', try all permutations of an interval of axes. If is 'candidate', try listed candidate. If is 'interleave', interleave chains of spatial axes and chains of reduction axes. kwargs: dict extra arguments for policy """ return self._add_new_transform(ReorderSpace, name, axes, policy, **kwargs) def define_annotate(self, name, axes, policy, **kwargs): """Define a new tunable knob which annotates a list of axes Parameters ---------- name: str name to index the entity of this space axes: Array of tvm.schedule.IterVar axes to annotate policy: str name of policy If is 'unroll', unroll the axes. If is 'try_unroll', try to unroll the axes. If is 'try_unroll_vec', try to unroll or vectorize the axes. If is 'bind_gpu', bind the first few axes to gpu threads. If is 'locate_cache', choose n axes to attach shared/local cache. kwargs: dict extra arguments for policy """ return self._add_new_transform(AnnotateSpace, name, axes, policy, **kwargs) def define_knob(self, name, candidate): """Define a tunable knob with a list of candidates Parameters ---------- name: str name key of that option candidate: list list of candidates """ return self._add_new_transform(OtherOptionSpace, name, [], None, candidate=candidate) def add_flop(self, flop): """Add float operation statistics for this tuning task Parameters --------- flop: int or float number of float operations """ self.flop += flop def raise_error(self, msg): """register error in config Using this to actively detect error when scheudling. Otherwise these error will occur during runtime, which will cost more time. Parameters ---------- msg: str """ self.errors.append(msg) def valid(self): """Check whether the config meets all the constraints Note: This check should be called after instantiation of task, because the ConfigEntity/ConfigSpace collects errors during instantiation Returns ------- valid: bool whether the config meets all the constraints """ return not bool(self.errors) def _add_new_transform(self, space_class, name, axes, policy, **kwargs): """Add a new transform space in template""" if self._collect: # convert schedule axis to space definition axis axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes] # add subspace (knob) space = space_class(axes, policy, **kwargs) self.space_map[name] = space self._entity_map[name] = space[0] return [Axis(space, i) for i in range(space.num_output)] return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))] def __len__(self): if self._length is None: self._length = int(np.prod([len(x) for x in self.space_map.values()])) return self._length def get(self, index): """Get a config entity with detailed parameters from this space Parameters ---------- index: int index in the space """ entities = OrderedDict() t = index for name, space in self.space_map.items(): entities[name] = space[t % len(space)] t //= len(space) ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints) return ret def __iter__(self): return self._entity_map.__iter__() def __getitem__(self, name): """get the transform entity(knob) of this entity by name do not use this to get a ConfigEntity of this space (should use ConfigSpace.get instead) Parameters ---------- name: str name of the transform """ return self._entity_map[name] def __repr__(self): res = "ConfigSpace (len=%d, space_map=\n" % len(self) for i, (name, space) in enumerate(self.space_map.items()): res += " %2d %s: %s\n" % (i, name, space) return res + ")" _ann_to_number = { 'none': 0, 'vec': 1, 'unroll': 2, 'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5, 'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8, 'vthread': 9, 'fuse': 10 } class ConfigEntity(ConfigSpace): """A configuration with detailed parameters Parameters ---------- index: int index of this config in space code_hash: str hash of schedule code template_key : str The specific template key entity_map: dict map name to transform entity constraints : list List of constraints """ def __init__(self, index, code_hash, template_key, entity_map, constraints): super(ConfigEntity, self).__init__() self.index = index self.template_key = template_key self._collect = False self._entity_map = entity_map self._space_map = None self._constraints = constraints self.code_hash = code_hash def get_flatten_feature(self): """ flatten entities to a numerical one-dimensional feature vector Returns ------- fea: np.array one dimensional float32 array """ fea = [] for _, v in self._entity_map.items(): if isinstance(v, SplitEntity): fea.extend(v.size) elif isinstance(v, ReorderEntity): # use a naive way: directly copy the permutation fea.extend(v.perm) elif isinstance(v, AnnotateEntity): # one-hot encoding for ann in v.anns: tmp = [0] * len(_ann_to_number) tmp[_ann_to_number[ann]] = 1 fea.extend(tmp) elif isinstance(v, OtherOptionEntity): fea.append(v.val) return np.array(fea, dtype=np.float32) def get_other_option(self): """ Returns ------- other_option: dict other tunable parameters (tunable parameters defined by `cfg.define_knob`) """ return {x: x.val for x in self._entity_map.values() if isinstance(x, OtherOptionEntity)} def to_json_dict(self): """convert to a json serializable dictionary Return ------ json_dict: dict a json serializable dictionary """ ret = {} ret['i'] = int(self.index) ret['t'] = self.template_key ret['c'] = self.code_hash entity_map = [] for k, v in self._entity_map.items(): if isinstance(v, SplitEntity): entity_map.append((k, 'sp', v.size)) elif isinstance(v, ReorderEntity): entity_map.append((k, 're', v.perm)) elif isinstance(v, AnnotateEntity): entity_map.append((k, 'an', v.anns)) elif isinstance(v, OtherOptionEntity): entity_map.append((k, 'ot', v.val)) else: raise RuntimeError("Invalid entity instance: " + v) ret['e'] = entity_map return ret @staticmethod def from_json_dict(json_dict): """Build a ConfigEntity from json serializable dictionary Parameters ---------- json_dict: dict Json serializable dictionary. This should be the return value of :any:`to_json_dict`. Returns ------- config: ConfigEntity The corresponding config object """ index = json_dict["i"] code_hash = json_dict["c"] template_key = json_dict["t"] constraints = [] entity_map = OrderedDict() for item in json_dict["e"]: key, knob_type, knob_args = item if knob_type == 'sp': entity = SplitEntity(knob_args) elif knob_type == 're': entity = ReorderEntity(knob_args) elif knob_type == 'an': entity = AnnotateEntity(knob_args) elif knob_type == 'ot': entity = OtherOptionEntity(knob_args) else: raise RuntimeError("Invalid config knob type: " + knob_type) entity_map[str(key)] = entity return ConfigEntity(index, code_hash, template_key, entity_map, constraints) def __repr__(self): return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash, self.index) class FallbackConfigEntity(ConfigSpace): """The config entity created to support fallback""" def __init__(self): super(FallbackConfigEntity, self).__init__() self.is_fallback = True def fallback_split(self, name, constraints): """Fallback a split knob Parameters ---------- name: str name of the knob constraints: List of int The maximum tile size for every dimension. Value `-1` means no constraint. Examples -------- If you use cfg.define_split('tile_0', 128, num_outputs=3), Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4] If you use cfg.define_split('tile_0', 49, num_outputs=3), Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1] """ space = self.space_map[name] assert isinstance(space, SplitSpace) assert len(constraints) == space.num_outputs # '-1' means no constraint constraints = [x if x != -1 else 1e10 for x in constraints] entity = self._entity_map[name] now = space.product for i in reversed(range(space.num_outputs)): factors = get_factors(now) find = len(factors) - 1 for j, f in enumerate(factors): if f > constraints[i]: find = j - 1 break if find >= 0: entity.size[i] = factors[find] now //= factors[find] else: raise RuntimeError("Cannot find feasible fallback split entity for node: " + name) def fallback_with_reference_log(self, ref_log): """A data driven fallback mechanism. We use tuned parameters from TopHub as reference data. For an unseen shape, we find the most similar tuned one from TopHub and mimic its parameters. Parameters ---------- ref_log: List of (MeasureInput, MeasureResult) The reference log """ knob_names = [x for x in self.space_map.keys() if isinstance(self.space_map[x], SplitSpace)] # find best match config in reference data by matching tiling factors factor_list = [] for knob_name in knob_names: factor_list.append(get_factors(self.space_map[knob_name].product)) best_match_cfg = None best_match_score = 0 for inp, _ in ref_log: match_score = 0 for i, knob_name in enumerate(knob_names): factors = get_factors(int(np.prod(inp.config[knob_name].size))) match_score += (float(len(set(factor_list[i]).intersection(factors))) / len(factor_list[i])) if match_score > best_match_score: best_match_score, best_match_cfg = match_score, inp.config if best_match_cfg is None: return # mimic its tiling strategy for knob_name in knob_names: constraint = list(best_match_cfg[knob_name].size) constraint[0] = -1 self.fallback_split(knob_name, constraint) # copy other knobs for knob_name in self.space_map.keys(): if not isinstance(self.space_map[knob_name], SplitSpace): self._entity_map[knob_name] = best_match_cfg[knob_name] def __setitem__(self, name, entity): """set the entity(knob) of by name Parameters ---------- name: str name of the entity entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity value of the entity """ self._entity_map[name] = entity def __repr__(self): return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)