Commit b787ffa3 by tqchen Committed by Tianqi Chen

[REFACTOR][PY] Establish tvm.tir

- Move related files into the corresponding location as in C++
- Keep the top-level TVM API backward compatible to make minimum changes in topi
parent a6c42b34
......@@ -33,36 +33,40 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd
# tvm.error
from . import error
# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir
# tvm.tir
from . import tir
# tvm.target
from . import target
# others
from . import tensor
from . import arith
from . import expr
from . import stmt
from . import make
from . import ir_pass
from . import schedule
from . import ir_builder
from . import target
from . import generic
from . import hybrid
from . import testing
from . import error
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
# backward compact for topi, to be removed later
from .tir import expr, stmt, ir_builder, ir_pass, generic
from .tir.op import *
from . import intrin
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
......@@ -227,7 +227,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types or tvm.expr.Var only' % type(x))
'primitive types or tvm.tir.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func):
......
......@@ -26,17 +26,19 @@ import tvm.runtime
from tvm.runtime import Object, ndarray
from tvm.ir import container
from tvm.target import codegen
from tvm.tir import expr
from tvm.tir import ir_pass
from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import target as _target
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import target as _target
from . import make
from .stmt import LoweredFunc
class DumpIR(object):
......@@ -61,7 +63,7 @@ class DumpIR(object):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
......
......@@ -15,9 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs
from .. import api as _api, intrin as _intrin
import tvm
from .. import api as _api
def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
......@@ -46,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
return _api.extern(
(n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
......@@ -78,7 +77,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
return _api.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.batch_matmul"
if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative",
......
......@@ -15,10 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to cuBLAS libraries."""
from __future__ import absolute_import as _abs
import tvm
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
......@@ -44,7 +42,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
......@@ -73,6 +71,6 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
......@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to cuBLASlt libraries."""
from __future__ import absolute_import as _abs
import tvm
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
......@@ -46,6 +45,6 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublaslt.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
......@@ -18,8 +18,8 @@
# pylint: disable-msg=C0103
import ctypes
import numpy as np
import tvm
from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func
# algos can be read from cudnn.h
......@@ -365,7 +365,7 @@ def conv_forward(x,
if dims == 4:
return _api.extern(
oshape, [x, w],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d.forward",
conv_mode,
tensor_format,
......@@ -383,7 +383,7 @@ def conv_forward(x,
return _api.extern(
oshape, [x, w],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv3d.forward",
conv_mode,
tensor_format,
......
......@@ -18,8 +18,8 @@
# pylint: disable-msg=C0103
import ctypes
import numpy as np
import tvm
from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func
......@@ -113,7 +113,7 @@ def conv2d_forward(x,
return _api.extern(
list(oshape), [x, w],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.conv2d.forward",
conv_mode,
data_type,
......
......@@ -15,9 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to MPS libraries."""
from __future__ import absolute_import as _abs
import tvm
from .. import api as _api
from .. import intrin as _intrin
# pylint: disable=C0103,W0612
......@@ -50,7 +49,7 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = c
return _api.extern(
(m, n), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
name="C")
......@@ -82,6 +81,6 @@ def conv2d(data, weight, pad='SAME', stride=1):
return _api.extern(
(n, ho, wo, co), [data, weight],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
name="C")
......@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to NNPACK libraries."""
import tvm
import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
def is_available():
......@@ -46,7 +45,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
m = rhs.shape[0]
return _api.extern(
(m, ), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0], nthreads), name="C")
......@@ -110,7 +109,7 @@ def convolution_inference(
return _api.extern(
(batch, output_channels, output_height, output_width),
[data, kernel, bias] if bias is not None else [data, kernel],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference",
ins[0],
ins[1],
......@@ -163,7 +162,7 @@ def convolution_inference_without_weight_transform(
return _api.extern(
(batch, output_channels, output_height, output_width),
[data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference_without_weight_transform",
ins[0],
ins[1],
......@@ -198,7 +197,7 @@ def convolution_inference_weight_transform(
return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
......
......@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to random library."""
import tvm
import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
def randint(low, high, size, dtype='int32'):
......@@ -39,7 +38,7 @@ def randint(low, high, size, dtype='int32'):
A tensor with specified size and dtype
"""
assert 'int' in dtype, "the type of randint output must be int or uint"
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype)
......@@ -67,7 +66,7 @@ def uniform(low, high, size):
out : Tensor
A tensor with specified size and dtype.
"""
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
......@@ -91,7 +90,7 @@ def normal(loc, scale, size):
out : Tensor
A tensor with specified size and dtype
"""
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
......
......@@ -15,10 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to rocBLAS libraries."""
from __future__ import absolute_import as _abs
import tvm
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS
......@@ -43,6 +41,6 @@ def matmul(lhs, rhs, transa=False, transb=False):
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.rocblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
......@@ -14,117 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic opertors in TVM.
We follow the numpy naming convention for this interface
(e.g., tvm.generic.multitply ~ numpy.multiply).
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
from . import make as _make
#Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of add operaton.
"""
return _make._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
"""Generic subtract operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _make._OpSub(lhs, rhs)
def multiply(lhs, rhs):
"""Generic multiply operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _make._OpMul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpFloorDiv(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
Parameters
----------
src : object
The source operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._cast(dtype, src)
"""Generic operators."""
# pylint:disable=unused-wildcard-import, wildcard-import
from .tir.generic import *
......@@ -17,16 +17,17 @@
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
from ..intrin import call_pure_intrin
#pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin
LOOP_INTRIN = {
'range' : For.Serial,
......@@ -69,15 +70,15 @@ def bind(func_id, args):
def _math_intrin(func_id, args):
# pylint: disable=import-outside-toplevel
from .. import intrin
return getattr(intrin, func_id)(*args)
import tvm.tir.op
return getattr(tvm.tir.op, func_id)(*args)
sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
def _min_max(func_id, args):
_internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1])
return getattr(_expr, func_id.title())(args[0], args[1])
min = max = _min_max #pylint: disable=invalid-name
......@@ -127,7 +128,7 @@ def len(func_id, args):
def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])
return _expr.Cast(func_id, args[0])
float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
......
......@@ -24,7 +24,11 @@ import types
import numbers
from enum import Enum
from tvm.ir.container import Array
from tvm.ir import Array, Range
import tvm.tir
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from .util import _internal_assert
from . import calls
......@@ -35,12 +39,7 @@ from ..api import any as _any
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt
from .. import api as _api
from .. import ir_pass as _ir_pass
def concat_list_to_block(lst):
......@@ -79,13 +78,13 @@ class Symbol(Enum):
def _floordiv(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floordiv(x, y)
return tvm.tir.floordiv(x, y)
return operator.floordiv(x, y)
def _floormod(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floormod(x, y)
return tvm.tir.floormod(x, y)
return operator.mod(x, y)
......@@ -208,11 +207,11 @@ class HybridParser(ast.NodeVisitor):
if _scope == 'global':
body = self.wrap_up_binds(body)
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
for elem in to_pop:
self.symbols.pop(elem)
......@@ -223,7 +222,7 @@ class HybridParser(ast.NodeVisitor):
def wrap_up_binds(self, body):
for _, iter_var in self.binds.items():
ext = iter_var.dom.extent
body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
body = tvm.tir.AttrStmt(iter_var, 'thread_extent', ext, body)
self.binds = {}
return body
......@@ -271,7 +270,7 @@ class HybridParser(ast.NodeVisitor):
return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load):
return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
return tvm.tir.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index)
return entry, [_api.const(0, 'int32')]
# Do I need any assertion here?
......@@ -304,10 +303,10 @@ class HybridParser(ast.NodeVisitor):
args = [_api.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
value = HybridParser._binop_maker[type(node.op)](read, rhs)
return _make.Provide(buf.op, 0, value, args)
return tvm.tir.Provide(buf.op, 0, value, args)
def visit_Assign(self, node):
......@@ -358,13 +357,13 @@ class HybridParser(ast.NodeVisitor):
lhs = self.visit(lhs_)
if lhs is not None:
buf, args = lhs
return _make.Provide(buf.op, 0, rhs, args)
return tvm.tir.Provide(buf.op, 0, rhs, args)
return util.make_nop()
lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!")
res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args)
return res
......@@ -391,8 +390,8 @@ class HybridParser(ast.NodeVisitor):
arr = arr[i.value]
return arr
if isinstance(node.ctx, ast.Load):
return _make.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index)
return tvm.tir.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index)
return arr, args
def visit_With(self, node):
......@@ -426,14 +425,14 @@ class HybridParser(ast.NodeVisitor):
else_body = visit_list_to_block(self.visit, node.orelse)
else:
else_body = None
return _make.IfThenElse(cond, if_body, else_body)
return tvm.tir.IfThenElse(cond, if_body, else_body)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if_body = self.visit(node.body)
else_body = self.visit(node.orelse)
return _make.Select(cond, if_body, else_body)
return tvm.tir.Select(cond, if_body, else_body)
def visit_Compare(self, node):
......@@ -543,7 +542,7 @@ class HybridParser(ast.NodeVisitor):
else:
_internal_assert(not isinstance(for_type, tuple), \
"Micro expansion should be handled before!")
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
self.symbols.pop(_name)
return res
......@@ -580,7 +579,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Assert(self, node):
test = self.visit(node.test)
mesg = _api.convert(self.visit(node.msg))
return _make.AssertStmt(test, mesg, util.make_nop())
return tvm.tir.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, args, symbols, closure_vars):
......
......@@ -22,12 +22,13 @@ import logging
import sys
import numpy
from tvm._ffi.base import numeric_types
from tvm.ir.container import Array
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from .._ffi.base import numeric_types
from ..tensor import Tensor
......@@ -46,7 +47,7 @@ def _internal_assert(cond, err):
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
"""Returns a 'no operation' node in HalideIR."""
return _make.Evaluate(_api.const(0, dtype='int32'))
return _stmt.Evaluate(_api.const(0, dtype='int32'))
def is_docstring(node):
......@@ -77,10 +78,10 @@ def replace_io(body, rmap):
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
return _expr.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
......
......@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs
from .attrs import Attrs, make_node
from .container import Array, Map
from . import transform
......@@ -18,6 +18,7 @@
import tvm._ffi
from tvm.runtime import Object
import tvm.runtime._ffi_node_api
from . import _ffi_api
......@@ -91,3 +92,40 @@ class Attrs(Object):
def __getitem__(self, item):
return self.__getattr__(item)
def make_node(type_key, **kwargs):
"""Make a new IR node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding IR Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10
"""
args = [type_key]
for k, v in kwargs.items():
args += [k, v]
return tvm.runtime._ffi_node_api.MakeNode(*args)
......@@ -53,7 +53,7 @@ class Node(Object):
return _ffi_api.AsText(self, show_meta_data, annotate)
def __str__(self):
return self.astext(show_meta_data=False)
return _ffi_api.PrettyPrint(self)
@tvm._ffi.register_object("relay.SourceName")
......
......@@ -99,3 +99,23 @@ class Range(Node):
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
"""
@staticmethod
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : PrimExpr
The minimum value of the range.
extent : PrimExpr
The extent of the range.
Returns
-------
rng : Range
The constructed range.
"""
return _ffi_api.range_by_min_extent(min_value, extent)
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""namespace of IR node builder make function
This namespace is used for developers. While you do not see any declarations.
......@@ -23,19 +24,22 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
import tvm._ffi
import tvm.ir
from tvm.ir import make_node as node
from tvm.tir import Call
def range_by_min_extent(min_value, extent):
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : Expr
min_value : PrimExpr
The minimum value of the range.
extent : Expr
extent : PrimExpr
The extent of the range.
Returns
......@@ -43,45 +47,6 @@ def range_by_min_extent(min_value, extent):
rng : Range
The constructed range.
"""
return _range_by_min_extent(min_value, extent)
def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding DSL Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 10
"""
args = [type_key]
for k, v in kwargs.items():
args += [k, v]
return _Node(*args)
return tvm.ir.Range.make_by_min_extent(min_value, extent)
tvm._ffi._init_api("tvm.make")
......@@ -509,7 +509,7 @@ class ParseTreeToRelayIR(RelayVisitor):
_, type_params = zip(*type_params)
self.exit_var_scope()
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify
......
......@@ -181,11 +181,11 @@ class VMCompiler(object):
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
dev_type = tvm.tir.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
dev_type = tvm.tir.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
......
......@@ -932,7 +932,7 @@ def _shape():
def _impl(inputs, attr, params):
is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]:
if not isinstance(axis, (int, tvm.expr.IntImm)):
if not isinstance(axis, (int, tvm.tir.IntImm)):
is_symbolic_shape = True
break
......
......@@ -557,7 +557,7 @@ def split_shape_func(attrs, inputs, _):
"""
Shape function for split op.
"""
if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections)
else:
indices_or_sections = get_const_tuple(attrs.indices_or_sections)
......
......@@ -14,126 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""The computation schedule api of TVM."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import expr as _expr, Buffer
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :any:`decl_buffer` instead.
See the documentation of :any:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
return _api_internal._BufferVStore(self, begin, value)
@tvm._ffi.register_object
......
......@@ -60,3 +60,4 @@ from .generic_func import GenericFunc
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
from . import datatype
from . import codegen
from .intrin import register_intrin_rule
......@@ -19,7 +19,7 @@ import tvm._ffi
import tvm.runtime._ffi_api
from tvm.runtime import convert, DataType
from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
def register(type_name, type_code):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Target dependent intrinsic registration."""
import tvm._ffi
from tvm.tir import call_pure_extern
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
Intrinsic generation rules are callback functions for
code generator to get device specific calls.
This function simply translates to.
:code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
TVM may already pre-register intrinsic rules in the backend.
However, user can use this function to change the intrinsic translation
behavior or add new intrinsic rules during runtime.
Parameters
----------
target : str
The name of codegen target.
intrin : str
The name of the intrinsic.
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers exp expansion rule for opencl.
.. code-block:: python
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op):
"""Intrinsic rule: Add float suffix if it is float32.
This is an example intrinsic generation rule.
Parameters
----------
op : PrimExpr
The call expression of original intrinsic.
Returns
-------
ret : PrimExpr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if op.dtype == "float32":
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
if op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args)
return op
def _rule_float_direct(op):
"""Intrinsic rule: Directly call pure extern function for floats.
This is an example intrinsic generation rule.
Parameters
----------
op : PrimExpr
The call expression of original intrinsic.
Returns
-------
ret : PrimExpr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.name, *op.args)
return None
# opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
......@@ -19,10 +19,9 @@
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from . import _api_internal
from . import make as _make
from . import expr as _expr
class TensorSlice(ObjectGeneric, _expr.ExprOp):
......@@ -74,7 +73,7 @@ class Tensor(Object, _expr.ExprOp):
else:
raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.op.name,
return _expr.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
......@@ -207,136 +206,3 @@ class HybridOp(Operation):
def axis(self):
"""Represent the IterVar axis, also defined when it is a HybridOp"""
return self.__getattr__("axis")
@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
......@@ -18,11 +18,12 @@
import tvm._ffi
from tvm.runtime import Object
from tvm.ir import Range
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
......@@ -39,7 +40,7 @@ def _get_region(tslice):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
region.append(Range.make_by_min_extent(begin, 1))
return region
@tvm._ffi.register_object
......@@ -136,7 +137,7 @@ def decl_tensor_intrin(op,
scalar_params = []
if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin
"""Namespace for Tensor-level IR"""
from tvm.ir import PrimExpr
from .buffer import Buffer, decl_buffer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, min_value, max_value
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from . import ir_builder
from . import ir_pass
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir"""
import tvm._ffi
tvm._ffi._init_api("tir", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Abstraction for array data structures."""
from numbers import Integral
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import PrimExpr
from . import _ffi_api
@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :py:func:`~decl_buffer` instead.
See the documentation of :py:func:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _ffi_api.BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value)
def decl_buffer(shape,
dtype=None,
name="buffer",
data=None,
strides=None,
elem_offset=None,
scope="",
data_alignment=-1,
offset_factor=0,
buffer_type=""):
"""Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
Parameters
----------
shape : tuple of Expr
The shape of the buffer.
dtype : str, optional
The data type of the buffer.
name : str, optional
The name of the buffer.
data : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
elem_offset: Expr, optional
The beginning offset of the array to data.
In terms of number of elements of dtype.
scope: str, optional
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
data_alignment: int, optional
The alignment of data pointer in bytes.
If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
buffer_type: str, optional, {"", "auto_broadcast"}
auto_broadcast buffer allows one to implement broadcast computation
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
Returns
-------
buffer : Buffer
The created buffer
Example
-------
Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
.. code-block:: python
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
While DLTensor data structure is very general, it is usually helpful
to create function that only handles specific case of data structure
and make compiled function benefit from it.
If user pass strides and elem_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""
# pylint: disable=import-outside-toplevel
from .expr import Var
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
dtype = "float32" if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = Var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = Var(name, "handle")
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Data layout."""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
See Also
--------
layout : Declare a layout
"""
def __len__(self):
return _ffi_api.LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _ffi_api.LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _ffi_api.LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _ffi_api.LayoutFactorOf(self, axis)
@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
See Also
--------
bijective_layout : Declare a layout
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _ffi_api.BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _ffi_api.BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _ffi_api.BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _ffi_api.BijectiveLayoutBackwardShape(self, shape)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _ffi_api.Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _ffi_api.BijectiveLayout(src_layout, dst_layout)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic opertors in TVM.
We follow the numpy naming convention for this interface
(e.g., tvm.generic.multitply ~ numpy.multiply).
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
from . import _ffi_api
# Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of add operaton.
"""
return _ffi_api._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
"""Generic subtract operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _ffi_api._OpSub(lhs, rhs)
def multiply(lhs, rhs):
"""Generic multiply operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _ffi_api._OpMul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._OpFloorDiv(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
Parameters
----------
src : object
The source operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._cast(dtype, src)
......@@ -16,15 +16,13 @@
# under the License.
"""Developer API of IR node builder make function."""
from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType
from tvm.runtime import ObjectGeneric, DataType, convert, const
from tvm.ir import container as _container
from . import api as _api
from . import stmt as _stmt
from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from .expr import Call as _Call
class WithScope(object):
"""Auxiliary scope with"""
......@@ -53,7 +51,7 @@ class BufferVar(ObjectGeneric):
.. code-block:: python
# The following code generate IR for x[0] = x[
ib = tvm.ir_builder.create()
ib = tvm.tir.ir_builder.create()
x = ib.pointer("float32")
x[0] = x[10] + 1
......@@ -78,19 +76,19 @@ class BufferVar(ObjectGeneric):
def __getitem__(self, index):
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index)
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value):
value = _api.convert(value)
value = convert(value)
if value.dtype != self._content_type:
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
t = DataType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index))
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index))
class IRBuilder(object):
......@@ -117,7 +115,7 @@ class IRBuilder(object):
"""Pop sequence from stack"""
seq = self._seq_stack.pop()
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
seq.append(_stmt.Evaluate(0))
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]]
......@@ -138,7 +136,7 @@ class IRBuilder(object):
The statement to be emitted or callable that build stmt given body.
"""
if isinstance(stmt, _expr.Call):
stmt = _make.Evaluate(stmt)
stmt = _stmt.Evaluate(stmt)
assert isinstance(stmt, _stmt.Stmt) or callable(stmt)
self._seq_stack[-1].append(stmt)
......@@ -167,10 +165,10 @@ class IRBuilder(object):
x[i] = x[i - 1] + 1
"""
if isinstance(node, string_types):
node = _make.StringImm(node)
node = _expr.StringImm(node)
if isinstance(value, string_types):
value = _make.StringImm(value)
self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x))
value = _expr.StringImm(value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
"""Create a for iteration scope.
......@@ -211,7 +209,7 @@ class IRBuilder(object):
name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
self.nidx += 1
self._seq_stack.append([])
loop_var = _api.var(name, dtype=dtype)
loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin)
def _exit_cb():
if for_type == "serial":
......@@ -224,7 +222,7 @@ class IRBuilder(object):
for_type_id = 3
else:
raise ValueError("Unknown for_type")
self.emit(_make.For(
self.emit(_stmt.For(
loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
return WithScope(loop_var, _exit_cb)
......@@ -253,7 +251,7 @@ class IRBuilder(object):
"""
self._seq_stack.append([])
def _exit_cb():
self.emit(_make.IfThenElse(cond, self._pop_seq(), None))
self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None))
return WithScope(None, _exit_cb)
def else_scope(self):
......@@ -286,7 +284,7 @@ class IRBuilder(object):
self._seq_stack[-1].pop()
self._seq_stack.append([])
def _exit_cb():
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb)
def new_scope(self):
......@@ -326,13 +324,13 @@ class IRBuilder(object):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _api.var(name, dtype="handle")
buffer_var = _expr.Var(name, dtype="handle")
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _make.Allocate(
buffer_var, dtype, shape, _api.const(1, dtype="uint1"), x))
self.emit(lambda x: _stmt.Allocate(
buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, dtype)
def pointer(self, content_type, name="ptr"):
......@@ -351,7 +349,7 @@ class IRBuilder(object):
ptr : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _api.var(name, dtype="handle")
buffer_var = _expr.Var(name, dtype="handle")
return BufferVar(self, buffer_var, content_type)
def buffer_ptr(self, buf):
......@@ -380,7 +378,8 @@ class IRBuilder(object):
expr : Expr
The expression will likely tag.
"""
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
return _expr.Call(expr.dtype, "likely", [expr],
_expr.Call.PureIntrinsic, None, 0)
def get(self):
"""Return the builded IR.
......
......@@ -25,4 +25,4 @@ You can read "include/tvm/tir/ir_pass.h" for the function signature and
"""
import tvm._ffi
tvm._ffi._init_api("tvm.ir_pass")
tvm._ffi._init_api("tvm.ir_pass", __name__)
......@@ -25,18 +25,19 @@ Each statement node have subfields that can be visited from python side.
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store)
st = tvm.tir.stmt.Store(a, x + 1, 1)
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
"""
import tvm._ffi
from tvm.runtime import Object
from . import make as _make
from . import _ffi_api
class Stmt(Object):
pass
"""Base class of all the statements."""
@tvm._ffi.register_object
class LetStmt(Stmt):
......@@ -47,7 +48,7 @@ class LetStmt(Stmt):
var : Var
The variable in the binding.
value : Expr
value : PrimExpr
The value in to be binded.
body : Stmt
......@@ -55,7 +56,7 @@ class LetStmt(Stmt):
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_make.LetStmt, var, value, body)
_ffi_api.LetStmt, var, value, body)
@tvm._ffi.register_object
......@@ -64,10 +65,10 @@ class AssertStmt(Stmt):
Parameters
----------
condition : Expr
condition : PrimExpr
The assert condition.
message : Expr
message : PrimExpr
The error message.
body : Stmt
......@@ -75,7 +76,7 @@ class AssertStmt(Stmt):
"""
def __init__(self, condition, message, body):
self.__init_handle_by_constructor__(
_make.AssertStmt, condition, message, body)
_ffi_api.AssertStmt, condition, message, body)
@tvm._ffi.register_object
......@@ -95,7 +96,7 @@ class ProducerConsumer(Stmt):
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_make.ProducerConsumer, func, is_producer, body)
_ffi_api.ProducerConsumer, func, is_producer, body)
@tvm._ffi.register_object
......@@ -107,10 +108,10 @@ class For(Stmt):
loop_var : Var
The loop variable.
min_val : Expr
min_val : PrimExpr
The begining value.
extent : Expr
extent : PrimExpr
The length of the loop.
for_type : int
......@@ -134,7 +135,7 @@ class For(Stmt):
device_api,
body):
self.__init_handle_by_constructor__(
_make.For, loop_var, min_val, extent,
_ffi_api.For, loop_var, min_val, extent,
for_type, device_api, body)
......@@ -147,18 +148,19 @@ class Store(Stmt):
buffer_var : Var
The buffer Variable.
value : Expr
value : PrimExpr
The value we want to store.
index : Expr
index : PrimExpr
The index in the store expression.
predicate : Expr
predicate : PrimExpr
The store predicate.
"""
def __init__(self, buffer_var, value, index, predicate):
def __init__(self, buffer_var, value, index, predicate=None):
args = [] if predicate is None else [predicate]
self.__init_handle_by_constructor__(
_make.Store, buffer_var, value, index, predicate)
_ffi_api.Store, buffer_var, value, index, *args)
@tvm._ffi.register_object
......@@ -173,7 +175,7 @@ class Provide(Stmt):
value_index : int
The output value index
value : Expr
value : PrimExpr
The value to be stored.
args : list of Expr
......@@ -181,7 +183,7 @@ class Provide(Stmt):
"""
def __init__(self, func, value_index, value, args):
self.__init_handle_by_constructor__(
_make.Provide, func, value_index, value, args)
_ffi_api.Provide, func, value_index, value, args)
@tvm._ffi.register_object
......@@ -199,7 +201,7 @@ class Allocate(Stmt):
extents : list of Expr
The extents of the allocate
condition : Expr
condition : PrimExpr
The condition.
body : Stmt
......@@ -212,7 +214,7 @@ class Allocate(Stmt):
condition,
body):
self.__init_handle_by_constructor__(
_make.Allocate, buffer_var, dtype,
_ffi_api.Allocate, buffer_var, dtype,
extents, condition, body)
......@@ -228,7 +230,7 @@ class AttrStmt(Stmt):
attr_key : str
Attribute type key.
value : Expr
value : PrimExpr
The value of the attribute
body : Stmt
......@@ -236,7 +238,7 @@ class AttrStmt(Stmt):
"""
def __init__(self, node, attr_key, value, body):
self.__init_handle_by_constructor__(
_make.AttrStmt, node, attr_key, value, body)
_ffi_api.AttrStmt, node, attr_key, value, body)
@tvm._ffi.register_object
......@@ -250,7 +252,7 @@ class Free(Stmt):
"""
def __init__(self, buffer_var):
self.__init_handle_by_constructor__(
_make.Free, buffer_var)
_ffi_api.Free, buffer_var)
@tvm._ffi.register_object
......@@ -271,7 +273,7 @@ class Realize(Stmt):
bounds : list of range
The bound of realize
condition : Expr
condition : PrimExpr
The realize condition.
body : Stmt
......@@ -285,7 +287,7 @@ class Realize(Stmt):
condition,
body):
self.__init_handle_by_constructor__(
_make.Realize, func, value_index, dtype,
_ffi_api.Realize, func, value_index, dtype,
bounds, condition, body)
......@@ -300,7 +302,7 @@ class SeqStmt(Stmt):
"""
def __init__(self, seq):
self.__init_handle_by_constructor__(
_make.SeqStmt, seq)
_ffi_api.SeqStmt, seq)
def __getitem__(self, i):
return self.seq[i]
......@@ -315,7 +317,7 @@ class IfThenElse(Stmt):
Parameters
----------
condition : Expr
condition : PrimExpr
The expression
then_case : Stmt
......@@ -326,7 +328,7 @@ class IfThenElse(Stmt):
"""
def __init__(self, condition, then_case, else_case):
self.__init_handle_by_constructor__(
_make.IfThenElse, condition, then_case, else_case)
_ffi_api.IfThenElse, condition, then_case, else_case)
@tvm._ffi.register_object
......@@ -335,12 +337,12 @@ class Evaluate(Stmt):
Parameters
----------
value : Expr
value : PrimExpr
The expression to be evalued.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.Evaluate, value)
_ffi_api.Evaluate, value)
@tvm._ffi.register_object
......@@ -363,7 +365,7 @@ class Prefetch(Stmt):
"""
def __init__(self, func, value_index, dtype, bounds):
self.__init_handle_by_constructor__(
_make.Prefetch, func, value_index, dtype, bounds)
_ffi_api.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object
......@@ -417,6 +419,3 @@ def stmt_list(stmt):
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_make.stmt_list = stmt_list
......@@ -30,50 +30,50 @@
namespace tvm {
namespace tir {
TVM_REGISTER_GLOBAL("_Var")
TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) {
return Var(s, t);
});
TVM_REGISTER_GLOBAL("_SizeVar")
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});
TVM_REGISTER_GLOBAL("make.abs")
TVM_REGISTER_GLOBAL("tir.abs")
.set_body_typed(tvm::abs);
TVM_REGISTER_GLOBAL("make.isnan")
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("make.floor")
TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);
TVM_REGISTER_GLOBAL("make.ceil")
TVM_REGISTER_GLOBAL("tir.ceil")
.set_body_typed(tvm::ceil);
TVM_REGISTER_GLOBAL("make.round")
TVM_REGISTER_GLOBAL("tir.round")
.set_body_typed(tvm::round);
TVM_REGISTER_GLOBAL("make.nearbyint")
TVM_REGISTER_GLOBAL("tir.nearbyint")
.set_body_typed(tvm::nearbyint);
TVM_REGISTER_GLOBAL("make.trunc")
TVM_REGISTER_GLOBAL("tir.trunc")
.set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("make._cast")
TVM_REGISTER_GLOBAL("tir._cast")
.set_body_typed(tvm::cast);
TVM_REGISTER_GLOBAL("make._range_by_min_extent")
TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("make.SeqStmt")
TVM_REGISTER_GLOBAL("tir.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});
TVM_REGISTER_GLOBAL("make.For")
TVM_REGISTER_GLOBAL("tir.For")
.set_body_typed([](
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
......@@ -85,7 +85,7 @@ TVM_REGISTER_GLOBAL("make.For")
body);
});
TVM_REGISTER_GLOBAL("make.Load")
TVM_REGISTER_GLOBAL("tir.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0];
if (args.size() == 3) {
......@@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("make.Load")
}
});
TVM_REGISTER_GLOBAL("make.Store")
TVM_REGISTER_GLOBAL("tir.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PrimExpr value = args[1];
if (args.size() == 3) {
......@@ -105,10 +105,10 @@ TVM_REGISTER_GLOBAL("make.Store")
}
});
TVM_REGISTER_GLOBAL("make.Realize")
TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("make.Call")
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
Array<PrimExpr> args, int call_type,
......@@ -122,12 +122,12 @@ TVM_REGISTER_GLOBAL("make.Call")
value_index);
});
TVM_REGISTER_GLOBAL("make.CommReducer")
TVM_REGISTER_GLOBAL("tir.CommReducer")
.set_body_typed(CommReducerNode::make);
// make from two arguments
#define REGISTER_MAKE(NodeName) \
TVM_REGISTER_GLOBAL("make."#NodeName) \
TVM_REGISTER_GLOBAL("tir."#NodeName) \
.set_body_typed(NodeName ## Node::make); \
......@@ -172,7 +172,7 @@ REGISTER_MAKE(Evaluate);
// overloaded, needs special handling
// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](
Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
){
......@@ -180,14 +180,14 @@ TVM_REGISTER_GLOBAL("make.Allocate")
});
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
......@@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse")
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
......
......@@ -34,10 +34,10 @@
namespace tvm {
TVM_REGISTER_GLOBAL("_min_value")
TVM_REGISTER_GLOBAL("tir.min_value")
.set_body_typed(min_value);
TVM_REGISTER_GLOBAL("_max_value")
TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("Range")
......@@ -49,66 +49,6 @@ TVM_REGISTER_GLOBAL("Range")
}
});
namespace tir {
TVM_REGISTER_GLOBAL("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10);
auto buffer_type = args[9].operator std::string();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], type);
});
TVM_REGISTER_GLOBAL("_BufferAccessPtr")
.set_body_method(&Buffer::access_ptr);
TVM_REGISTER_GLOBAL("_BufferVLoad")
.set_body_method(&Buffer::vload);
TVM_REGISTER_GLOBAL("_BufferVStore")
.set_body_method(&Buffer::vstore);
TVM_REGISTER_GLOBAL("_Layout")
.set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("_LayoutIndexOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutNdim")
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("_LayoutGetItem")
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("_BijectiveLayout")
.set_body_typed(BijectiveLayoutNode::make);
TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardIndex")
.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape")
.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
namespace te {
TVM_REGISTER_GLOBAL("_Tensor")
.set_body_typed(TensorNode::make);
......
......@@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value) {
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("make.IntImm")
TVM_REGISTER_GLOBAL("ir.IntImm")
.set_body_typed([](DataType dtype, int64_t value) {
return IntImm(dtype, value);
});
......@@ -97,7 +97,7 @@ FloatImm::FloatImm(DataType dtype, double value) {
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("make.FloatImm")
TVM_REGISTER_GLOBAL("ir.FloatImm")
.set_body_typed([](DataType dtype, double value) {
return FloatImm(dtype, value);
});
......
......@@ -304,6 +304,6 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr")
TVM_REGISTER_GLOBAL("node.NodeListAttrNames")
.set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("make._Node")
TVM_REGISTER_GLOBAL("node.MakeNode")
.set_body(MakeNode);
} // namespace tvm
......@@ -906,7 +906,9 @@ static const char* kSemVer = "v0.0.4";
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) {
return AsText(node, false, nullptr);
Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
}
std::string AsText(const ObjectRef& node,
......@@ -918,6 +920,10 @@ std::string AsText(const ObjectRef& node,
return doc.str();
}
TVM_REGISTER_GLOBAL("ir.PrettyPrint")
.set_body_typed(PrettyPrint);
TVM_REGISTER_GLOBAL("ir.AsText")
.set_body_typed(AsText);
} // namespace tvm
......@@ -20,6 +20,7 @@
/*!
* \file buffer.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
......@@ -460,5 +461,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_REGISTER_NODE_TYPE(BufferNode);
TVM_REGISTER_GLOBAL("tir.Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10);
auto buffer_type = args[9].operator std::string();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], type);
});
TVM_REGISTER_GLOBAL("tir.BufferAccessPtr")
.set_body_method(&Buffer::access_ptr);
TVM_REGISTER_GLOBAL("tir.BufferVLoad")
.set_body_method(&Buffer::vload);
TVM_REGISTER_GLOBAL("tir.BufferVStore")
.set_body_method(&Buffer::vstore);
} // namespace tir
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file src/lang/data_layout.cc
* \brief Data Layout expression.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/ir_pass.h>
#include <cctype>
......@@ -371,5 +372,44 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
});
TVM_REGISTER_GLOBAL("tir.Layout")
.set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("tir.LayoutIndexOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutNdim")
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("tir.LayoutGetItem")
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
.set_body_typed(BijectiveLayoutNode::make);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
} // namespace tvm
......@@ -24,7 +24,7 @@ def test_reduce_prims():
n = tvm.size_var('n')
m = tvm.size_var('m')
A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
R = tvm.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule
......@@ -232,8 +232,8 @@ def test_rfactor_elemwise_threads():
def test_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
......@@ -279,8 +279,8 @@ def test_argmax():
def test_rfactor_argmax():
def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
......
......@@ -82,10 +82,10 @@ def test_compile_tuple_dup():
def test_compile_full():
# Shape calculations can happen in int64. The test checks that full operator
# can handle when shapes are not int32
shape = (tvm.expr.IntImm('int32', 1),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int32', 64))
shape = (tvm.tir.IntImm('int32', 1),
tvm.tir.IntImm('int64', 16),
tvm.tir.IntImm('int64', 16),
tvm.tir.IntImm('int32', 64))
output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
f = relay.Function([], output)
mod = tvm.IRModule.from_expr(f)
......
......@@ -41,7 +41,7 @@ def test_basic_build():
}
# build
targets = {
tvm.expr.IntImm("int32", ctx.device_type): tgt
tvm.tir.IntImm("int32", ctx.device_type): tgt
}
g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params)
......
......@@ -77,9 +77,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol))
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
return func
......
......@@ -307,7 +307,7 @@ def get_synthetic_lib():
subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
gcc_input3], relay.copy(gcc_input0))
subgraph0 = subgraph0.set_attribute(
"Primitive", tvm.expr.IntImm("int32", 1))
"Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph0
subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2])
......@@ -320,7 +320,7 @@ def get_synthetic_lib():
subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
gcc_input7], relay.copy(gcc_input4))
subgraph1 = subgraph1.set_attribute(
"Primitive", tvm.expr.IntImm("int32", 1))
"Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph1
subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5])
......
......@@ -17,7 +17,7 @@
""" test ir"""
import tvm
from tvm import relay
from tvm.expr import *
from tvm.tir.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal
import numpy as np
......@@ -110,7 +110,7 @@ def test_type_relation():
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
......
......@@ -69,7 +69,7 @@ type List[A] {
"""
def roundtrip(expr):
x = relay.fromtext(str(expr))
x = relay.fromtext(expr.astext())
assert_graph_equal(x, expr)
......@@ -343,7 +343,7 @@ def test_func():
# attributes
assert parses_as(
"fn (n=5) { () }",
relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5)))
relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
)
......
......@@ -630,8 +630,8 @@ def test_upsampling_infer_type():
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(h*scale)),
tvm.tir.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
......@@ -647,9 +647,9 @@ def test_upsampling3d_infer_type():
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)),
tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(d*scale)),
tvm.tir.Cast("int32", tvm.round(h*scale)),
tvm.tir.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
......
......@@ -517,7 +517,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or not alpha:
if all(isinstance(v, tvm.tir.Var) == 1 for v in data) or not alpha:
return
func = relay.Function([x, y], z)
......
......@@ -154,7 +154,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data):
if all(isinstance(v, tvm.tir.Var) == 1 for v in data):
return
func = relay.Function([x], z)
......
......@@ -160,9 +160,9 @@ def test_type_relation_alpha_equal():
broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
......@@ -322,7 +322,7 @@ def test_multi_node_subgraph():
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))
func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a"))
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
......@@ -332,7 +332,7 @@ def test_multi_node_subgraph():
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b"))
assert not alpha_equal(func0, func1)
......@@ -413,9 +413,9 @@ def test_call_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8")
......
......@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops():
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler",
tvm.expr.StringImm("ccompiler"))
tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
tvm.expr.StringImm("ccompiler_0"))
tvm.tir.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
......@@ -316,7 +316,7 @@ def test_extern_ccompiler_default_ops():
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.set_attribute("Primitive",
tvm.expr.IntImm("int32", 1))
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod = tvm.IRModule()
......
......@@ -65,7 +65,7 @@ def test_tuple_type():
def test_type_relation():
func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast')
attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
attrs = tvm.ir.make_node('attrs.TestAttrs', name='attr', padding=(3,4))
tp = TypeVar('tp')
tf = FuncType([], TupleType([]), [], [])
tt = TensorType([1, 2, 3], 'float32')
......
......@@ -151,9 +151,9 @@ def test_reduce_combiner_simplify():
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer(
lambda x, y: tvm.expr.Select(dummy < 0,
lambda x, y: tvm.tir.Select(dummy < 0,
x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0,
lambda t0: tvm.tir.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0],
......@@ -199,7 +199,7 @@ def test_reduce_combiner_simplify():
assert tvm.ir_pass.Equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
......@@ -211,7 +211,7 @@ def test_reduce_simplify():
k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A')
ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]),
ck.verify(tvm.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now
......@@ -235,23 +235,23 @@ def test_simplify_if_then_else():
tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)),
tvm.tir.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.tir.LE(24512, tmod(((x*4) + y) - 4, 24528)),
tmod(((x*4) + y) - 4, 16),
x), y)
ck.verify(res, expected)
ck.verify(res2, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
res = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
res = tvm.tir.Select(x >= 10,
tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0)
expected = tvm.tir.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
res = tvm.tir.Select(x >= 10,
tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
ck.verify(res, 0)
......
......@@ -228,7 +228,7 @@ def test_select_bound():
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(
tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1))
tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1))
assert bd.min_value == 0
assert bd.max_value == 11
......
......@@ -19,7 +19,7 @@ import tvm
def assert_expr_equal(a, b):
res = tvm.ir_pass.Simplify(a - b)
equal = isinstance(res, tvm.expr.IntImm) and res.value == 0
equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
if not equal:
raise ValueError("{} and {} are not equal".format(a, b))
......
......@@ -23,14 +23,14 @@ def test_domain_touched():
m = tvm.var('m')
a = tvm.placeholder((n, m), name = 'a')
b = tvm.placeholder((n, m), name = 'b')
ir = tvm.make.For(
ir = tvm.tir.For(
i, 0, n, 0, 0,
tvm.make.For(j, 0, m, 0, 0,
tvm.make.Provide(
tvm.tir.For(j, 0, m, 0, 0,
tvm.tir.Provide(
a.op,
0,
tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
[i, j]
)
)
......@@ -51,7 +51,7 @@ def test_domain_touched():
assert a_domain_rw[0].min.value == -1
assert a_domain_rw[0].extent.value == 101
assert a_domain_rw[1].min.value == -1
assert isinstance(a_domain_rw[1].extent, tvm.expr.Add)
assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
assert a_domain_rw[1].extent.a.name == 'm'
assert a_domain_rw[1].extent.b.value == 1
......
......@@ -41,7 +41,7 @@ def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base
assert s.max_value.value == base + stride * lanes - 1
......@@ -99,7 +99,7 @@ def test_max_min():
def test_select():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
{x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
......
......@@ -84,7 +84,7 @@ def test_min_max_select():
assert m.coeff == 3
assert m.base == 1
m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2))
m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2))
assert m.coeff == 1
assert m.base == 0
......
......@@ -25,9 +25,9 @@ def test_stmt_simplify():
with ib.if_scope(i < 12):
A[i] = C[i]
body = tvm.stmt.LetStmt(n, 10, ib.get())
body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body, tvm.stmt.Store)
assert isinstance(body.body, tvm.tir.Store)
def test_thread_extent_simplify():
......@@ -42,9 +42,9 @@ def test_thread_extent_simplify():
ib.scope_attr(ty, "thread_extent", 1)
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.stmt.LetStmt(n, 10, ib.get())
body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body.body.body, tvm.stmt.Store)
assert isinstance(body.body.body.body, tvm.tir.Store)
def test_basic_likely_elimination():
......
......@@ -185,19 +185,19 @@ def test_cuda_shuffle():
def my_vectorize(stmt):
def vectorizer(op):
if op.for_type == tvm.stmt.For.Vectorized:
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.const(4, 'int32')
idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
idx = tvm.tir.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
all_ones = tvm.const(1, 'int32x4')
store = op.body
value = store.value
new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones)
new_a = tvm.tir.Load('int32x4', value.a.buffer_var, idx, all_ones)
bs, ids = [], []
for i in range(4):
bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
bs.append(tvm.tir.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
ids.append(tvm.const(3 - i, 'int32'))
new_b = tvm.make.Shuffle(bs, ids)
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
......
......@@ -29,9 +29,9 @@ def test_llvm_intrin():
tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
0, 3, 1
]
ib.emit(tvm.make.Evaluate(
tvm.make.Call(
"int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
ib.emit(tvm.tir.Evaluate(
tvm.tir.Call(
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
......@@ -643,14 +643,14 @@ def test_llvm_shuffle():
def vectorizer(op):
store = op.body
idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
idx = tvm.tir.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
all_ones = tvm.const(1, 'int32x8')
value = store.value
b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
b_idx = tvm.tir.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
new_a = tvm.tir.Load('int32x8', value.a.buffer_var, idx, all_ones)
new_b = tvm.tir.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
value = new_a + new_b
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
......
......@@ -40,7 +40,7 @@ def test_opencl_ternary_expression():
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
......
......@@ -26,7 +26,7 @@ def test_static_callback():
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
cp = tvm.thread_axis((0, 1), "cop")
finit = tvm.make.StringImm("TVMBackendRunOnce")
finit = tvm.tir.StringImm("TVMBackendRunOnce")
ib.scope_attr(cp, "coproc_uop_scope", finit)
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
......
......@@ -34,7 +34,7 @@ def test_stack_vm_basic():
n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
stmt = tvm.tir.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
......@@ -75,7 +75,7 @@ def test_stack_vm_cond():
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
with ib.if_scope(tvm.make.EQ(i, 4)):
with ib.if_scope(tvm.tir.EQ(i, 4)):
A[i + 1] = A[i] + 1
with ib.else_scope():
A[i + 1] = A[i] + 2
......
......@@ -31,7 +31,7 @@ def test_vector_comparison():
A = tvm.placeholder(n, dtype=dtype, name='A')
B = tvm.compute(
A.shape,
lambda i: tvm.expr.Select(
lambda i: tvm.tir.Select(
A[i] >= 0, A[i] + tvm.const(1, dtype),
tvm.const(0, dtype)), name='B')
s = tvm.create_schedule(B.op)
......
......@@ -18,7 +18,7 @@
import tvm
from ctypes import *
import topi
import tvm.ir_pass as ir_pass
import tvm.tir.ir_pass as ir_pass
import numpy as np
tgt = "llvm"
......@@ -126,7 +126,7 @@ def test_bfloat_add_and_cast_FloatImm():
Z = topi.cast(
topi.add(
topi.cast(X, dtype="custom[bfloat]16"),
tvm.expr.FloatImm("custom[bfloat]16", 1.5)),
tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
dtype="float")
s = tvm.create_schedule([Z.op])
......
......@@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val)
assert isinstance(val, (tvm.expr.IntImm,))
assert isinstance(val, (tvm.tir.IntImm,))
return val.value
ctx = tvm.context(target, 0)
......@@ -46,14 +46,14 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
shape = [tvm_val_2_py_val(j) for j in i.shape]
emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
elif isinstance(i, tvm.expr.Var):
elif isinstance(i, tvm.tir.Var):
emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1])
else:
assert isinstance(i, list)
emu_args.append(numpy.array(i))
compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))] + \
(outs if isinstance(outs, list) else [outs])
module = tvm.build(sch,
compile_args,
......@@ -76,7 +76,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))]
module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))]
module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
h_module = tvm.hybrid.build(sch, module_args, module_outs)
......@@ -111,32 +111,32 @@ def test_outer_product():
return
#Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert ir.extent.name == 'n'
ibody = ir.body
assert isinstance(ibody, tvm.stmt.For)
assert isinstance(ibody, tvm.tir.For)
#Check for j in (0, m)
assert ibody.loop_var.name == 'j'
assert ibody.min.value == 0
assert ibody.extent.name == 'm'
#Check loop body
jblock = ibody.body
assert isinstance(jblock, tvm.stmt.SeqStmt)
assert isinstance(jblock, tvm.tir.SeqStmt)
jbody = jblock[0]
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert isinstance(jbody, tvm.tir.AssertStmt)
assert isinstance(jbody.message, tvm.tir.StringImm)
assert jbody.message.value == "index out of range!"
jbody = jblock[1]
assert isinstance(jbody, tvm.stmt.Provide)
assert isinstance(jbody, tvm.tir.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
assert jbody.args[0].name == 'i'
assert jbody.args[1].name == 'j'
assert isinstance(jbody.value, tvm.expr.Mul)
assert isinstance(jbody.value, tvm.tir.Mul)
mul = jbody.value
assert isinstance(mul.a, tvm.expr.Call)
assert isinstance(mul.a, tvm.tir.Call)
assert mul.a.name == 'a'
assert mul.b.name == 'b'
......@@ -177,21 +177,21 @@ def test_fanout():
return
#Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert tvm.ir_pass.Equal(ir.extent, n - 3)
#Check loopbody
ibody = ir.body
assert isinstance(ibody, tvm.stmt.AttrStmt)
assert isinstance(ibody, tvm.tir.AttrStmt)
abody = ibody.body
assert isinstance(abody, tvm.stmt.Realize)
assert isinstance(abody, tvm.tir.Realize)
assert abody.bounds[0].min.value == 0
assert abody.bounds[0].extent.value == 1
assert abody.func.name == 'sigma'
#Check i loop body
rbody = abody.body
assert isinstance(rbody[0], tvm.stmt.Provide)
assert isinstance(rbody[0], tvm.tir.Provide)
assert rbody[0].func.name == 'sigma'
assert len(rbody[0].args) == 1
assert rbody[0].args[0].value == 0
......@@ -201,13 +201,13 @@ def test_fanout():
assert jloop.min.value == 0
assert jloop.extent.value == 3
jbody = jloop.body
assert isinstance(jbody, tvm.stmt.Provide)
assert isinstance(jbody, tvm.tir.Provide)
assert len(jbody.args) == 1
assert jbody.args[0].value == 0
assert jbody.func.name == 'sigma'
assert isinstance(jbody.value, tvm.expr.Add)
assert isinstance(jbody.value, tvm.tir.Add)
value = jbody.value
assert isinstance(value.a, tvm.expr.Call)
assert isinstance(value.a, tvm.tir.Call)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
......@@ -215,17 +215,17 @@ def test_fanout():
assert len(value.b.args) == 1
assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody[2]
assert isinstance(divide, tvm.stmt.Provide)
assert isinstance(divide, tvm.tir.Provide)
assert len(divide.args) == 1
assert divide.args[0].value == 0
value = divide.value
assert isinstance(value, tvm.expr.Mul)
assert isinstance(value, tvm.tir.Mul)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert abs(value.b.value - (1 / 3.0)) < 1e-5
write = rbody[3]
assert isinstance(write, tvm.stmt.Provide)
assert isinstance(write, tvm.tir.Provide)
assert write.func.name == 'b'
assert write.value.name == 'sigma'
assert len(write.value.args) == 1
......@@ -260,9 +260,9 @@ def test_looptype():
iloop = ir[0]
jloop = ir[1]
kloop = ir[2]
assert iloop.for_type == tvm.stmt.For.Parallel
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
assert iloop.for_type == tvm.tir.For.Parallel
assert jloop.for_type == tvm.tir.For.Vectorized
assert kloop.for_type == tvm.tir.For.Unrolled
func, ins, outs = run_and_check(looptype, [a, b, c])
run_and_check(func, ins, outs=outs)
......@@ -364,7 +364,7 @@ def test_bind():
c = foo(a)
s = tvm.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True)
assert not isinstance(ir, tvm.stmt.AttrStmt)
assert not isinstance(ir, tvm.tir.AttrStmt)
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
......@@ -729,20 +729,20 @@ def test_schedule():
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.inner'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'j.outer.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'j.outer.inner'
ir = ir.body
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
......@@ -752,11 +752,11 @@ def test_schedule():
sch = tvm.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.j.fused'
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
......
......@@ -28,14 +28,14 @@ def test_for():
body = ib.get()
print(body)
assert isinstance(body, tvm.stmt.AttrStmt)
assert isinstance(body, tvm.tir.AttrStmt)
body = body.body
assert isinstance(body, tvm.stmt.Allocate)
assert isinstance(body, tvm.tir.Allocate)
body = body.body
assert isinstance(body, tvm.stmt.For)
assert isinstance(body, tvm.tir.For)
body = body.body
assert isinstance(body, tvm.stmt.SeqStmt)
assert isinstance(body[1], tvm.stmt.For)
assert isinstance(body, tvm.tir.SeqStmt)
assert isinstance(body[1], tvm.tir.For)
def test_if():
ib = tvm.ir_builder.create()
......@@ -50,11 +50,11 @@ def test_if():
body = ib.get()
assert A == A
assert isinstance(body, tvm.stmt.For)
assert isinstance(body, tvm.tir.For)
body = body.body
assert isinstance(body, tvm.stmt.IfThenElse)
assert isinstance(body.condition, tvm.expr.EQ)
assert isinstance(body.then_case.index, tvm.expr.Var)
assert isinstance(body, tvm.tir.IfThenElse)
assert isinstance(body.condition, tvm.tir.EQ)
assert isinstance(body.then_case.index, tvm.tir.Var)
assert body.else_case.index.value == 0
def test_prefetch():
......@@ -64,10 +64,10 @@ def test_prefetch():
with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
tvm.tir.Prefetch(
A.op, A.value_index, A.dtype,
[tvm.make.range_by_min_extent(i+1, 2),
tvm.make.range_by_min_extent(0, 20)]))
[tvm.ir.Range.make_by_min_extent(i+1, 2),
tvm.ir.Range.make_by_min_extent(0, 20)]))
body = ib.get()
assert body.body.bounds[0].extent.value == 2
......
......@@ -22,7 +22,7 @@ def test_const():
x = tvm.const(1, "int32")
print(x.dtype)
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm)
assert isinstance(x, tvm.tir.IntImm)
def test_scalar_dtype_inference():
......@@ -45,47 +45,47 @@ def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
assert isinstance(tvm.max(x, y), tvm.tir.Max)
assert isinstance(tvm.min(x, y), tvm.tir.Min)
def test_ir():
x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1)
y = tvm.tir.IntImm('int32', 1)
z = x + y
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
stmt = tvm.tir.Evaluate(z)
assert isinstance(stmt, tvm.tir.Evaluate)
def test_ir2():
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store)
st = tvm.tir.Store(a, x + 1, 1)
assert isinstance(st, tvm.tir.Store)
assert(st.buffer_var == a)
def test_let():
x = tvm.var('x')
y = tvm.var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1));
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1));
def test_cast():
x = tvm.var('x', dtype="float32")
y = x.astype("int32")
z = x.astype("float32x4")
assert isinstance(y, tvm.expr.Cast)
assert isinstance(z, tvm.expr.Broadcast)
assert isinstance(y, tvm.tir.Cast)
assert isinstance(z, tvm.tir.Broadcast)
assert z.lanes == 4
def test_attr():
x = tvm.var('x')
y = tvm.var('y')
stmt = tvm.make.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1));
stmt = tvm.tir.AttrStmt(
y, "stride", 10, tvm.tir.Evaluate(x + 1));
assert stmt.node == y
a = tvm.convert(1)
......@@ -105,9 +105,9 @@ def test_basic():
def test_stmt():
x = tvm.make.Evaluate(0)
tvm.make.For(tvm.var('i'), 0, 1,
tvm.stmt.For.Serial, 0,
x = tvm.tir.Evaluate(0)
tvm.tir.For(tvm.var('i'), 0, 1,
tvm.tir.For.Serial, 0,
x)
......@@ -207,7 +207,7 @@ def test_equality():
def test_equality_string_imm():
x = 'a'
y = tvm.make.StringImm(x)
y = tvm.tir.StringImm(x)
x == y.value
x == y
......
......@@ -17,50 +17,50 @@
import tvm
def test_expr_constructor():
x = tvm.expr.Var("xx", "float32")
assert isinstance(x, tvm.expr.Var)
x = tvm.tir.Var("xx", "float32")
assert isinstance(x, tvm.tir.Var)
assert x.name == "xx"
x = tvm.expr.Reduce(None, [1],
x = tvm.tir.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)],
None, 0)
assert isinstance(x, tvm.expr.Reduce)
assert isinstance(x, tvm.tir.Reduce)
assert x.combiner == None
assert x.value_index == 0
x = tvm.expr.FloatImm("float32", 1.0)
assert isinstance(x, tvm.expr.FloatImm)
x = tvm.tir.FloatImm("float32", 1.0)
assert isinstance(x, tvm.tir.FloatImm)
assert x.value == 1.0
assert x.dtype == "float32"
x = tvm.expr.IntImm("int64", 2)
assert isinstance(x, tvm.expr.IntImm)
x = tvm.tir.IntImm("int64", 2)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 2
assert x.dtype == "int64"
x = tvm.expr.StringImm("xyza")
assert isinstance(x, tvm.expr.StringImm)
x = tvm.tir.StringImm("xyza")
assert isinstance(x, tvm.tir.StringImm)
assert x.value == "xyza"
x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1))
assert isinstance(x, tvm.expr.Cast)
x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1))
assert isinstance(x, tvm.tir.Cast)
assert x.dtype == "float32"
assert x.value.value == 1
a = tvm.const(1.0, dtype="float32")
b = tvm.var("x", dtype="float32")
for cls in [tvm.expr.Add,
tvm.expr.Sub,
tvm.expr.Mul,
tvm.expr.Div,
tvm.expr.Mod,
tvm.expr.Min,
tvm.expr.Max,
tvm.expr.LT,
tvm.expr.LE,
tvm.expr.GT,
tvm.expr.GE]:
for cls in [tvm.tir.Add,
tvm.tir.Sub,
tvm.tir.Mul,
tvm.tir.Div,
tvm.tir.Mod,
tvm.tir.Min,
tvm.tir.Max,
tvm.tir.LT,
tvm.tir.LE,
tvm.tir.GT,
tvm.tir.GE]:
x = cls(a, b)
assert isinstance(x, cls)
assert x.a == a
......@@ -70,58 +70,58 @@ def test_expr_constructor():
a = tvm.convert(tvm.var("x") > 1)
b = tvm.convert(tvm.var("x") == 1)
for cls in [tvm.expr.And,
tvm.expr.Or]:
for cls in [tvm.tir.And,
tvm.tir.Or]:
x = cls(a, b)
assert isinstance(x, cls)
assert x.a == a
assert x.b.same_as(b)
x = tvm.expr.Not(a)
assert isinstance(x, tvm.expr.Not)
x = tvm.tir.Not(a)
assert isinstance(x, tvm.tir.Not)
assert x.a == a
x = tvm.expr.Select(a, a, b)
assert isinstance(x, tvm.expr.Select)
x = tvm.tir.Select(a, a, b)
assert isinstance(x, tvm.tir.Select)
assert x.true_value == a
assert x.false_value == b
assert x.condition == a
buffer_var = tvm.var("x", dtype="handle")
x = tvm.expr.Load("float32", buffer_var, 1, a)
assert isinstance(x, tvm.expr.Load)
x = tvm.tir.Load("float32", buffer_var, 1, a)
assert isinstance(x, tvm.tir.Load)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.index.value == 1
assert x.predicate == a
x = tvm.expr.Ramp(1, 2, 10)
assert isinstance(x, tvm.expr.Ramp)
x = tvm.tir.Ramp(1, 2, 10)
assert isinstance(x, tvm.tir.Ramp)
assert x.base.value == 1
assert x.stride.value == 2
assert x.lanes == 10
x = tvm.expr.Broadcast(a, 10)
assert isinstance(x, tvm.expr.Broadcast)
x = tvm.tir.Broadcast(a, 10)
assert isinstance(x, tvm.tir.Broadcast)
assert x.value == a
assert x.lanes == 10
x = tvm.expr.Shuffle([a], [0])
assert isinstance(x, tvm.expr.Shuffle)
x = tvm.tir.Shuffle([a], [0])
assert isinstance(x, tvm.tir.Shuffle)
assert x.vectors[0] == a
assert x.indices[0].value == 0
x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0)
assert isinstance(x, tvm.expr.Call)
x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0)
assert isinstance(x, tvm.tir.Call)
assert x.dtype == "float32"
assert x.name == "xyz"
assert x.args[0] == a
assert x.call_type == tvm.expr.Call.Extern
assert x.call_type == tvm.tir.Call.Extern
assert x.func == None
assert x.value_index == 0
v = tvm.var("aa")
x = tvm.expr.Let(v, 1, v)
x = tvm.tir.Let(v, 1, v)
assert x.var == v
assert x.value.value == 1
assert x.body == v
......@@ -130,75 +130,75 @@ def test_expr_constructor():
def test_stmt_constructor():
v = tvm.var("aa")
buffer_var = tvm.var("buf", dtype="handle")
nop = tvm.stmt.Evaluate(1)
x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1))
assert isinstance(x, tvm.stmt.LetStmt)
nop = tvm.tir.Evaluate(1)
x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
assert isinstance(x, tvm.tir.LetStmt)
assert x.var == v
assert x.value.value == 1
assert isinstance(x.body, tvm.stmt.Evaluate)
assert isinstance(x.body, tvm.tir.Evaluate)
x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1))
assert isinstance(x, tvm.stmt.AttrStmt)
x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1))
assert isinstance(x, tvm.tir.AttrStmt)
assert x.value.value == 1
x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"),
x = tvm.tir.AssertStmt(tvm.const(1, "uint1"),
tvm.convert("hellow"),
nop)
assert isinstance(x, tvm.stmt.AssertStmt)
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
x = tvm.stmt.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.stmt.ProducerConsumer)
x = tvm.tir.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.tir.ProducerConsumer)
assert x.body == nop
x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.stmt.For)
x = tvm.tir.For(tvm.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.tir.For)
assert x.min.value == 0
assert x.extent.value == 10
assert x.body == nop
x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
assert isinstance(x, tvm.stmt.Store)
x = tvm.tir.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
assert isinstance(x, tvm.tir.Store)
assert x.buffer_var == buffer_var
assert x.index.value == 10
assert x.value.value == 1
tensor = tvm.placeholder((), dtype="float32")
x = tvm.stmt.Provide(tensor.op, 0, 10, [])
assert isinstance(x, tvm.stmt.Provide)
x = tvm.tir.Provide(tensor.op, 0, 10, [])
assert isinstance(x, tvm.tir.Provide)
assert x.value_index == 0
assert x.value.value == 10
x = tvm.stmt.Allocate(buffer_var, "float32", [10],
x = tvm.tir.Allocate(buffer_var, "float32", [10],
tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Allocate)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.body == nop
x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.stmt.AttrStmt)
x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.tir.AttrStmt)
assert x.node == buffer_var
assert x.attr_key == "xyz"
assert x.body == nop
x = tvm.stmt.Free(buffer_var)
assert isinstance(x, tvm.stmt.Free)
x = tvm.tir.Free(buffer_var)
assert isinstance(x, tvm.tir.Free)
assert x.buffer_var == buffer_var
x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Realize)
x = tvm.tir.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.tir.Realize)
assert x.body == nop
x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"),
tvm.stmt.Evaluate(11),
x = tvm.tir.IfThenElse(tvm.const(1, "uint1"),
tvm.tir.Evaluate(11),
nop)
assert isinstance(x, tvm.stmt.IfThenElse)
assert isinstance(x, tvm.tir.IfThenElse)
assert x.then_case.value.value == 11
assert x.else_case == nop
x = tvm.stmt.Prefetch(None, 1, "float32", [])
assert isinstance(x, tvm.stmt.Prefetch)
x = tvm.tir.Prefetch(None, 1, "float32", [])
assert isinstance(x, tvm.tir.Prefetch)
assert x.value_index == 1
......
......@@ -69,7 +69,7 @@ def test_map_save_load_json():
def test_in_container():
arr = tvm.convert(['a', 'b', 'c'])
assert 'a' in arr
assert tvm.make.StringImm('a') in arr
assert tvm.tir.StringImm('a') in arr
assert 'd' not in arr
def test_ndarray_container():
......
......@@ -20,9 +20,9 @@ import tvm
from topi.util import get_const_tuple
def test_layout():
layout = tvm.layout("NCHW16c")
layout = tvm.tir.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tensor.Layout)
assert isinstance(layout, tvm.tir.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
......@@ -63,7 +63,7 @@ def test_bilayout_convertible():
def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
assert isinstance(bilayout, tvm.tir.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
......
......@@ -29,7 +29,7 @@ def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x, "int32") for x in args])
y = f(*args)
if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y):
if not isinstance(x, (tvm.tir.IntImm,)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
tmod = tvm.truncmod
......@@ -56,7 +56,7 @@ def test_const_fold2():
assert tmod(x, 1).value == 0
assert (x * 1).same_as(x)
assert (1 * x).same_as(x)
assert isinstance(tdiv(1, x), tvm.expr.Div)
assert isinstance(tdiv(1, x), tvm.tir.Div)
def test_const_fold3():
# Test that using ints with logic operations is forbidden
......@@ -92,17 +92,17 @@ def test_const_fold4():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
tdiv = tvm.truncdiv
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9
x3 = tdiv(x2, 3)
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6)
assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
assert isinstance(y, tvm.tir.IntImm) and y.value == 6
def test_binary_dtype_match():
......
......@@ -31,7 +31,7 @@ def test_make_smap():
# save load json
x = tvm.const(1, "int32")
y = tvm.const(10, "int32")
z = tvm.expr.Add(x, y)
z = tvm.tir.Add(x, y)
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.ir.save_json(tvm.convert([smap]))
arr = tvm.ir.load_json(json_str)
......@@ -40,11 +40,11 @@ def test_make_smap():
def test_make_node():
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10
A = tvm.placeholder((10, ), name='A')
AA = tvm.make.node("Tensor",
AA = tvm.ir.make_node("Tensor",
shape=A.shape,
dtype=A.dtype,
op=A.op,
......@@ -55,25 +55,25 @@ def test_make_node():
def test_make_attrs():
try:
x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
......@@ -104,7 +104,7 @@ def test_env_func():
assert y(1) == 2
assert y.func(1) == 2
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
......
......@@ -240,7 +240,7 @@ def test_tensor_intrin_scalar_params():
C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
......
......@@ -128,7 +128,7 @@ def test_tensor_compute1():
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
def test_tensor_compute2():
M = 2048
......@@ -171,8 +171,8 @@ def test_tensor_compute2():
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate)
def test_tensor_scan():
m = tvm.size_var("m")
......@@ -259,7 +259,7 @@ def test_tuple_with_different_deps():
stmt = tvm.schedule.ScheduleOps(sch, bounds)
def get_B1_realize(x):
if isinstance(x, tvm.stmt.Realize) and \
if isinstance(x, tvm.tir.Realize) and \
x.func == B1.op and x.value_index == 1:
ret.append(x)
ret = []
......
......@@ -29,8 +29,8 @@ def test_operator_type_and_tags():
B1 = B[0]
B2 = B[0,0]
assert isinstance(k + n, tvm.expr.PrimExpr)
assert isinstance(n + n, tvm.expr.PrimExpr)
assert isinstance(k + n, tvm.tir.PrimExpr)
assert isinstance(n + n, tvm.tir.PrimExpr)
assert isinstance(k + A, tvm.tensor.Tensor)
assert isinstance(A + k, tvm.tensor.Tensor)
assert isinstance(n + A, tvm.tensor.Tensor)
......@@ -53,11 +53,11 @@ def test_operator_type_and_tags():
assert (B + A).op.tag == topi.tag.BROADCAST
assert (B + B).op.tag == topi.tag.BROADCAST
assert isinstance(k + B2, tvm.expr.PrimExpr)
assert isinstance(B2 + k, tvm.expr.PrimExpr)
assert isinstance(n + B2, tvm.expr.PrimExpr)
assert isinstance(B2 + n, tvm.expr.PrimExpr)
assert isinstance(B2 + B2, tvm.expr.PrimExpr)
assert isinstance(k + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + k, tvm.tir.PrimExpr)
assert isinstance(n + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + n, tvm.tir.PrimExpr)
assert isinstance(B2 + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + A, tvm.tensor.Tensor)
assert isinstance(A + B2, tvm.tensor.Tensor)
assert isinstance(B2 + B, tvm.tensor.Tensor)
......
......@@ -17,15 +17,15 @@
import tvm
def test_attrs_equal():
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1))
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
assert tvm.ir_pass.AttrsEqual(x, y)
assert not tvm.ir_pass.AttrsEqual(x, z)
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert not tvm.ir_pass.AttrsEqual(dattr, x)
dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert tvm.ir_pass.AttrsEqual(dattr, dattr2)
assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y})
......@@ -42,8 +42,8 @@ def test_attrs_equal():
def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y})
assert fhash({"x": x}) != fhash({"x": [y, 1]})
assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]})
......
......@@ -31,16 +31,16 @@ def test_simplify():
def test_verify_ssa():
x = tvm.var('x')
y = tvm.var()
z = tvm.make.Evaluate(x + y)
z = tvm.tir.Evaluate(x + y)
assert(tvm.ir_pass.VerifySSA(z))
def test_convert_ssa():
x = tvm.var('x')
y = tvm.var()
let1 = tvm.make.Let(x, 1, x + 1)
let2 = tvm.make.Let(x, 1, x + y)
z = tvm.make.Evaluate(let1 + let2)
let1 = tvm.tir.Let(x, 1, x + 1)
let2 = tvm.tir.Let(x, 1, x + y)
z = tvm.tir.Evaluate(let1 + let2)
assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa))
......
......@@ -166,12 +166,12 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
def test_in_bounds_const_loop_partition_ir():
def check_attr_stmt (x):
if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
if isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
return True
return False
def check_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse):
if isinstance(x, tvm.tir.IfThenElse):
return True
return False
......@@ -183,7 +183,7 @@ def test_in_bounds_const_loop_partition_ir():
assert (count == nums)
def collect_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse):
if isinstance(x, tvm.tir.IfThenElse):
branch_collector.append(x)
n = 21
......
......@@ -20,8 +20,8 @@ def test_for():
dev_type = tvm.var("dev_type")
def device_context(dev_id):
ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
return tvm.tir.Call(
"handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0)
ib = tvm.ir_builder.create()
n = tvm.var("n")
......
......@@ -33,7 +33,7 @@ def test_decorate_device():
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.ir_pass.Simplify(stmt)
stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.stmt.AttrStmt)
assert isinstance(stmt2, tvm.tir.AttrStmt)
assert stmt2.attr_key == "device_scope"
assert stmt1 == stmt2.body
......
......@@ -24,19 +24,19 @@ def verify_structure(stmt, expected_struct):
struct = {}
def _extract_vars(op):
global var_list
if isinstance(op, tvm.expr.Var):
if isinstance(op, tvm.tir.Var):
var_list.append(op.name)
def _visit(op):
key = op
if isinstance(op, tvm.stmt.IfThenElse):
if isinstance(op, tvm.tir.IfThenElse):
global var_list
tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars)
val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
var_list.clear()
elif isinstance(op, tvm.stmt.For):
elif isinstance(op, tvm.tir.For):
val = [(op.body,), ("For", op.loop_var.name)]
elif isinstance(op, tvm.stmt.AttrStmt):
elif isinstance(op, tvm.tir.AttrStmt):
val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
else:
return
......@@ -61,9 +61,9 @@ def test_basic():
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)):
ib.emit(tvm.make.Evaluate(m))
ib.emit(tvm.tir.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get()
new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
......@@ -82,7 +82,7 @@ def test_no_else():
with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)):
ib.emit(tvm.make.Evaluate(m))
ib.emit(tvm.tir.Evaluate(m))
stmt = ib.get()
new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
......
......@@ -33,7 +33,7 @@ def test_copy2d():
assert dst.strides[1].value == 1
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.make.Evaluate(0)
return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_copy_pad():
......@@ -57,7 +57,7 @@ def test_copy_pad():
assert pad_after[0].value == 1
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.make.Evaluate(0)
return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_single_point_test():
......@@ -76,7 +76,7 @@ def test_single_point_test():
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0)
return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b):
......@@ -109,7 +109,7 @@ def test_copy_pad_split():
assert_expr_equal(pad_before[0], rpad_before)
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.make.Evaluate(0)
return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
......
......@@ -37,13 +37,13 @@ def test_double_buffer():
stmt = ib.get()
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 4
......
......@@ -20,7 +20,7 @@ def test_inline():
m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.tir.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt)
......@@ -39,11 +39,11 @@ def test_inline2():
m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.tir.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op):
if isinstance(op, tvm.expr.Call):
if isinstance(op, tvm.tir.Call):
assert op.func != T.op
tvm.ir_pass.PostOrderVisit(stmt, check)
......
......@@ -32,12 +32,12 @@ def test_ir_transform():
return None
def postorder(op):
assert isinstance(op, tvm.expr.Call)
assert isinstance(op, tvm.tir.Call)
if op.name == "TestA":
return tvm.call_extern("int32", "TestB", op.args[0] + 1)
return op
body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
stmt_list = tvm.make.stmt_list(body.body.body)
stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0
......
......@@ -20,7 +20,7 @@ def test_coproc_lift():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
value = tvm.make.StringImm("xxx")
value = tvm.tir.StringImm("xxx")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment