"""Tag class for TVM operators.""" from ._ffi.base import _LIB_NAME try: from decorator import decorate except ImportError as err_msg: # Allow decorator to be missing in runtime if _LIB_NAME != "libtvm_runtime.so": raise err_msg class TagScope(object): """Tag scope object to set tag for operators, working as context manager and decorator both. See also tag_scope. """ current = None def __init__(self, tag): self._old_scope = None self.tag = tag def __enter__(self): if TagScope.current is not None: raise ValueError("nested op_tag is not allowed for now") self._old_scope = TagScope.current TagScope.current = self return self def __exit__(self, ptype, value, trace): assert self._old_scope is None TagScope.current = self._old_scope def __call__(self, fdecl): def tagged_fdecl(func, *args, **kwargs): with self: return func(*args, **kwargs) return decorate(fdecl, tagged_fdecl) def tag_scope(tag): """The operator tag scope. Parameters ---------- tag: str The tag name. Returns ------- tag_scope: TagScope The tag scope object, which can be used as decorator or context manger. Example ------- .. code-block:: python n = tvm.var('n') m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((n, l), name='A') B = tvm.placeholder((m, l), name='B') k = tvm.reduce_axis((0, l), name='k') with tvm.tag_scope(tag='matmul'): C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k)) # or use tag_scope as decorator @tvm.tag_scope(tag="conv") def compute_relu(data): return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i))) """ return TagScope(tag)