Commit b787ffa3 by tqchen Committed by Tianqi Chen

[REFACTOR][PY] Establish tvm.tir

- Move related files into the corresponding location as in C++
- Keep the top-level TVM API backward compatible to make minimum changes in topi
parent a6c42b34
...@@ -33,36 +33,40 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl ...@@ -33,36 +33,40 @@ from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd from .runtime import ndarray as nd
# tvm.error
from . import error
# tvm.ir # tvm.ir
from .ir import IRModule from .ir import IRModule
from .ir import transform from .ir import transform
from .ir import container from .ir import container
from . import ir from . import ir
# tvm.tir
from . import tir
# tvm.target
from . import target
# others # others
from . import tensor from . import tensor
from . import arith from . import arith
from . import expr
from . import stmt
from . import make from . import make
from . import ir_pass
from . import schedule from . import schedule
from . import ir_builder
from . import target
from . import generic
from . import hybrid from . import hybrid
from . import testing from . import testing
from . import error
from .api import * from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin from .tensor_intrin import decl_tensor_intrin
from .schedule import create_schedule from .schedule import create_schedule
from .build_module import build, lower, build_config from .build_module import build, lower, build_config
from .tag import tag_scope from .tag import tag_scope
# backward compact for topi, to be removed later
from .tir import expr, stmt, ir_builder, ir_pass, generic
from .tir.op import *
from . import intrin
# Contrib initializers # Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
...@@ -227,7 +227,7 @@ def args_to_workload(x, topi_compute_func=None): ...@@ -227,7 +227,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = 0 workload = 0
else: else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use' raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types or tvm.expr.Var only' % type(x)) 'primitive types or tvm.tir.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
def template(func): def template(func):
......
...@@ -26,17 +26,19 @@ import tvm.runtime ...@@ -26,17 +26,19 @@ import tvm.runtime
from tvm.runtime import Object, ndarray from tvm.runtime import Object, ndarray
from tvm.ir import container from tvm.ir import container
from tvm.target import codegen from tvm.target import codegen
from tvm.tir import expr
from tvm.tir import ir_pass
from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import target as _target
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
from . import schedule from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import target as _target
from . import make from . import make
from .stmt import LoweredFunc
class DumpIR(object): class DumpIR(object):
...@@ -61,7 +63,7 @@ class DumpIR(object): ...@@ -61,7 +63,7 @@ class DumpIR(object):
def dump(*args, **kwargs): def dump(*args, **kwargs):
"""dump function""" """dump function"""
retv = func(*args, **kwargs) retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)): if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
return retv return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__ fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc" pname = str(self._pass_id) + "_" + fname + "_ir.cc"
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to BLAS libraries.""" """External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs import tvm
from .. import api as _api
from .. import api as _api, intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False, **kwargs): def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
...@@ -46,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): ...@@ -46,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
return _api.extern( return _api.extern(
(n, m), (n, m),
[lhs, rhs], [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
), ),
name="C", name="C",
...@@ -78,7 +77,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs ...@@ -78,7 +77,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
return _api.extern( return _api.extern(
(b, n, m), (b, n, m),
[lhs, rhs], [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cblas.batch_matmul" "tvm.contrib.cblas.batch_matmul"
if not iterative if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative", else "tvm.contrib.cblas.batch_matmul_iterative",
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to cuBLAS libraries.""" """External function interface to cuBLAS libraries."""
from __future__ import absolute_import as _abs import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False, dtype=None): def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS """Create an extern op that compute matrix mult of A and rhs with cuBLAS
...@@ -44,7 +42,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None): ...@@ -44,7 +42,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype dtype = dtype if dtype is not None else lhs.dtype
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.matmul", "tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
...@@ -73,6 +71,6 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): ...@@ -73,6 +71,6 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype dtype = dtype if dtype is not None else lhs.dtype
return _api.extern( return _api.extern(
(b, n, m), [lhs, rhs], (b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublas.batch_matmul", "tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to cuBLASlt libraries.""" """External function interface to cuBLASlt libraries."""
from __future__ import absolute_import as _abs import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS """Create an extern op that compute matrix mult of A and rhs with cuBLAS
...@@ -46,6 +45,6 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): ...@@ -46,6 +45,6 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
dtype = dtype if dtype is not None else lhs.dtype dtype = dtype if dtype is not None else lhs.dtype
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cublaslt.matmul", "tvm.contrib.cublaslt.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
# pylint: disable-msg=C0103 # pylint: disable-msg=C0103
import ctypes import ctypes
import numpy as np import numpy as np
import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func from .. import get_global_func as _get_global_func
# algos can be read from cudnn.h # algos can be read from cudnn.h
...@@ -365,7 +365,7 @@ def conv_forward(x, ...@@ -365,7 +365,7 @@ def conv_forward(x,
if dims == 4: if dims == 4:
return _api.extern( return _api.extern(
oshape, [x, w], oshape, [x, w],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv2d.forward", "tvm.contrib.cudnn.conv2d.forward",
conv_mode, conv_mode,
tensor_format, tensor_format,
...@@ -383,7 +383,7 @@ def conv_forward(x, ...@@ -383,7 +383,7 @@ def conv_forward(x,
return _api.extern( return _api.extern(
oshape, [x, w], oshape, [x, w],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.conv3d.forward", "tvm.contrib.cudnn.conv3d.forward",
conv_mode, conv_mode,
tensor_format, tensor_format,
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
# pylint: disable-msg=C0103 # pylint: disable-msg=C0103
import ctypes import ctypes
import numpy as np import numpy as np
import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func from .. import get_global_func as _get_global_func
...@@ -113,7 +113,7 @@ def conv2d_forward(x, ...@@ -113,7 +113,7 @@ def conv2d_forward(x,
return _api.extern( return _api.extern(
list(oshape), [x, w], list(oshape), [x, w],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.conv2d.forward", "tvm.contrib.miopen.conv2d.forward",
conv_mode, conv_mode,
data_type, data_type,
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to MPS libraries.""" """External function interface to MPS libraries."""
from __future__ import absolute_import as _abs import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
# pylint: disable=C0103,W0612 # pylint: disable=C0103,W0612
...@@ -50,7 +49,7 @@ def matmul(lhs, rhs, transa=False, transb=False): ...@@ -50,7 +49,7 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = c n = c
return _api.extern( return _api.extern(
(m, n), [lhs, rhs], (m, n), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb), "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
name="C") name="C")
...@@ -82,6 +81,6 @@ def conv2d(data, weight, pad='SAME', stride=1): ...@@ -82,6 +81,6 @@ def conv2d(data, weight, pad='SAME', stride=1):
return _api.extern( return _api.extern(
(n, ho, wo, co), [data, weight], (n, ho, wo, co), [data, weight],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride), "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
name="C") name="C")
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to NNPACK libraries.""" """External function interface to NNPACK libraries."""
import tvm
import tvm._ffi import tvm._ffi
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
def is_available(): def is_available():
...@@ -46,7 +45,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1): ...@@ -46,7 +45,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1):
m = rhs.shape[0] m = rhs.shape[0]
return _api.extern( return _api.extern(
(m, ), [lhs, rhs], (m, ), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.fully_connected_inference", "tvm.contrib.nnpack.fully_connected_inference",
ins[0], ins[1], outs[0], nthreads), name="C") ins[0], ins[1], outs[0], nthreads), name="C")
...@@ -110,7 +109,7 @@ def convolution_inference( ...@@ -110,7 +109,7 @@ def convolution_inference(
return _api.extern( return _api.extern(
(batch, output_channels, output_height, output_width), (batch, output_channels, output_height, output_width),
[data, kernel, bias] if bias is not None else [data, kernel], [data, kernel, bias] if bias is not None else [data, kernel],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference", "tvm.contrib.nnpack.convolution_inference",
ins[0], ins[0],
ins[1], ins[1],
...@@ -163,7 +162,7 @@ def convolution_inference_without_weight_transform( ...@@ -163,7 +162,7 @@ def convolution_inference_without_weight_transform(
return _api.extern( return _api.extern(
(batch, output_channels, output_height, output_width), (batch, output_channels, output_height, output_width),
[data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel], [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference_without_weight_transform", "tvm.contrib.nnpack.convolution_inference_without_weight_transform",
ins[0], ins[0],
ins[1], ins[1],
...@@ -198,7 +197,7 @@ def convolution_inference_weight_transform( ...@@ -198,7 +197,7 @@ def convolution_inference_weight_transform(
return _api.extern( return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size), (output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel], [kernel],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform", "tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype) ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to random library.""" """External function interface to random library."""
import tvm
import tvm._ffi import tvm._ffi
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
def randint(low, high, size, dtype='int32'): def randint(low, high, size, dtype='int32'):
...@@ -39,7 +38,7 @@ def randint(low, high, size, dtype='int32'): ...@@ -39,7 +38,7 @@ def randint(low, high, size, dtype='int32'):
A tensor with specified size and dtype A tensor with specified size and dtype
""" """
assert 'int' in dtype, "the type of randint output must be int or uint" assert 'int' in dtype, "the type of randint output must be int or uint"
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype) "tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype)
...@@ -67,7 +66,7 @@ def uniform(low, high, size): ...@@ -67,7 +66,7 @@ def uniform(low, high, size):
out : Tensor out : Tensor
A tensor with specified size and dtype. A tensor with specified size and dtype.
""" """
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32') "tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
...@@ -91,7 +90,7 @@ def normal(loc, scale, size): ...@@ -91,7 +90,7 @@ def normal(loc, scale, size):
out : Tensor out : Tensor
A tensor with specified size and dtype A tensor with specified size and dtype
""" """
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32') "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to rocBLAS libraries.""" """External function interface to rocBLAS libraries."""
from __future__ import absolute_import as _abs import tvm
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False): def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS """Create an extern op that compute matrix mult of A and rhs with rocBLAS
...@@ -43,6 +41,6 @@ def matmul(lhs, rhs, transa=False, transb=False): ...@@ -43,6 +41,6 @@ def matmul(lhs, rhs, transa=False, transb=False):
m = rhs.shape[0] if transb else rhs.shape[1] m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.rocblas.matmul", "tvm.contrib.rocblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C") ins[0], ins[1], outs[0], transa, transb), name="C")
...@@ -14,117 +14,6 @@ ...@@ -14,117 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Generic opertors in TVM. """Generic operators."""
We follow the numpy naming convention for this interface # pylint:disable=unused-wildcard-import, wildcard-import
(e.g., tvm.generic.multitply ~ numpy.multiply). from .tir.generic import *
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
from . import make as _make
#Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of add operaton.
"""
return _make._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
"""Generic subtract operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _make._OpSub(lhs, rhs)
def multiply(lhs, rhs):
"""Generic multiply operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _make._OpMul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpFloorDiv(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
Parameters
----------
src : object
The source operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._cast(dtype, src)
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time """Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support.""" semantic support."""
from tvm.ir.container import Array from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For
from .. import api as _api from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert from .util import _internal_assert
from ..intrin import call_pure_intrin
#pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
LOOP_INTRIN = { LOOP_INTRIN = {
'range' : For.Serial, 'range' : For.Serial,
...@@ -69,15 +70,15 @@ def bind(func_id, args): ...@@ -69,15 +70,15 @@ def bind(func_id, args):
def _math_intrin(func_id, args): def _math_intrin(func_id, args):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .. import intrin import tvm.tir.op
return getattr(intrin, func_id)(*args) return getattr(tvm.tir.op, func_id)(*args)
sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
def _min_max(func_id, args): def _min_max(func_id, args):
_internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements") _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1]) return getattr(_expr, func_id.title())(args[0], args[1])
min = max = _min_max #pylint: disable=invalid-name min = max = _min_max #pylint: disable=invalid-name
...@@ -127,7 +128,7 @@ def len(func_id, args): ...@@ -127,7 +128,7 @@ def len(func_id, args):
def _cast(func_id, args): def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \ _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
"Only one expression can be cast") "Only one expression can be cast")
return _make.Cast(func_id, args[0]) return _expr.Cast(func_id, args[0])
float16 = float32 = float64 = _cast #pylint: disable=invalid-name float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
......
...@@ -24,7 +24,11 @@ import types ...@@ -24,7 +24,11 @@ import types
import numbers import numbers
from enum import Enum from enum import Enum
from tvm.ir.container import Array from tvm.ir import Array, Range
import tvm.tir
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from .util import _internal_assert from .util import _internal_assert
from . import calls from . import calls
...@@ -35,12 +39,7 @@ from ..api import any as _any ...@@ -35,12 +39,7 @@ from ..api import any as _any
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt
from .. import api as _api from .. import api as _api
from .. import ir_pass as _ir_pass
def concat_list_to_block(lst): def concat_list_to_block(lst):
...@@ -79,13 +78,13 @@ class Symbol(Enum): ...@@ -79,13 +78,13 @@ class Symbol(Enum):
def _floordiv(x, y): def _floordiv(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floordiv(x, y) return tvm.tir.floordiv(x, y)
return operator.floordiv(x, y) return operator.floordiv(x, y)
def _floormod(x, y): def _floormod(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floormod(x, y) return tvm.tir.floormod(x, y)
return operator.mod(x, y) return operator.mod(x, y)
...@@ -208,11 +207,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -208,11 +207,11 @@ class HybridParser(ast.NodeVisitor):
if _scope == 'global': if _scope == 'global':
body = self.wrap_up_binds(body) body = self.wrap_up_binds(body)
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype _dtype = _buf.dtype
_true = _api.convert(True) _true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
for elem in to_pop: for elem in to_pop:
self.symbols.pop(elem) self.symbols.pop(elem)
...@@ -223,7 +222,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -223,7 +222,7 @@ class HybridParser(ast.NodeVisitor):
def wrap_up_binds(self, body): def wrap_up_binds(self, body):
for _, iter_var in self.binds.items(): for _, iter_var in self.binds.items():
ext = iter_var.dom.extent ext = iter_var.dom.extent
body = _make.AttrStmt(iter_var, 'thread_extent', ext, body) body = tvm.tir.AttrStmt(iter_var, 'thread_extent', ext, body)
self.binds = {} self.binds = {}
return body return body
...@@ -271,7 +270,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -271,7 +270,7 @@ class HybridParser(ast.NodeVisitor):
return entry if isinstance(node.ctx, ast.Load) else None return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar: if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ return tvm.tir.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index) _expr.Call.Halide, entry.op, entry.value_index)
return entry, [_api.const(0, 'int32')] return entry, [_api.const(0, 'int32')]
# Do I need any assertion here? # Do I need any assertion here?
...@@ -304,10 +303,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -304,10 +303,10 @@ class HybridParser(ast.NodeVisitor):
args = [_api.const(0, 'int32')] args = [_api.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
value = HybridParser._binop_maker[type(node.op)](read, rhs) value = HybridParser._binop_maker[type(node.op)](read, rhs)
return _make.Provide(buf.op, 0, value, args) return tvm.tir.Provide(buf.op, 0, value, args)
def visit_Assign(self, node): def visit_Assign(self, node):
...@@ -358,13 +357,13 @@ class HybridParser(ast.NodeVisitor): ...@@ -358,13 +357,13 @@ class HybridParser(ast.NodeVisitor):
lhs = self.visit(lhs_) lhs = self.visit(lhs_)
if lhs is not None: if lhs is not None:
buf, args = lhs buf, args = lhs
return _make.Provide(buf.op, 0, rhs, args) return tvm.tir.Provide(buf.op, 0, rhs, args)
return util.make_nop() return util.make_nop()
lhs, args = self.visit(lhs) lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \ _internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!") "An array access's LHS is expected to be a expr.Call!")
res = _make.Provide(lhs.op, lhs.value_index, rhs, args) res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args)
return res return res
...@@ -391,8 +390,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -391,8 +390,8 @@ class HybridParser(ast.NodeVisitor):
arr = arr[i.value] arr = arr[i.value]
return arr return arr
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return _make.Call(arr.dtype, arr.name, args, return tvm.tir.Call(arr.dtype, arr.name, args,
_expr.Call.Halide, arr.op, arr.value_index) _expr.Call.Halide, arr.op, arr.value_index)
return arr, args return arr, args
def visit_With(self, node): def visit_With(self, node):
...@@ -426,14 +425,14 @@ class HybridParser(ast.NodeVisitor): ...@@ -426,14 +425,14 @@ class HybridParser(ast.NodeVisitor):
else_body = visit_list_to_block(self.visit, node.orelse) else_body = visit_list_to_block(self.visit, node.orelse)
else: else:
else_body = None else_body = None
return _make.IfThenElse(cond, if_body, else_body) return tvm.tir.IfThenElse(cond, if_body, else_body)
def visit_IfExp(self, node): def visit_IfExp(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
if_body = self.visit(node.body) if_body = self.visit(node.body)
else_body = self.visit(node.orelse) else_body = self.visit(node.orelse)
return _make.Select(cond, if_body, else_body) return tvm.tir.Select(cond, if_body, else_body)
def visit_Compare(self, node): def visit_Compare(self, node):
...@@ -543,7 +542,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -543,7 +542,7 @@ class HybridParser(ast.NodeVisitor):
else: else:
_internal_assert(not isinstance(for_type, tuple), \ _internal_assert(not isinstance(for_type, tuple), \
"Micro expansion should be handled before!") "Micro expansion should be handled before!")
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
self.symbols.pop(_name) self.symbols.pop(_name)
return res return res
...@@ -580,7 +579,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -580,7 +579,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Assert(self, node): def visit_Assert(self, node):
test = self.visit(node.test) test = self.visit(node.test)
mesg = _api.convert(self.visit(node.msg)) mesg = _api.convert(self.visit(node.msg))
return _make.AssertStmt(test, mesg, util.make_nop()) return tvm.tir.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, args, symbols, closure_vars): def parse_python(src, args, symbols, closure_vars):
......
...@@ -22,12 +22,13 @@ import logging ...@@ -22,12 +22,13 @@ import logging
import sys import sys
import numpy import numpy
from tvm._ffi.base import numeric_types
from tvm.ir.container import Array from tvm.ir.container import Array
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from .. import api as _api from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from .._ffi.base import numeric_types
from ..tensor import Tensor from ..tensor import Tensor
...@@ -46,7 +47,7 @@ def _internal_assert(cond, err): ...@@ -46,7 +47,7 @@ def _internal_assert(cond, err):
# Useful constants. In avoid of runtime dependences, we use function calls to return them. # Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop(): def make_nop():
"""Returns a 'no operation' node in HalideIR.""" """Returns a 'no operation' node in HalideIR."""
return _make.Evaluate(_api.const(0, dtype='int32')) return _stmt.Evaluate(_api.const(0, dtype='int32'))
def is_docstring(node): def is_docstring(node):
...@@ -77,10 +78,10 @@ def replace_io(body, rmap): ...@@ -77,10 +78,10 @@ def replace_io(body, rmap):
def replace(op): def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func] buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args) return _stmt.Provide(buf.op, op.value_index, op.value, op.args)
if isinstance(op, _expr.Call) and op.func in rmap.keys(): if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func] buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \ return _expr.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index) _expr.Call.Halide, buf.op, buf.value_index)
return None return None
......
...@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation ...@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .adt import Constructor, TypeData from .adt import Constructor, TypeData
from .module import IRModule from .module import IRModule
from .attrs import Attrs from .attrs import Attrs, make_node
from .container import Array, Map from .container import Array, Map
from . import transform from . import transform
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
import tvm.runtime._ffi_node_api
from . import _ffi_api from . import _ffi_api
...@@ -91,3 +92,40 @@ class Attrs(Object): ...@@ -91,3 +92,40 @@ class Attrs(Object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__getattr__(item) return self.__getattr__(item)
def make_node(type_key, **kwargs):
"""Make a new IR node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding IR Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10
"""
args = [type_key]
for k, v in kwargs.items():
args += [k, v]
return tvm.runtime._ffi_node_api.MakeNode(*args)
...@@ -53,7 +53,7 @@ class Node(Object): ...@@ -53,7 +53,7 @@ class Node(Object):
return _ffi_api.AsText(self, show_meta_data, annotate) return _ffi_api.AsText(self, show_meta_data, annotate)
def __str__(self): def __str__(self):
return self.astext(show_meta_data=False) return _ffi_api.PrettyPrint(self)
@tvm._ffi.register_object("relay.SourceName") @tvm._ffi.register_object("relay.SourceName")
......
...@@ -99,3 +99,23 @@ class Range(Node): ...@@ -99,3 +99,23 @@ class Range(Node):
You do not need to create a Range explicitly. You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions. Python lists and tuples will be converted automatically to a Range in API functions.
""" """
@staticmethod
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
Parameters
----------
min_value : PrimExpr
The minimum value of the range.
extent : PrimExpr
The extent of the range.
Returns
-------
rng : Range
The constructed range.
"""
return _ffi_api.range_by_min_extent(min_value, extent)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-import
"""namespace of IR node builder make function """namespace of IR node builder make function
This namespace is used for developers. While you do not see any declarations. This namespace is used for developers. While you do not see any declarations.
...@@ -23,19 +24,22 @@ Each api is a PackedFunc that can be called in a positional argument manner. ...@@ -23,19 +24,22 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node. You can use make function to build the IR node.
""" """
import tvm._ffi import tvm._ffi
import tvm.ir
from tvm.ir import make_node as node
from tvm.tir import Call
def range_by_min_extent(min_value, extent): def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent. """Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent) This constructs a range in [min_value, min_value + extent)
Parameters Parameters
---------- ----------
min_value : Expr min_value : PrimExpr
The minimum value of the range. The minimum value of the range.
extent : Expr extent : PrimExpr
The extent of the range. The extent of the range.
Returns Returns
...@@ -43,45 +47,6 @@ def range_by_min_extent(min_value, extent): ...@@ -43,45 +47,6 @@ def range_by_min_extent(min_value, extent):
rng : Range rng : Range
The constructed range. The constructed range.
""" """
return _range_by_min_extent(min_value, extent) return tvm.ir.Range.make_by_min_extent(min_value, extent)
def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Returns
-------
node : Node
The corresponding DSL Node
Note
----
If the created node is instance of AttrsNode, then
the creator function will also run bound checks and
default value setup as supported by Attrs.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 10
"""
args = [type_key]
for k, v in kwargs.items():
args += [k, v]
return _Node(*args)
tvm._ffi._init_api("tvm.make") tvm._ffi._init_api("tvm.make")
...@@ -509,7 +509,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -509,7 +509,7 @@ class ParseTreeToRelayIR(RelayVisitor):
_, type_params = zip(*type_params) _, type_params = zip(*type_params)
self.exit_var_scope() self.exit_var_scope()
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs) return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify @spanify
......
...@@ -181,11 +181,11 @@ class VMCompiler(object): ...@@ -181,11 +181,11 @@ class VMCompiler(object):
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
tgts = {} tgts = {}
if isinstance(target, (str, tvm.target.Target)): if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) dev_type = tvm.tir.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target) tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict): elif isinstance(target, dict):
for dev, tgt in target.items(): for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) dev_type = tvm.tir.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt) tgts[dev_type] = tvm.target.create(tgt)
else: else:
raise TypeError("target is expected to be str, tvm.target.Target, " + raise TypeError("target is expected to be str, tvm.target.Target, " +
......
...@@ -932,7 +932,7 @@ def _shape(): ...@@ -932,7 +932,7 @@ def _shape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
is_symbolic_shape = False is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]: for axis in attr['_input_shapes'][inputs[0]]:
if not isinstance(axis, (int, tvm.expr.IntImm)): if not isinstance(axis, (int, tvm.tir.IntImm)):
is_symbolic_shape = True is_symbolic_shape = True
break break
......
...@@ -557,7 +557,7 @@ def split_shape_func(attrs, inputs, _): ...@@ -557,7 +557,7 @@ def split_shape_func(attrs, inputs, _):
""" """
Shape function for split op. Shape function for split op.
""" """
if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)): if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections) indices_or_sections = get_const_int(attrs.indices_or_sections)
else: else:
indices_or_sections = get_const_tuple(attrs.indices_or_sections) indices_or_sections = get_const_tuple(attrs.indices_or_sections)
......
...@@ -14,126 +14,17 @@ ...@@ -14,126 +14,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-import
"""The computation schedule api of TVM.""" """The computation schedule api of TVM."""
import tvm._ffi import tvm._ffi
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm.runtime import Object, convert from tvm.runtime import Object, convert
from tvm.ir import container as _container from tvm.ir import container as _container
from tvm.tir import expr as _expr, Buffer
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import expr as _expr
@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :any:`decl_buffer` instead.
See the documentation of :any:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
return _api_internal._BufferVStore(self, begin, value)
@tvm._ffi.register_object @tvm._ffi.register_object
......
...@@ -60,3 +60,4 @@ from .generic_func import GenericFunc ...@@ -60,3 +60,4 @@ from .generic_func import GenericFunc
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
from . import datatype from . import datatype
from . import codegen from . import codegen
from .intrin import register_intrin_rule
...@@ -19,7 +19,7 @@ import tvm._ffi ...@@ -19,7 +19,7 @@ import tvm._ffi
import tvm.runtime._ffi_api import tvm.runtime._ffi_api
from tvm.runtime import convert, DataType from tvm.runtime import convert, DataType
from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
def register(type_name, type_code): def register(type_name, type_code):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Target dependent intrinsic registration."""
import tvm._ffi
from tvm.tir import call_pure_extern
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
Intrinsic generation rules are callback functions for
code generator to get device specific calls.
This function simply translates to.
:code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
TVM may already pre-register intrinsic rules in the backend.
However, user can use this function to change the intrinsic translation
behavior or add new intrinsic rules during runtime.
Parameters
----------
target : str
The name of codegen target.
intrin : str
The name of the intrinsic.
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers exp expansion rule for opencl.
.. code-block:: python
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op):
"""Intrinsic rule: Add float suffix if it is float32.
This is an example intrinsic generation rule.
Parameters
----------
op : PrimExpr
The call expression of original intrinsic.
Returns
-------
ret : PrimExpr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if op.dtype == "float32":
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
if op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args)
return op
def _rule_float_direct(op):
"""Intrinsic rule: Directly call pure extern function for floats.
This is an example intrinsic generation rule.
Parameters
----------
op : PrimExpr
The call expression of original intrinsic.
Returns
-------
ret : PrimExpr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.name, *op.args)
return None
# opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
...@@ -19,10 +19,9 @@ ...@@ -19,10 +19,9 @@
import tvm._ffi import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from . import _api_internal from . import _api_internal
from . import make as _make
from . import expr as _expr
class TensorSlice(ObjectGeneric, _expr.ExprOp): class TensorSlice(ObjectGeneric, _expr.ExprOp):
...@@ -74,7 +73,7 @@ class Tensor(Object, _expr.ExprOp): ...@@ -74,7 +73,7 @@ class Tensor(Object, _expr.ExprOp):
else: else:
raise ValueError("The indices must be expression") raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.op.name, return _expr.Call(self.dtype, self.op.name,
args, _expr.Call.Halide, args, _expr.Call.Halide,
self.op, self.value_index) self.op, self.value_index)
...@@ -207,136 +206,3 @@ class HybridOp(Operation): ...@@ -207,136 +206,3 @@ class HybridOp(Operation):
def axis(self): def axis(self):
"""Represent the IterVar axis, also defined when it is a HybridOp""" """Represent the IterVar axis, also defined when it is a HybridOp"""
return self.__getattr__("axis") return self.__getattr__("axis")
@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
Do not construct directly, use :any:`layout` instead.
See the documentation of :any:`layout` for more details.
See Also
--------
layout : Declare a layout
"""
def __len__(self):
return _api_internal._LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _api_internal._LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _api_internal._LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _api_internal._LayoutFactorOf(self, axis)
@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
See Also
--------
bijective_layout : Declare a bijective layout converter
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _api_internal._BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _api_internal._BijectiveLayoutBackwardShape(self, shape)
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from tvm.ir import Range
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from . import _api_internal from . import _api_internal
from . import api as _api from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
...@@ -39,7 +40,7 @@ def _get_region(tslice): ...@@ -39,7 +40,7 @@ def _get_region(tslice):
begin = idx.var begin = idx.var
else: else:
begin = idx begin = idx
region.append(_make.range_by_min_extent(begin, 1)) region.append(Range.make_by_min_extent(begin, 1))
return region return region
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -136,7 +137,7 @@ def decl_tensor_intrin(op, ...@@ -136,7 +137,7 @@ def decl_tensor_intrin(op,
scalar_params = [] scalar_params = []
if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)): if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
body = [body] body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body] body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
if len(body) < 3: if len(body) < 3:
body += [None] * (3 - len(body)) body += [None] * (3 - len(body))
return _api_internal._TensorIntrin( return _api_internal._TensorIntrin(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin
"""Namespace for Tensor-level IR"""
from tvm.ir import PrimExpr
from .buffer import Buffer, decl_buffer
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, min_value, max_value
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from . import ir_builder
from . import ir_pass
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir"""
import tvm._ffi
tvm._ffi._init_api("tir", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Abstraction for array data structures."""
from numbers import Integral
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import PrimExpr
from . import _ffi_api
@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :py:func:`~decl_buffer` instead.
See the documentation of :py:func:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _ffi_api.BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
return _ffi_api.BufferVStore(self, begin, value)
def decl_buffer(shape,
dtype=None,
name="buffer",
data=None,
strides=None,
elem_offset=None,
scope="",
data_alignment=-1,
offset_factor=0,
buffer_type=""):
"""Declare a new symbolic buffer.
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
Parameters
----------
shape : tuple of Expr
The shape of the buffer.
dtype : str, optional
The data type of the buffer.
name : str, optional
The name of the buffer.
data : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
elem_offset: Expr, optional
The beginning offset of the array to data.
In terms of number of elements of dtype.
scope: str, optional
The storage scope of the buffer, if not global.
If scope equals empty string, it means it is global memory.
data_alignment: int, optional
The alignment of data pointer in bytes.
If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, optional
The factor of elem_offset field, when set,
elem_offset is required to be multiple of offset_factor.
If 0 is pssed, the alignment will be set to 1.
if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
buffer_type: str, optional, {"", "auto_broadcast"}
auto_broadcast buffer allows one to implement broadcast computation
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
Returns
-------
buffer : Buffer
The created buffer
Example
-------
Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
.. code-block:: python
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
fadd(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
While DLTensor data structure is very general, it is usually helpful
to create function that only handles specific case of data structure
and make compiled function benefit from it.
If user pass strides and elem_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""
# pylint: disable=import-outside-toplevel
from .expr import Var
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
dtype = "float32" if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = Var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = Var(name, "handle")
return _ffi_api.Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor, buffer_type)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Data layout."""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
See Also
--------
layout : Declare a layout
"""
def __len__(self):
return _ffi_api.LayoutNdim(self)
def __contains__(self, axis):
return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Layout index out of range")
return _ffi_api.LayoutGetItem(self, index)
def index_of(self, axis):
"""Get the index of an axis
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
index : int
The index of the axis, -1 if not found.
"""
return _ffi_api.LayoutIndexOf(self, axis)
def factor_of(self, axis):
"""Get the factor size of the subordinate axis.
Parameters
----------
axis : str
The axis name, need to be [a-z,A-Z]
Returns
-------
factor : int
the size of the subordinate-axis of axis (if axis is a primal-axis),
or the size of axis itself (if axis is a subordinate-axis).
Return -1 if axis is not in the layout.
"""
return _ffi_api.LayoutFactorOf(self, axis)
@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
Do not construct directly, use :any:`bijective_layout` instead.
See the documentation of :any:`bijective_layout` for more details.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
See Also
--------
bijective_layout : Declare a layout
"""
def forward_index(self, index):
"""Given the indices of the src-layout, infer the dst index.
Parameters
----------
index: Array of Expr
The indices in src-layout.
Returns
-------
dst_index: Array of Expr
The inferred indices in dst-layout.
"""
return _ffi_api.BijectiveLayoutForwardIndex(self, index)
def backward_index(self, index):
"""Given the indices of the dst-layout, infer the src index.
Parameters
----------
index: Array of Expr
The indices in dst-layout.
Returns
-------
src_index: Array of Expr
The inferred indices in src-layout.
"""
return _ffi_api.BijectiveLayoutBackwardIndex(self, index)
def forward_shape(self, shape):
"""Given the shape of the src-layout, infer the dst shape.
Parameters
----------
shape: Array of Expr
The shape in src-layout.
Returns
-------
dst_shape: Array of Expr
The inferred shape in dst-layout.
"""
return _ffi_api.BijectiveLayoutForwardShape(self, shape)
def backward_shape(self, shape):
"""Given the shape of the dst-layout, infer the src shape.
Parameters
----------
shape: Array of Expr
The shape in dst-layout.
Returns
-------
src_shape: Array of Expr
The inferred shape in src-layout.
"""
return _ffi_api.BijectiveLayoutBackwardShape(self, shape)
def layout(layout_str):
"""Create a layout node from a string.
Parameters
----------
layout_str : str
A layout representation is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block].
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
Returns
-------
layout : Layout
The created layout
"""
return _ffi_api.Layout(layout_str)
def bijective_layout(src_layout, dst_layout):
"""Create a bijective layout mapping.
Parameters
----------
src_layout : str or Layout
source layout.
dst_layout : str or Layout
destination layout.
Returns
-------
bijective_layout : BijectiveLayout
The created bijective layout
"""
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
return _ffi_api.BijectiveLayout(src_layout, dst_layout)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic opertors in TVM.
We follow the numpy naming convention for this interface
(e.g., tvm.generic.multitply ~ numpy.multiply).
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
from . import _ffi_api
# Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of add operaton.
"""
return _ffi_api._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
"""Generic subtract operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _ffi_api._OpSub(lhs, rhs)
def multiply(lhs, rhs):
"""Generic multiply operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _ffi_api._OpMul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._OpFloorDiv(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
Parameters
----------
src : object
The source operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _ffi_api._cast(dtype, src)
...@@ -16,15 +16,13 @@ ...@@ -16,15 +16,13 @@
# under the License. # under the License.
"""Developer API of IR node builder make function.""" """Developer API of IR node builder make function."""
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm.runtime import ObjectGeneric, DataType from tvm.runtime import ObjectGeneric, DataType, convert, const
from tvm.ir import container as _container from tvm.ir import container as _container
from . import api as _api
from . import stmt as _stmt from . import stmt as _stmt
from . import expr as _expr from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from .expr import Call as _Call
class WithScope(object): class WithScope(object):
"""Auxiliary scope with""" """Auxiliary scope with"""
...@@ -53,7 +51,7 @@ class BufferVar(ObjectGeneric): ...@@ -53,7 +51,7 @@ class BufferVar(ObjectGeneric):
.. code-block:: python .. code-block:: python
# The following code generate IR for x[0] = x[ # The following code generate IR for x[0] = x[
ib = tvm.ir_builder.create() ib = tvm.tir.ir_builder.create()
x = ib.pointer("float32") x = ib.pointer("float32")
x[0] = x[10] + 1 x[0] = x[10] + 1
...@@ -78,19 +76,19 @@ class BufferVar(ObjectGeneric): ...@@ -78,19 +76,19 @@ class BufferVar(ObjectGeneric):
def __getitem__(self, index): def __getitem__(self, index):
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes) index = _expr.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index) return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value): def __setitem__(self, index, value):
value = _api.convert(value) value = convert(value)
if value.dtype != self._content_type: if value.dtype != self._content_type:
raise ValueError( raise ValueError(
"data type does not match content type %s vs %s" % ( "data type does not match content type %s vs %s" % (
value.dtype, self._content_type)) value.dtype, self._content_type))
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes) index = _expr.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index)) self._builder.emit(_stmt.Store(self._buffer_var, value, index))
class IRBuilder(object): class IRBuilder(object):
...@@ -117,7 +115,7 @@ class IRBuilder(object): ...@@ -117,7 +115,7 @@ class IRBuilder(object):
"""Pop sequence from stack""" """Pop sequence from stack"""
seq = self._seq_stack.pop() seq = self._seq_stack.pop()
if not seq or callable(seq[-1]): if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0)) seq.append(_stmt.Evaluate(0))
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]] ret_seq = [seq[-1]]
...@@ -138,7 +136,7 @@ class IRBuilder(object): ...@@ -138,7 +136,7 @@ class IRBuilder(object):
The statement to be emitted or callable that build stmt given body. The statement to be emitted or callable that build stmt given body.
""" """
if isinstance(stmt, _expr.Call): if isinstance(stmt, _expr.Call):
stmt = _make.Evaluate(stmt) stmt = _stmt.Evaluate(stmt)
assert isinstance(stmt, _stmt.Stmt) or callable(stmt) assert isinstance(stmt, _stmt.Stmt) or callable(stmt)
self._seq_stack[-1].append(stmt) self._seq_stack[-1].append(stmt)
...@@ -167,10 +165,10 @@ class IRBuilder(object): ...@@ -167,10 +165,10 @@ class IRBuilder(object):
x[i] = x[i - 1] + 1 x[i] = x[i - 1] + 1
""" """
if isinstance(node, string_types): if isinstance(node, string_types):
node = _make.StringImm(node) node = _expr.StringImm(node)
if isinstance(value, string_types): if isinstance(value, string_types):
value = _make.StringImm(value) value = _expr.StringImm(value)
self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x)) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
"""Create a for iteration scope. """Create a for iteration scope.
...@@ -211,7 +209,7 @@ class IRBuilder(object): ...@@ -211,7 +209,7 @@ class IRBuilder(object):
name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3) name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
self.nidx += 1 self.nidx += 1
self._seq_stack.append([]) self._seq_stack.append([])
loop_var = _api.var(name, dtype=dtype) loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin) extent = end if begin == 0 else _pass.Simplify(end - begin)
def _exit_cb(): def _exit_cb():
if for_type == "serial": if for_type == "serial":
...@@ -224,7 +222,7 @@ class IRBuilder(object): ...@@ -224,7 +222,7 @@ class IRBuilder(object):
for_type_id = 3 for_type_id = 3
else: else:
raise ValueError("Unknown for_type") raise ValueError("Unknown for_type")
self.emit(_make.For( self.emit(_stmt.For(
loop_var, begin, extent, for_type_id, 0, self._pop_seq())) loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
return WithScope(loop_var, _exit_cb) return WithScope(loop_var, _exit_cb)
...@@ -253,7 +251,7 @@ class IRBuilder(object): ...@@ -253,7 +251,7 @@ class IRBuilder(object):
""" """
self._seq_stack.append([]) self._seq_stack.append([])
def _exit_cb(): def _exit_cb():
self.emit(_make.IfThenElse(cond, self._pop_seq(), None)) self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None))
return WithScope(None, _exit_cb) return WithScope(None, _exit_cb)
def else_scope(self): def else_scope(self):
...@@ -286,7 +284,7 @@ class IRBuilder(object): ...@@ -286,7 +284,7 @@ class IRBuilder(object):
self._seq_stack[-1].pop() self._seq_stack[-1].pop()
self._seq_stack.append([]) self._seq_stack.append([])
def _exit_cb(): def _exit_cb():
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq())) self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb) return WithScope(None, _exit_cb)
def new_scope(self): def new_scope(self):
...@@ -326,13 +324,13 @@ class IRBuilder(object): ...@@ -326,13 +324,13 @@ class IRBuilder(object):
buffer : BufferVar buffer : BufferVar
The buffer var representing the buffer. The buffer var representing the buffer.
""" """
buffer_var = _api.var(name, dtype="handle") buffer_var = _expr.Var(name, dtype="handle")
if not isinstance(shape, (list, tuple, _container.Array)): if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape] shape = [shape]
if scope: if scope:
self.scope_attr(buffer_var, "storage_scope", scope) self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _make.Allocate( self.emit(lambda x: _stmt.Allocate(
buffer_var, dtype, shape, _api.const(1, dtype="uint1"), x)) buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, dtype) return BufferVar(self, buffer_var, dtype)
def pointer(self, content_type, name="ptr"): def pointer(self, content_type, name="ptr"):
...@@ -351,7 +349,7 @@ class IRBuilder(object): ...@@ -351,7 +349,7 @@ class IRBuilder(object):
ptr : BufferVar ptr : BufferVar
The buffer var representing the buffer. The buffer var representing the buffer.
""" """
buffer_var = _api.var(name, dtype="handle") buffer_var = _expr.Var(name, dtype="handle")
return BufferVar(self, buffer_var, content_type) return BufferVar(self, buffer_var, content_type)
def buffer_ptr(self, buf): def buffer_ptr(self, buf):
...@@ -380,7 +378,8 @@ class IRBuilder(object): ...@@ -380,7 +378,8 @@ class IRBuilder(object):
expr : Expr expr : Expr
The expression will likely tag. The expression will likely tag.
""" """
return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0) return _expr.Call(expr.dtype, "likely", [expr],
_expr.Call.PureIntrinsic, None, 0)
def get(self): def get(self):
"""Return the builded IR. """Return the builded IR.
......
...@@ -25,4 +25,4 @@ You can read "include/tvm/tir/ir_pass.h" for the function signature and ...@@ -25,4 +25,4 @@ You can read "include/tvm/tir/ir_pass.h" for the function signature and
""" """
import tvm._ffi import tvm._ffi
tvm._ffi._init_api("tvm.ir_pass") tvm._ffi._init_api("tvm.ir_pass", __name__)
...@@ -25,18 +25,19 @@ Each statement node have subfields that can be visited from python side. ...@@ -25,18 +25,19 @@ Each statement node have subfields that can be visited from python side.
x = tvm.var("n") x = tvm.var("n")
a = tvm.var("array", tvm.handle) a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1) st = tvm.tir.stmt.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store) assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a) assert(st.buffer_var == a)
""" """
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from . import make as _make from . import _ffi_api
class Stmt(Object): class Stmt(Object):
pass """Base class of all the statements."""
@tvm._ffi.register_object @tvm._ffi.register_object
class LetStmt(Stmt): class LetStmt(Stmt):
...@@ -47,7 +48,7 @@ class LetStmt(Stmt): ...@@ -47,7 +48,7 @@ class LetStmt(Stmt):
var : Var var : Var
The variable in the binding. The variable in the binding.
value : Expr value : PrimExpr
The value in to be binded. The value in to be binded.
body : Stmt body : Stmt
...@@ -55,7 +56,7 @@ class LetStmt(Stmt): ...@@ -55,7 +56,7 @@ class LetStmt(Stmt):
""" """
def __init__(self, var, value, body): def __init__(self, var, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.LetStmt, var, value, body) _ffi_api.LetStmt, var, value, body)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -64,10 +65,10 @@ class AssertStmt(Stmt): ...@@ -64,10 +65,10 @@ class AssertStmt(Stmt):
Parameters Parameters
---------- ----------
condition : Expr condition : PrimExpr
The assert condition. The assert condition.
message : Expr message : PrimExpr
The error message. The error message.
body : Stmt body : Stmt
...@@ -75,7 +76,7 @@ class AssertStmt(Stmt): ...@@ -75,7 +76,7 @@ class AssertStmt(Stmt):
""" """
def __init__(self, condition, message, body): def __init__(self, condition, message, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.AssertStmt, condition, message, body) _ffi_api.AssertStmt, condition, message, body)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -95,7 +96,7 @@ class ProducerConsumer(Stmt): ...@@ -95,7 +96,7 @@ class ProducerConsumer(Stmt):
""" """
def __init__(self, func, is_producer, body): def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.ProducerConsumer, func, is_producer, body) _ffi_api.ProducerConsumer, func, is_producer, body)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -107,10 +108,10 @@ class For(Stmt): ...@@ -107,10 +108,10 @@ class For(Stmt):
loop_var : Var loop_var : Var
The loop variable. The loop variable.
min_val : Expr min_val : PrimExpr
The begining value. The begining value.
extent : Expr extent : PrimExpr
The length of the loop. The length of the loop.
for_type : int for_type : int
...@@ -134,7 +135,7 @@ class For(Stmt): ...@@ -134,7 +135,7 @@ class For(Stmt):
device_api, device_api,
body): body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.For, loop_var, min_val, extent, _ffi_api.For, loop_var, min_val, extent,
for_type, device_api, body) for_type, device_api, body)
...@@ -147,18 +148,19 @@ class Store(Stmt): ...@@ -147,18 +148,19 @@ class Store(Stmt):
buffer_var : Var buffer_var : Var
The buffer Variable. The buffer Variable.
value : Expr value : PrimExpr
The value we want to store. The value we want to store.
index : Expr index : PrimExpr
The index in the store expression. The index in the store expression.
predicate : Expr predicate : PrimExpr
The store predicate. The store predicate.
""" """
def __init__(self, buffer_var, value, index, predicate): def __init__(self, buffer_var, value, index, predicate=None):
args = [] if predicate is None else [predicate]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Store, buffer_var, value, index, predicate) _ffi_api.Store, buffer_var, value, index, *args)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -173,7 +175,7 @@ class Provide(Stmt): ...@@ -173,7 +175,7 @@ class Provide(Stmt):
value_index : int value_index : int
The output value index The output value index
value : Expr value : PrimExpr
The value to be stored. The value to be stored.
args : list of Expr args : list of Expr
...@@ -181,7 +183,7 @@ class Provide(Stmt): ...@@ -181,7 +183,7 @@ class Provide(Stmt):
""" """
def __init__(self, func, value_index, value, args): def __init__(self, func, value_index, value, args):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Provide, func, value_index, value, args) _ffi_api.Provide, func, value_index, value, args)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -199,7 +201,7 @@ class Allocate(Stmt): ...@@ -199,7 +201,7 @@ class Allocate(Stmt):
extents : list of Expr extents : list of Expr
The extents of the allocate The extents of the allocate
condition : Expr condition : PrimExpr
The condition. The condition.
body : Stmt body : Stmt
...@@ -212,7 +214,7 @@ class Allocate(Stmt): ...@@ -212,7 +214,7 @@ class Allocate(Stmt):
condition, condition,
body): body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Allocate, buffer_var, dtype, _ffi_api.Allocate, buffer_var, dtype,
extents, condition, body) extents, condition, body)
...@@ -228,7 +230,7 @@ class AttrStmt(Stmt): ...@@ -228,7 +230,7 @@ class AttrStmt(Stmt):
attr_key : str attr_key : str
Attribute type key. Attribute type key.
value : Expr value : PrimExpr
The value of the attribute The value of the attribute
body : Stmt body : Stmt
...@@ -236,7 +238,7 @@ class AttrStmt(Stmt): ...@@ -236,7 +238,7 @@ class AttrStmt(Stmt):
""" """
def __init__(self, node, attr_key, value, body): def __init__(self, node, attr_key, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.AttrStmt, node, attr_key, value, body) _ffi_api.AttrStmt, node, attr_key, value, body)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -250,7 +252,7 @@ class Free(Stmt): ...@@ -250,7 +252,7 @@ class Free(Stmt):
""" """
def __init__(self, buffer_var): def __init__(self, buffer_var):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Free, buffer_var) _ffi_api.Free, buffer_var)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -271,7 +273,7 @@ class Realize(Stmt): ...@@ -271,7 +273,7 @@ class Realize(Stmt):
bounds : list of range bounds : list of range
The bound of realize The bound of realize
condition : Expr condition : PrimExpr
The realize condition. The realize condition.
body : Stmt body : Stmt
...@@ -285,7 +287,7 @@ class Realize(Stmt): ...@@ -285,7 +287,7 @@ class Realize(Stmt):
condition, condition,
body): body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Realize, func, value_index, dtype, _ffi_api.Realize, func, value_index, dtype,
bounds, condition, body) bounds, condition, body)
...@@ -300,7 +302,7 @@ class SeqStmt(Stmt): ...@@ -300,7 +302,7 @@ class SeqStmt(Stmt):
""" """
def __init__(self, seq): def __init__(self, seq):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.SeqStmt, seq) _ffi_api.SeqStmt, seq)
def __getitem__(self, i): def __getitem__(self, i):
return self.seq[i] return self.seq[i]
...@@ -315,7 +317,7 @@ class IfThenElse(Stmt): ...@@ -315,7 +317,7 @@ class IfThenElse(Stmt):
Parameters Parameters
---------- ----------
condition : Expr condition : PrimExpr
The expression The expression
then_case : Stmt then_case : Stmt
...@@ -326,7 +328,7 @@ class IfThenElse(Stmt): ...@@ -326,7 +328,7 @@ class IfThenElse(Stmt):
""" """
def __init__(self, condition, then_case, else_case): def __init__(self, condition, then_case, else_case):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.IfThenElse, condition, then_case, else_case) _ffi_api.IfThenElse, condition, then_case, else_case)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -335,12 +337,12 @@ class Evaluate(Stmt): ...@@ -335,12 +337,12 @@ class Evaluate(Stmt):
Parameters Parameters
---------- ----------
value : Expr value : PrimExpr
The expression to be evalued. The expression to be evalued.
""" """
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Evaluate, value) _ffi_api.Evaluate, value)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -363,7 +365,7 @@ class Prefetch(Stmt): ...@@ -363,7 +365,7 @@ class Prefetch(Stmt):
""" """
def __init__(self, func, value_index, dtype, bounds): def __init__(self, func, value_index, dtype, bounds):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Prefetch, func, value_index, dtype, bounds) _ffi_api.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -417,6 +419,3 @@ def stmt_list(stmt): ...@@ -417,6 +419,3 @@ def stmt_list(stmt):
if isinstance(stmt, ProducerConsumer): if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body) return stmt_list(stmt.body)
return [stmt] return [stmt]
_make.stmt_list = stmt_list
...@@ -30,50 +30,50 @@ ...@@ -30,50 +30,50 @@
namespace tvm { namespace tvm {
namespace tir { namespace tir {
TVM_REGISTER_GLOBAL("_Var") TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return Var(s, t); return Var(s, t);
}); });
TVM_REGISTER_GLOBAL("_SizeVar") TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t); return SizeVar(s, t);
}); });
TVM_REGISTER_GLOBAL("make.abs") TVM_REGISTER_GLOBAL("tir.abs")
.set_body_typed(tvm::abs); .set_body_typed(tvm::abs);
TVM_REGISTER_GLOBAL("make.isnan") TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan); .set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("make.floor") TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor); .set_body_typed(tvm::floor);
TVM_REGISTER_GLOBAL("make.ceil") TVM_REGISTER_GLOBAL("tir.ceil")
.set_body_typed(tvm::ceil); .set_body_typed(tvm::ceil);
TVM_REGISTER_GLOBAL("make.round") TVM_REGISTER_GLOBAL("tir.round")
.set_body_typed(tvm::round); .set_body_typed(tvm::round);
TVM_REGISTER_GLOBAL("make.nearbyint") TVM_REGISTER_GLOBAL("tir.nearbyint")
.set_body_typed(tvm::nearbyint); .set_body_typed(tvm::nearbyint);
TVM_REGISTER_GLOBAL("make.trunc") TVM_REGISTER_GLOBAL("tir.trunc")
.set_body_typed(tvm::trunc); .set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("make._cast") TVM_REGISTER_GLOBAL("tir._cast")
.set_body_typed(tvm::cast); .set_body_typed(tvm::cast);
TVM_REGISTER_GLOBAL("make._range_by_min_extent") TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
.set_body_typed(Range::make_by_min_extent); .set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("make.SeqStmt") TVM_REGISTER_GLOBAL("tir.SeqStmt")
.set_body_typed([](Array<Stmt> seq) { .set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq)); return SeqStmt(std::move(seq));
}); });
TVM_REGISTER_GLOBAL("make.For") TVM_REGISTER_GLOBAL("tir.For")
.set_body_typed([]( .set_body_typed([](
Var loop_var, PrimExpr min, PrimExpr extent, Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) { int for_type, int device_api, Stmt body) {
...@@ -85,7 +85,7 @@ TVM_REGISTER_GLOBAL("make.For") ...@@ -85,7 +85,7 @@ TVM_REGISTER_GLOBAL("make.For")
body); body);
}); });
TVM_REGISTER_GLOBAL("make.Load") TVM_REGISTER_GLOBAL("tir.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0]; DataType t = args[0];
if (args.size() == 3) { if (args.size() == 3) {
...@@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("make.Load") ...@@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("make.Load")
} }
}); });
TVM_REGISTER_GLOBAL("make.Store") TVM_REGISTER_GLOBAL("tir.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PrimExpr value = args[1]; PrimExpr value = args[1];
if (args.size() == 3) { if (args.size() == 3) {
...@@ -105,10 +105,10 @@ TVM_REGISTER_GLOBAL("make.Store") ...@@ -105,10 +105,10 @@ TVM_REGISTER_GLOBAL("make.Store")
} }
}); });
TVM_REGISTER_GLOBAL("make.Realize") TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make); .set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("make.Call") TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([]( .set_body_typed([](
DataType type, std::string name, DataType type, std::string name,
Array<PrimExpr> args, int call_type, Array<PrimExpr> args, int call_type,
...@@ -122,12 +122,12 @@ TVM_REGISTER_GLOBAL("make.Call") ...@@ -122,12 +122,12 @@ TVM_REGISTER_GLOBAL("make.Call")
value_index); value_index);
}); });
TVM_REGISTER_GLOBAL("make.CommReducer") TVM_REGISTER_GLOBAL("tir.CommReducer")
.set_body_typed(CommReducerNode::make); .set_body_typed(CommReducerNode::make);
// make from two arguments // make from two arguments
#define REGISTER_MAKE(NodeName) \ #define REGISTER_MAKE(NodeName) \
TVM_REGISTER_GLOBAL("make."#NodeName) \ TVM_REGISTER_GLOBAL("tir."#NodeName) \
.set_body_typed(NodeName ## Node::make); \ .set_body_typed(NodeName ## Node::make); \
...@@ -172,7 +172,7 @@ REGISTER_MAKE(Evaluate); ...@@ -172,7 +172,7 @@ REGISTER_MAKE(Evaluate);
// overloaded, needs special handling // overloaded, needs special handling
// has default args // has default args
TVM_REGISTER_GLOBAL("make.Allocate") TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([]( .set_body_typed([](
Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
){ ){
...@@ -180,14 +180,14 @@ TVM_REGISTER_GLOBAL("make.Allocate") ...@@ -180,14 +180,14 @@ TVM_REGISTER_GLOBAL("make.Allocate")
}); });
// operator overloading, smarter than make // operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \ #define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body_typed([](PrimExpr a, PrimExpr b) { \ .set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \ return (Func(a, b)); \
}) })
#define REGISTER_MAKE_BIT_OP(Node, Func) \ #define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("tir."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \ bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \
...@@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|); ...@@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>); REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse") TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { .set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value); return if_then_else(cond, true_value, false_value);
}); });
......
...@@ -34,10 +34,10 @@ ...@@ -34,10 +34,10 @@
namespace tvm { namespace tvm {
TVM_REGISTER_GLOBAL("_min_value") TVM_REGISTER_GLOBAL("tir.min_value")
.set_body_typed(min_value); .set_body_typed(min_value);
TVM_REGISTER_GLOBAL("_max_value") TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value); .set_body_typed(max_value);
TVM_REGISTER_GLOBAL("Range") TVM_REGISTER_GLOBAL("Range")
...@@ -49,66 +49,6 @@ TVM_REGISTER_GLOBAL("Range") ...@@ -49,66 +49,6 @@ TVM_REGISTER_GLOBAL("Range")
} }
}); });
namespace tir {
TVM_REGISTER_GLOBAL("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10);
auto buffer_type = args[9].operator std::string();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], type);
});
TVM_REGISTER_GLOBAL("_BufferAccessPtr")
.set_body_method(&Buffer::access_ptr);
TVM_REGISTER_GLOBAL("_BufferVLoad")
.set_body_method(&Buffer::vload);
TVM_REGISTER_GLOBAL("_BufferVStore")
.set_body_method(&Buffer::vstore);
TVM_REGISTER_GLOBAL("_Layout")
.set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("_LayoutIndexOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutNdim")
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("_LayoutGetItem")
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("_BijectiveLayout")
.set_body_typed(BijectiveLayoutNode::make);
TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardIndex")
.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape")
.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
namespace te { namespace te {
TVM_REGISTER_GLOBAL("_Tensor") TVM_REGISTER_GLOBAL("_Tensor")
.set_body_typed(TensorNode::make); .set_body_typed(TensorNode::make);
......
...@@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value) { ...@@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value) {
data_ = std::move(node); data_ = std::move(node);
} }
TVM_REGISTER_GLOBAL("make.IntImm") TVM_REGISTER_GLOBAL("ir.IntImm")
.set_body_typed([](DataType dtype, int64_t value) { .set_body_typed([](DataType dtype, int64_t value) {
return IntImm(dtype, value); return IntImm(dtype, value);
}); });
...@@ -97,7 +97,7 @@ FloatImm::FloatImm(DataType dtype, double value) { ...@@ -97,7 +97,7 @@ FloatImm::FloatImm(DataType dtype, double value) {
data_ = std::move(node); data_ = std::move(node);
} }
TVM_REGISTER_GLOBAL("make.FloatImm") TVM_REGISTER_GLOBAL("ir.FloatImm")
.set_body_typed([](DataType dtype, double value) { .set_body_typed([](DataType dtype, double value) {
return FloatImm(dtype, value); return FloatImm(dtype, value);
}); });
......
...@@ -304,6 +304,6 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr") ...@@ -304,6 +304,6 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr")
TVM_REGISTER_GLOBAL("node.NodeListAttrNames") TVM_REGISTER_GLOBAL("node.NodeListAttrNames")
.set_body(NodeListAttrNames); .set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("make._Node") TVM_REGISTER_GLOBAL("node.MakeNode")
.set_body(MakeNode); .set_body(MakeNode);
} // namespace tvm } // namespace tvm
...@@ -906,7 +906,9 @@ static const char* kSemVer = "v0.0.4"; ...@@ -906,7 +906,9 @@ static const char* kSemVer = "v0.0.4";
// - relay_text_printer.cc (specific printing logics for relay) // - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR) // - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) { std::string PrettyPrint(const ObjectRef& node) {
return AsText(node, false, nullptr); Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
} }
std::string AsText(const ObjectRef& node, std::string AsText(const ObjectRef& node,
...@@ -918,6 +920,10 @@ std::string AsText(const ObjectRef& node, ...@@ -918,6 +920,10 @@ std::string AsText(const ObjectRef& node,
return doc.str(); return doc.str();
} }
TVM_REGISTER_GLOBAL("ir.PrettyPrint")
.set_body_typed(PrettyPrint);
TVM_REGISTER_GLOBAL("ir.AsText") TVM_REGISTER_GLOBAL("ir.AsText")
.set_body_typed(AsText); .set_body_typed(AsText);
} // namespace tvm } // namespace tvm
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
/*! /*!
* \file buffer.cc * \file buffer.cc
*/ */
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -460,5 +461,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -460,5 +461,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_NODE_TYPE(BufferNode);
TVM_REGISTER_GLOBAL("tir.Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10);
auto buffer_type = args[9].operator std::string();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
*ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], type);
});
TVM_REGISTER_GLOBAL("tir.BufferAccessPtr")
.set_body_method(&Buffer::access_ptr);
TVM_REGISTER_GLOBAL("tir.BufferVLoad")
.set_body_method(&Buffer::vload);
TVM_REGISTER_GLOBAL("tir.BufferVStore")
.set_body_method(&Buffer::vstore);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file src/lang/data_layout.cc * \file src/lang/data_layout.cc
* \brief Data Layout expression. * \brief Data Layout expression.
*/ */
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <cctype> #include <cctype>
...@@ -371,5 +372,44 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -371,5 +372,44 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "BijectiveLayout(" << b->src_layout.name() p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")"; << "->" << b->dst_layout.name() << ")";
}); });
TVM_REGISTER_GLOBAL("tir.Layout")
.set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("tir.LayoutIndexOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutNdim")
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("tir.LayoutGetItem")
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
.set_body_typed(BijectiveLayoutNode::make);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -24,7 +24,7 @@ def test_reduce_prims(): ...@@ -24,7 +24,7 @@ def test_reduce_prims():
n = tvm.size_var('n') n = tvm.size_var('n')
m = tvm.size_var('m') m = tvm.size_var('m')
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R') R = tvm.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m)) k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B') B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule # schedule
...@@ -232,8 +232,8 @@ def test_rfactor_elemwise_threads(): ...@@ -232,8 +232,8 @@ def test_rfactor_elemwise_threads():
def test_argmax(): def test_argmax():
def fcombine(x, y): def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs return lhs, rhs
def fidentity(t0, t1): def fidentity(t0, t1):
...@@ -279,8 +279,8 @@ def test_argmax(): ...@@ -279,8 +279,8 @@ def test_argmax():
def test_rfactor_argmax(): def test_rfactor_argmax():
def fcombine(x, y): def fcombine(x, y):
lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs return lhs, rhs
def fidentity(t0, t1): def fidentity(t0, t1):
......
...@@ -82,10 +82,10 @@ def test_compile_tuple_dup(): ...@@ -82,10 +82,10 @@ def test_compile_tuple_dup():
def test_compile_full(): def test_compile_full():
# Shape calculations can happen in int64. The test checks that full operator # Shape calculations can happen in int64. The test checks that full operator
# can handle when shapes are not int32 # can handle when shapes are not int32
shape = (tvm.expr.IntImm('int32', 1), shape = (tvm.tir.IntImm('int32', 1),
tvm.expr.IntImm('int64', 16), tvm.tir.IntImm('int64', 16),
tvm.expr.IntImm('int64', 16), tvm.tir.IntImm('int64', 16),
tvm.expr.IntImm('int32', 64)) tvm.tir.IntImm('int32', 64))
output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32') output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
f = relay.Function([], output) f = relay.Function([], output)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
......
...@@ -41,7 +41,7 @@ def test_basic_build(): ...@@ -41,7 +41,7 @@ def test_basic_build():
} }
# build # build
targets = { targets = {
tvm.expr.IntImm("int32", ctx.device_type): tgt tvm.tir.IntImm("int32", ctx.device_type): tgt
} }
g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params) g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params)
......
...@@ -77,9 +77,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -77,9 +77,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol): def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler)) func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol)) func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
return func return func
......
...@@ -307,7 +307,7 @@ def get_synthetic_lib(): ...@@ -307,7 +307,7 @@ def get_synthetic_lib():
subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2, subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
gcc_input3], relay.copy(gcc_input0)) gcc_input3], relay.copy(gcc_input0))
subgraph0 = subgraph0.set_attribute( subgraph0 = subgraph0.set_attribute(
"Primitive", tvm.expr.IntImm("int32", 1)) "Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph0 # Call subgraph0
subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2]) subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2])
...@@ -320,7 +320,7 @@ def get_synthetic_lib(): ...@@ -320,7 +320,7 @@ def get_synthetic_lib():
subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6, subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
gcc_input7], relay.copy(gcc_input4)) gcc_input7], relay.copy(gcc_input4))
subgraph1 = subgraph1.set_attribute( subgraph1 = subgraph1.set_attribute(
"Primitive", tvm.expr.IntImm("int32", 1)) "Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph1 # Call subgraph1
subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5]) subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5])
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
""" test ir""" """ test ir"""
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.expr import * from tvm.tir.expr import *
from tvm.relay import op from tvm.relay import op
from tvm.relay.analysis import graph_equal from tvm.relay.analysis import graph_equal
import numpy as np import numpy as np
...@@ -110,7 +110,7 @@ def test_type_relation(): ...@@ -110,7 +110,7 @@ def test_type_relation():
num_inputs = 2 num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(func, args, num_inputs, attrs) tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args assert tr.args == args
......
...@@ -69,7 +69,7 @@ type List[A] { ...@@ -69,7 +69,7 @@ type List[A] {
""" """
def roundtrip(expr): def roundtrip(expr):
x = relay.fromtext(str(expr)) x = relay.fromtext(expr.astext())
assert_graph_equal(x, expr) assert_graph_equal(x, expr)
...@@ -343,7 +343,7 @@ def test_func(): ...@@ -343,7 +343,7 @@ def test_func():
# attributes # attributes
assert parses_as( assert parses_as(
"fn (n=5) { () }", "fn (n=5) { () }",
relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
) )
......
...@@ -630,8 +630,8 @@ def test_upsampling_infer_type(): ...@@ -630,8 +630,8 @@ def test_upsampling_infer_type():
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
"method=\"BINLINEAR\"" in y.astext() "method=\"BINLINEAR\"" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))), tvm.tir.Cast("int32", tvm.round(w*scale))),
"float32") "float32")
n, c = tvm.size_var("n"), tvm.size_var("c") n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
...@@ -647,9 +647,9 @@ def test_upsampling3d_infer_type(): ...@@ -647,9 +647,9 @@ def test_upsampling3d_infer_type():
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)), assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(d*scale)),
tvm.expr.Cast("int32", tvm.round(h*scale)), tvm.tir.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))), tvm.tir.Cast("int32", tvm.round(w*scale))),
"float32") "float32")
n, c = tvm.size_var("n"), tvm.size_var("c") n, c = tvm.size_var("n"), tvm.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32")) x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
......
...@@ -517,7 +517,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): ...@@ -517,7 +517,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
alpha_shape = (data[axis],) alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32") assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or not alpha: if all(isinstance(v, tvm.tir.Var) == 1 for v in data) or not alpha:
return return
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
......
...@@ -154,7 +154,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ...@@ -154,7 +154,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type) assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data): if all(isinstance(v, tvm.tir.Var) == 1 for v in data):
return return
func = relay.Function([x], z) func = relay.Function([x], z)
......
...@@ -160,9 +160,9 @@ def test_type_relation_alpha_equal(): ...@@ -160,9 +160,9 @@ def test_type_relation_alpha_equal():
broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity")
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
...@@ -322,7 +322,7 @@ def test_multi_node_subgraph(): ...@@ -322,7 +322,7 @@ def test_multi_node_subgraph():
p00 = relay.subtract(z00, w01) p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02) q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00) func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a")) func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a"))
x1 = relay.var('x1', shape=(10, 10)) x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10))
...@@ -332,7 +332,7 @@ def test_multi_node_subgraph(): ...@@ -332,7 +332,7 @@ def test_multi_node_subgraph():
p10 = relay.subtract(z10, w11) p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12) q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10) func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b")) func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b"))
assert not alpha_equal(func0, func1) assert not alpha_equal(func0, func1)
...@@ -413,9 +413,9 @@ def test_call_alpha_equal(): ...@@ -413,9 +413,9 @@ def test_call_alpha_equal():
v1 = relay.Var("v1") v1 = relay.Var("v1")
v2 = relay.Var("v2") v2 = relay.Var("v2")
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tt1 = relay.TensorType((1, 2, 3), "float32") tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8") tt2 = relay.TensorType((), "int8")
......
...@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops(): ...@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops():
add = x0 + y0 add = x0 + y0
# Function that uses C compiler # Function that uses C compiler
func = relay.Function([x0, y0], add) func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", func = func.set_attribute("Compiler",
tvm.expr.StringImm("ccompiler")) tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol", func = func.set_attribute("ExternalSymbol",
tvm.expr.StringImm("ccompiler_0")) tvm.tir.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y]) add_call = relay.Call(func, [x, y])
# Function that uses default compiler. Ops are fused in this function. # Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8)) p0 = relay.var("p0", shape=(8, 8))
...@@ -316,7 +316,7 @@ def test_extern_ccompiler_default_ops(): ...@@ -316,7 +316,7 @@ def test_extern_ccompiler_default_ops():
concat = relay.concatenate([log, exp], axis=0) concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat) fused_func = relay.Function([p0], concat)
fused_func = fused_func.set_attribute("Primitive", fused_func = fused_func.set_attribute("Primitive",
tvm.expr.IntImm("int32", 1)) tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call]) fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call) main = relay.Function([x, y], fused_call)
mod = tvm.IRModule() mod = tvm.IRModule()
......
...@@ -65,7 +65,7 @@ def test_tuple_type(): ...@@ -65,7 +65,7 @@ def test_tuple_type():
def test_type_relation(): def test_type_relation():
func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast') func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast')
attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4)) attrs = tvm.ir.make_node('attrs.TestAttrs', name='attr', padding=(3,4))
tp = TypeVar('tp') tp = TypeVar('tp')
tf = FuncType([], TupleType([]), [], []) tf = FuncType([], TupleType([]), [], [])
tt = TensorType([1, 2, 3], 'float32') tt = TensorType([1, 2, 3], 'float32')
......
...@@ -151,9 +151,9 @@ def test_reduce_combiner_simplify(): ...@@ -151,9 +151,9 @@ def test_reduce_combiner_simplify():
prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
sum_or_prod = comm_reducer( sum_or_prod = comm_reducer(
lambda x, y: tvm.expr.Select(dummy < 0, lambda x, y: tvm.tir.Select(dummy < 0,
x + y, x*y), x + y, x*y),
lambda t0: tvm.expr.Select(dummy < 0, lambda t0: tvm.tir.Select(dummy < 0,
tvm.const(0, t0), tvm.const(1, t0))) tvm.const(0, t0), tvm.const(1, t0)))
sum_and_prod = comm_reducer( sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0], lambda x, y: (x[0] + y[0],
...@@ -199,7 +199,7 @@ def test_reduce_combiner_simplify(): ...@@ -199,7 +199,7 @@ def test_reduce_combiner_simplify():
assert tvm.ir_pass.Equal(lhs, rhs) assert tvm.ir_pass.Equal(lhs, rhs)
# Test that components with side effects are not removed # Test that components with side effects are not removed
side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0) side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
...@@ -211,7 +211,7 @@ def test_reduce_simplify(): ...@@ -211,7 +211,7 @@ def test_reduce_simplify():
k = tvm.reduce_axis((0, 10), name="k") k = tvm.reduce_axis((0, 10), name="k")
j = tvm.reduce_axis((-5, 3), name="j") j = tvm.reduce_axis((-5, 3), name="j")
A = tvm.placeholder((10,), name='A') A = tvm.placeholder((10,), name='A')
ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]), ck.verify(tvm.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
tvm.sum(k + j, [k, j])) tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3]) ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now # The rule below is not typical, removed for now
...@@ -235,23 +235,23 @@ def test_simplify_if_then_else(): ...@@ -235,23 +235,23 @@ def test_simplify_if_then_else():
tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16), tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y) x), y)
expected = tvm.if_then_else( expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)), tvm.tir.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)), tvm.if_then_else(tvm.tir.LE(24512, tmod(((x*4) + y) - 4, 24528)),
tmod(((x*4) + y) - 4, 16), tmod(((x*4) + y) - 4, 16),
x), y) x), y)
ck.verify(res, expected) ck.verify(res, expected)
ck.verify(res2, expected) ck.verify(res2, expected)
# can only simplify if condition # can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) res = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) expected = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
ck.verify(res, ck.analyzer.canonical_simplify(expected)) ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10, res = tvm.tir.Select(x >= 10,
tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0) tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0) expected = tvm.tir.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected)) ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10, res = tvm.tir.Select(x >= 10,
tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0) tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
ck.verify(res, 0) ck.verify(res, 0)
......
...@@ -228,7 +228,7 @@ def test_select_bound(): ...@@ -228,7 +228,7 @@ def test_select_bound():
analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound( bd = analyzer.const_int_bound(
tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1)) tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1))
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == 11 assert bd.max_value == 11
......
...@@ -19,7 +19,7 @@ import tvm ...@@ -19,7 +19,7 @@ import tvm
def assert_expr_equal(a, b): def assert_expr_equal(a, b):
res = tvm.ir_pass.Simplify(a - b) res = tvm.ir_pass.Simplify(a - b)
equal = isinstance(res, tvm.expr.IntImm) and res.value == 0 equal = isinstance(res, tvm.tir.IntImm) and res.value == 0
if not equal: if not equal:
raise ValueError("{} and {} are not equal".format(a, b)) raise ValueError("{} and {} are not equal".format(a, b))
......
...@@ -23,14 +23,14 @@ def test_domain_touched(): ...@@ -23,14 +23,14 @@ def test_domain_touched():
m = tvm.var('m') m = tvm.var('m')
a = tvm.placeholder((n, m), name = 'a') a = tvm.placeholder((n, m), name = 'a')
b = tvm.placeholder((n, m), name = 'b') b = tvm.placeholder((n, m), name = 'b')
ir = tvm.make.For( ir = tvm.tir.For(
i, 0, n, 0, 0, i, 0, n, 0, 0,
tvm.make.For(j, 0, m, 0, 0, tvm.tir.For(j, 0, m, 0, 0,
tvm.make.Provide( tvm.tir.Provide(
a.op, a.op,
0, 0,
tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
[i, j] [i, j]
) )
) )
...@@ -51,7 +51,7 @@ def test_domain_touched(): ...@@ -51,7 +51,7 @@ def test_domain_touched():
assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].min.value == -1
assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[0].extent.value == 101
assert a_domain_rw[1].min.value == -1 assert a_domain_rw[1].min.value == -1
assert isinstance(a_domain_rw[1].extent, tvm.expr.Add) assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
assert a_domain_rw[1].extent.a.name == 'm' assert a_domain_rw[1].extent.a.name == 'm'
assert a_domain_rw[1].extent.b.value == 1 assert a_domain_rw[1].extent.b.value == 1
......
...@@ -41,7 +41,7 @@ def test_vector(): ...@@ -41,7 +41,7 @@ def test_vector():
base = 10 base = 10
stride = 3 stride = 3
lanes = 2 lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes)) s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base assert s.min_value.value == base
assert s.max_value.value == base + stride * lanes - 1 assert s.max_value.value == base + stride * lanes - 1
...@@ -99,7 +99,7 @@ def test_max_min(): ...@@ -99,7 +99,7 @@ def test_max_min():
def test_select(): def test_select():
ck = IntSetChecker() ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1), ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
{x : tvm.arith.IntervalSet(0, 10)}, (-1, 11)) {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
......
...@@ -84,7 +84,7 @@ def test_min_max_select(): ...@@ -84,7 +84,7 @@ def test_min_max_select():
assert m.coeff == 3 assert m.coeff == 3
assert m.base == 1 assert m.base == 1
m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2))
assert m.coeff == 1 assert m.coeff == 1
assert m.base == 0 assert m.base == 0
......
...@@ -25,9 +25,9 @@ def test_stmt_simplify(): ...@@ -25,9 +25,9 @@ def test_stmt_simplify():
with ib.if_scope(i < 12): with ib.if_scope(i < 12):
A[i] = C[i] A[i] = C[i]
body = tvm.stmt.LetStmt(n, 10, ib.get()) body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body) body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body, tvm.stmt.Store) assert isinstance(body.body, tvm.tir.Store)
def test_thread_extent_simplify(): def test_thread_extent_simplify():
...@@ -42,9 +42,9 @@ def test_thread_extent_simplify(): ...@@ -42,9 +42,9 @@ def test_thread_extent_simplify():
ib.scope_attr(ty, "thread_extent", 1) ib.scope_attr(ty, "thread_extent", 1)
with ib.if_scope(tx + ty < 12): with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty] A[tx] = C[tx + ty]
body = tvm.stmt.LetStmt(n, 10, ib.get()) body = tvm.tir.LetStmt(n, 10, ib.get())
body = tvm.ir_pass.CanonicalSimplify(body) body = tvm.ir_pass.CanonicalSimplify(body)
assert isinstance(body.body.body.body, tvm.stmt.Store) assert isinstance(body.body.body.body, tvm.tir.Store)
def test_basic_likely_elimination(): def test_basic_likely_elimination():
......
...@@ -185,19 +185,19 @@ def test_cuda_shuffle(): ...@@ -185,19 +185,19 @@ def test_cuda_shuffle():
def my_vectorize(stmt): def my_vectorize(stmt):
def vectorizer(op): def vectorizer(op):
if op.for_type == tvm.stmt.For.Vectorized: if op.for_type == tvm.tir.For.Vectorized:
four = tvm.const(4, 'int32') four = tvm.const(4, 'int32')
idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4) idx = tvm.tir.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
all_ones = tvm.const(1, 'int32x4') all_ones = tvm.const(1, 'int32x4')
store = op.body store = op.body
value = store.value value = store.value
new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones) new_a = tvm.tir.Load('int32x4', value.a.buffer_var, idx, all_ones)
bs, ids = [], [] bs, ids = [], []
for i in range(4): for i in range(4):
bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32'))) bs.append(tvm.tir.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
ids.append(tvm.const(3 - i, 'int32')) ids.append(tvm.const(3 - i, 'int32'))
new_b = tvm.make.Shuffle(bs, ids) new_b = tvm.tir.Shuffle(bs, ids)
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None return None
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
......
...@@ -29,9 +29,9 @@ def test_llvm_intrin(): ...@@ -29,9 +29,9 @@ def test_llvm_intrin():
tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
0, 3, 1 0, 3, 1
] ]
ib.emit(tvm.make.Evaluate( ib.emit(tvm.tir.Evaluate(
tvm.make.Call( tvm.tir.Call(
"int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get() body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
...@@ -643,14 +643,14 @@ def test_llvm_shuffle(): ...@@ -643,14 +643,14 @@ def test_llvm_shuffle():
def vectorizer(op): def vectorizer(op):
store = op.body store = op.body
idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8) idx = tvm.tir.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
all_ones = tvm.const(1, 'int32x8') all_ones = tvm.const(1, 'int32x8')
value = store.value value = store.value
b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)]) b_idx = tvm.tir.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones) new_a = tvm.tir.Load('int32x8', value.a.buffer_var, idx, all_ones)
new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones) new_b = tvm.tir.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
value = new_a + new_b value = new_a + new_b
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
......
...@@ -40,7 +40,7 @@ def test_opencl_ternary_expression(): ...@@ -40,7 +40,7 @@ def test_opencl_ternary_expression():
true_value = tvm.const(1, dtype=dtype) true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype) false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype) max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value) max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
......
...@@ -26,7 +26,7 @@ def test_static_callback(): ...@@ -26,7 +26,7 @@ def test_static_callback():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab) A = ib.buffer_ptr(Ab)
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
finit = tvm.make.StringImm("TVMBackendRunOnce") finit = tvm.tir.StringImm("TVMBackendRunOnce")
ib.scope_attr(cp, "coproc_uop_scope", finit) ib.scope_attr(cp, "coproc_uop_scope", finit)
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
......
...@@ -34,7 +34,7 @@ def test_stack_vm_basic(): ...@@ -34,7 +34,7 @@ def test_stack_vm_basic():
n = tvm.size_var('n') n = tvm.size_var('n')
Ab = tvm.decl_buffer((n, ), tvm.float32) Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.tir.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm") fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
...@@ -75,7 +75,7 @@ def test_stack_vm_cond(): ...@@ -75,7 +75,7 @@ def test_stack_vm_cond():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab) A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i: with ib.for_range(0, n - 1, "i") as i:
with ib.if_scope(tvm.make.EQ(i, 4)): with ib.if_scope(tvm.tir.EQ(i, 4)):
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
with ib.else_scope(): with ib.else_scope():
A[i + 1] = A[i] + 2 A[i + 1] = A[i] + 2
......
...@@ -31,7 +31,7 @@ def test_vector_comparison(): ...@@ -31,7 +31,7 @@ def test_vector_comparison():
A = tvm.placeholder(n, dtype=dtype, name='A') A = tvm.placeholder(n, dtype=dtype, name='A')
B = tvm.compute( B = tvm.compute(
A.shape, A.shape,
lambda i: tvm.expr.Select( lambda i: tvm.tir.Select(
A[i] >= 0, A[i] + tvm.const(1, dtype), A[i] >= 0, A[i] + tvm.const(1, dtype),
tvm.const(0, dtype)), name='B') tvm.const(0, dtype)), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import tvm import tvm
from ctypes import * from ctypes import *
import topi import topi
import tvm.ir_pass as ir_pass import tvm.tir.ir_pass as ir_pass
import numpy as np import numpy as np
tgt = "llvm" tgt = "llvm"
...@@ -126,7 +126,7 @@ def test_bfloat_add_and_cast_FloatImm(): ...@@ -126,7 +126,7 @@ def test_bfloat_add_and_cast_FloatImm():
Z = topi.cast( Z = topi.cast(
topi.add( topi.add(
topi.cast(X, dtype="custom[bfloat]16"), topi.cast(X, dtype="custom[bfloat]16"),
tvm.expr.FloatImm("custom[bfloat]16", 1.5)), tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
dtype="float") dtype="float")
s = tvm.create_schedule([Z.op]) s = tvm.create_schedule([Z.op])
......
...@@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): ...@@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val): def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val) val = tvm.ir_pass.Simplify(val)
assert isinstance(val, (tvm.expr.IntImm,)) assert isinstance(val, (tvm.tir.IntImm,))
return val.value return val.value
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
...@@ -46,14 +46,14 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): ...@@ -46,14 +46,14 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
shape = [tvm_val_2_py_val(j) for j in i.shape] shape = [tvm_val_2_py_val(j) for j in i.shape]
emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx)) nd_args.append(tvm.nd.array(emu_args[-1], ctx))
elif isinstance(i, tvm.expr.Var): elif isinstance(i, tvm.tir.Var):
emu_args.append(tvm_val_2_py_val(i)) emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1]) nd_args.append(emu_args[-1])
else: else:
assert isinstance(i, list) assert isinstance(i, list)
emu_args.append(numpy.array(i)) emu_args.append(numpy.array(i))
compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))] + \
(outs if isinstance(outs, list) else [outs]) (outs if isinstance(outs, list) else [outs])
module = tvm.build(sch, module = tvm.build(sch,
compile_args, compile_args,
...@@ -76,7 +76,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): ...@@ -76,7 +76,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
for nd, np in zip(out_tensors, ref_data): for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))]
module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
h_module = tvm.hybrid.build(sch, module_args, module_outs) h_module = tvm.hybrid.build(sch, module_args, module_outs)
...@@ -111,32 +111,32 @@ def test_outer_product(): ...@@ -111,32 +111,32 @@ def test_outer_product():
return return
#Check for i in (0, n) #Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i' assert ir.loop_var.name == 'i'
assert ir.min.value == 0 assert ir.min.value == 0
assert ir.extent.name == 'n' assert ir.extent.name == 'n'
ibody = ir.body ibody = ir.body
assert isinstance(ibody, tvm.stmt.For) assert isinstance(ibody, tvm.tir.For)
#Check for j in (0, m) #Check for j in (0, m)
assert ibody.loop_var.name == 'j' assert ibody.loop_var.name == 'j'
assert ibody.min.value == 0 assert ibody.min.value == 0
assert ibody.extent.name == 'm' assert ibody.extent.name == 'm'
#Check loop body #Check loop body
jblock = ibody.body jblock = ibody.body
assert isinstance(jblock, tvm.stmt.SeqStmt) assert isinstance(jblock, tvm.tir.SeqStmt)
jbody = jblock[0] jbody = jblock[0]
assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody, tvm.tir.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm) assert isinstance(jbody.message, tvm.tir.StringImm)
assert jbody.message.value == "index out of range!" assert jbody.message.value == "index out of range!"
jbody = jblock[1] jbody = jblock[1]
assert isinstance(jbody, tvm.stmt.Provide) assert isinstance(jbody, tvm.tir.Provide)
assert jbody.func.name == 'c' assert jbody.func.name == 'c'
assert len(jbody.args) == 2 assert len(jbody.args) == 2
assert jbody.args[0].name == 'i' assert jbody.args[0].name == 'i'
assert jbody.args[1].name == 'j' assert jbody.args[1].name == 'j'
assert isinstance(jbody.value, tvm.expr.Mul) assert isinstance(jbody.value, tvm.tir.Mul)
mul = jbody.value mul = jbody.value
assert isinstance(mul.a, tvm.expr.Call) assert isinstance(mul.a, tvm.tir.Call)
assert mul.a.name == 'a' assert mul.a.name == 'a'
assert mul.b.name == 'b' assert mul.b.name == 'b'
...@@ -177,21 +177,21 @@ def test_fanout(): ...@@ -177,21 +177,21 @@ def test_fanout():
return return
#Check for i in (0, n-3) #Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i' assert ir.loop_var.name == 'i'
assert ir.min.value == 0 assert ir.min.value == 0
assert tvm.ir_pass.Equal(ir.extent, n - 3) assert tvm.ir_pass.Equal(ir.extent, n - 3)
#Check loopbody #Check loopbody
ibody = ir.body ibody = ir.body
assert isinstance(ibody, tvm.stmt.AttrStmt) assert isinstance(ibody, tvm.tir.AttrStmt)
abody = ibody.body abody = ibody.body
assert isinstance(abody, tvm.stmt.Realize) assert isinstance(abody, tvm.tir.Realize)
assert abody.bounds[0].min.value == 0 assert abody.bounds[0].min.value == 0
assert abody.bounds[0].extent.value == 1 assert abody.bounds[0].extent.value == 1
assert abody.func.name == 'sigma' assert abody.func.name == 'sigma'
#Check i loop body #Check i loop body
rbody = abody.body rbody = abody.body
assert isinstance(rbody[0], tvm.stmt.Provide) assert isinstance(rbody[0], tvm.tir.Provide)
assert rbody[0].func.name == 'sigma' assert rbody[0].func.name == 'sigma'
assert len(rbody[0].args) == 1 assert len(rbody[0].args) == 1
assert rbody[0].args[0].value == 0 assert rbody[0].args[0].value == 0
...@@ -201,13 +201,13 @@ def test_fanout(): ...@@ -201,13 +201,13 @@ def test_fanout():
assert jloop.min.value == 0 assert jloop.min.value == 0
assert jloop.extent.value == 3 assert jloop.extent.value == 3
jbody = jloop.body jbody = jloop.body
assert isinstance(jbody, tvm.stmt.Provide) assert isinstance(jbody, tvm.tir.Provide)
assert len(jbody.args) == 1 assert len(jbody.args) == 1
assert jbody.args[0].value == 0 assert jbody.args[0].value == 0
assert jbody.func.name == 'sigma' assert jbody.func.name == 'sigma'
assert isinstance(jbody.value, tvm.expr.Add) assert isinstance(jbody.value, tvm.tir.Add)
value = jbody.value value = jbody.value
assert isinstance(value.a, tvm.expr.Call) assert isinstance(value.a, tvm.tir.Call)
assert value.a.name == 'sigma' assert value.a.name == 'sigma'
assert len(value.a.args) == 1 assert len(value.a.args) == 1
assert value.a.args[0].value == 0 assert value.a.args[0].value == 0
...@@ -215,17 +215,17 @@ def test_fanout(): ...@@ -215,17 +215,17 @@ def test_fanout():
assert len(value.b.args) == 1 assert len(value.b.args) == 1
assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody[2] divide= rbody[2]
assert isinstance(divide, tvm.stmt.Provide) assert isinstance(divide, tvm.tir.Provide)
assert len(divide.args) == 1 assert len(divide.args) == 1
assert divide.args[0].value == 0 assert divide.args[0].value == 0
value = divide.value value = divide.value
assert isinstance(value, tvm.expr.Mul) assert isinstance(value, tvm.tir.Mul)
assert value.a.name == 'sigma' assert value.a.name == 'sigma'
assert len(value.a.args) == 1 assert len(value.a.args) == 1
assert value.a.args[0].value == 0 assert value.a.args[0].value == 0
assert abs(value.b.value - (1 / 3.0)) < 1e-5 assert abs(value.b.value - (1 / 3.0)) < 1e-5
write = rbody[3] write = rbody[3]
assert isinstance(write, tvm.stmt.Provide) assert isinstance(write, tvm.tir.Provide)
assert write.func.name == 'b' assert write.func.name == 'b'
assert write.value.name == 'sigma' assert write.value.name == 'sigma'
assert len(write.value.args) == 1 assert len(write.value.args) == 1
...@@ -260,9 +260,9 @@ def test_looptype(): ...@@ -260,9 +260,9 @@ def test_looptype():
iloop = ir[0] iloop = ir[0]
jloop = ir[1] jloop = ir[1]
kloop = ir[2] kloop = ir[2]
assert iloop.for_type == tvm.stmt.For.Parallel assert iloop.for_type == tvm.tir.For.Parallel
assert jloop.for_type == tvm.stmt.For.Vectorized assert jloop.for_type == tvm.tir.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled assert kloop.for_type == tvm.tir.For.Unrolled
func, ins, outs = run_and_check(looptype, [a, b, c]) func, ins, outs = run_and_check(looptype, [a, b, c])
run_and_check(func, ins, outs=outs) run_and_check(func, ins, outs=outs)
...@@ -364,7 +364,7 @@ def test_bind(): ...@@ -364,7 +364,7 @@ def test_bind():
c = foo(a) c = foo(a)
s = tvm.create_schedule(c.op) s = tvm.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True) ir = tvm.lower(s, [a, c], simple_mode=True)
assert not isinstance(ir, tvm.stmt.AttrStmt) assert not isinstance(ir, tvm.tir.AttrStmt)
func, ins, outs = run_and_check(foo, [a], target='cuda') func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
...@@ -729,20 +729,20 @@ def test_schedule(): ...@@ -729,20 +729,20 @@ def test_schedule():
sch[c].vectorize(ji) sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji) sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer) assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.inner' assert ir.loop_var.name == 'i.inner'
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.outer' assert ir.loop_var.name == 'i.outer'
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'j.outer.outer' assert ir.loop_var.name == 'j.outer.outer'
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'j.outer.inner' assert ir.loop_var.name == 'j.outer.inner'
ir = ir.body ir = ir.body
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
...@@ -752,11 +752,11 @@ def test_schedule(): ...@@ -752,11 +752,11 @@ def test_schedule():
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1]) sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer) assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i.j.fused' assert ir.loop_var.name == 'i.j.fused'
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs) run_and_check(func, ins, outs=outs)
......
...@@ -28,14 +28,14 @@ def test_for(): ...@@ -28,14 +28,14 @@ def test_for():
body = ib.get() body = ib.get()
print(body) print(body)
assert isinstance(body, tvm.stmt.AttrStmt) assert isinstance(body, tvm.tir.AttrStmt)
body = body.body body = body.body
assert isinstance(body, tvm.stmt.Allocate) assert isinstance(body, tvm.tir.Allocate)
body = body.body body = body.body
assert isinstance(body, tvm.stmt.For) assert isinstance(body, tvm.tir.For)
body = body.body body = body.body
assert isinstance(body, tvm.stmt.SeqStmt) assert isinstance(body, tvm.tir.SeqStmt)
assert isinstance(body[1], tvm.stmt.For) assert isinstance(body[1], tvm.tir.For)
def test_if(): def test_if():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
...@@ -50,11 +50,11 @@ def test_if(): ...@@ -50,11 +50,11 @@ def test_if():
body = ib.get() body = ib.get()
assert A == A assert A == A
assert isinstance(body, tvm.stmt.For) assert isinstance(body, tvm.tir.For)
body = body.body body = body.body
assert isinstance(body, tvm.stmt.IfThenElse) assert isinstance(body, tvm.tir.IfThenElse)
assert isinstance(body.condition, tvm.expr.EQ) assert isinstance(body.condition, tvm.tir.EQ)
assert isinstance(body.then_case.index, tvm.expr.Var) assert isinstance(body.then_case.index, tvm.tir.Var)
assert body.else_case.index.value == 0 assert body.else_case.index.value == 0
def test_prefetch(): def test_prefetch():
...@@ -64,10 +64,10 @@ def test_prefetch(): ...@@ -64,10 +64,10 @@ def test_prefetch():
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
ib.emit( ib.emit(
tvm.make.Prefetch( tvm.tir.Prefetch(
A.op, A.value_index, A.dtype, A.op, A.value_index, A.dtype,
[tvm.make.range_by_min_extent(i+1, 2), [tvm.ir.Range.make_by_min_extent(i+1, 2),
tvm.make.range_by_min_extent(0, 20)])) tvm.ir.Range.make_by_min_extent(0, 20)]))
body = ib.get() body = ib.get()
assert body.body.bounds[0].extent.value == 2 assert body.body.bounds[0].extent.value == 2
......
...@@ -22,7 +22,7 @@ def test_const(): ...@@ -22,7 +22,7 @@ def test_const():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
print(x.dtype) print(x.dtype)
assert x.dtype == tvm.int32 assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.tir.IntImm)
def test_scalar_dtype_inference(): def test_scalar_dtype_inference():
...@@ -45,47 +45,47 @@ def test_make(): ...@@ -45,47 +45,47 @@ def test_make():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.var("x") y = tvm.var("x")
z = x + y z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.max(x, y), tvm.tir.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min) assert isinstance(tvm.min(x, y), tvm.tir.Min)
def test_ir(): def test_ir():
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1) y = tvm.tir.IntImm('int32', 1)
z = x + y z = x + y
stmt = tvm.make.Evaluate(z) stmt = tvm.tir.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate) assert isinstance(stmt, tvm.tir.Evaluate)
def test_ir2(): def test_ir2():
x = tvm.var("n") x = tvm.var("n")
a = tvm.var("array", tvm.handle) a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1) st = tvm.tir.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store) assert isinstance(st, tvm.tir.Store)
assert(st.buffer_var == a) assert(st.buffer_var == a)
def test_let(): def test_let():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var('y') y = tvm.var('y')
stmt = tvm.make.LetStmt( stmt = tvm.tir.LetStmt(
x, 10, tvm.make.Evaluate(x + 1)); x, 10, tvm.tir.Evaluate(x + 1));
def test_cast(): def test_cast():
x = tvm.var('x', dtype="float32") x = tvm.var('x', dtype="float32")
y = x.astype("int32") y = x.astype("int32")
z = x.astype("float32x4") z = x.astype("float32x4")
assert isinstance(y, tvm.expr.Cast) assert isinstance(y, tvm.tir.Cast)
assert isinstance(z, tvm.expr.Broadcast) assert isinstance(z, tvm.tir.Broadcast)
assert z.lanes == 4 assert z.lanes == 4
def test_attr(): def test_attr():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var('y') y = tvm.var('y')
stmt = tvm.make.AttrStmt( stmt = tvm.tir.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1)); y, "stride", 10, tvm.tir.Evaluate(x + 1));
assert stmt.node == y assert stmt.node == y
a = tvm.convert(1) a = tvm.convert(1)
...@@ -105,9 +105,9 @@ def test_basic(): ...@@ -105,9 +105,9 @@ def test_basic():
def test_stmt(): def test_stmt():
x = tvm.make.Evaluate(0) x = tvm.tir.Evaluate(0)
tvm.make.For(tvm.var('i'), 0, 1, tvm.tir.For(tvm.var('i'), 0, 1,
tvm.stmt.For.Serial, 0, tvm.tir.For.Serial, 0,
x) x)
...@@ -207,7 +207,7 @@ def test_equality(): ...@@ -207,7 +207,7 @@ def test_equality():
def test_equality_string_imm(): def test_equality_string_imm():
x = 'a' x = 'a'
y = tvm.make.StringImm(x) y = tvm.tir.StringImm(x)
x == y.value x == y.value
x == y x == y
......
...@@ -17,50 +17,50 @@ ...@@ -17,50 +17,50 @@
import tvm import tvm
def test_expr_constructor(): def test_expr_constructor():
x = tvm.expr.Var("xx", "float32") x = tvm.tir.Var("xx", "float32")
assert isinstance(x, tvm.expr.Var) assert isinstance(x, tvm.tir.Var)
assert x.name == "xx" assert x.name == "xx"
x = tvm.expr.Reduce(None, [1], x = tvm.tir.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)], [tvm.api._IterVar((0, 1), "x", 2)],
None, 0) None, 0)
assert isinstance(x, tvm.expr.Reduce) assert isinstance(x, tvm.tir.Reduce)
assert x.combiner == None assert x.combiner == None
assert x.value_index == 0 assert x.value_index == 0
x = tvm.expr.FloatImm("float32", 1.0) x = tvm.tir.FloatImm("float32", 1.0)
assert isinstance(x, tvm.expr.FloatImm) assert isinstance(x, tvm.tir.FloatImm)
assert x.value == 1.0 assert x.value == 1.0
assert x.dtype == "float32" assert x.dtype == "float32"
x = tvm.expr.IntImm("int64", 2) x = tvm.tir.IntImm("int64", 2)
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.tir.IntImm)
assert x.value == 2 assert x.value == 2
assert x.dtype == "int64" assert x.dtype == "int64"
x = tvm.expr.StringImm("xyza") x = tvm.tir.StringImm("xyza")
assert isinstance(x, tvm.expr.StringImm) assert isinstance(x, tvm.tir.StringImm)
assert x.value == "xyza" assert x.value == "xyza"
x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1))
assert isinstance(x, tvm.expr.Cast) assert isinstance(x, tvm.tir.Cast)
assert x.dtype == "float32" assert x.dtype == "float32"
assert x.value.value == 1 assert x.value.value == 1
a = tvm.const(1.0, dtype="float32") a = tvm.const(1.0, dtype="float32")
b = tvm.var("x", dtype="float32") b = tvm.var("x", dtype="float32")
for cls in [tvm.expr.Add, for cls in [tvm.tir.Add,
tvm.expr.Sub, tvm.tir.Sub,
tvm.expr.Mul, tvm.tir.Mul,
tvm.expr.Div, tvm.tir.Div,
tvm.expr.Mod, tvm.tir.Mod,
tvm.expr.Min, tvm.tir.Min,
tvm.expr.Max, tvm.tir.Max,
tvm.expr.LT, tvm.tir.LT,
tvm.expr.LE, tvm.tir.LE,
tvm.expr.GT, tvm.tir.GT,
tvm.expr.GE]: tvm.tir.GE]:
x = cls(a, b) x = cls(a, b)
assert isinstance(x, cls) assert isinstance(x, cls)
assert x.a == a assert x.a == a
...@@ -70,58 +70,58 @@ def test_expr_constructor(): ...@@ -70,58 +70,58 @@ def test_expr_constructor():
a = tvm.convert(tvm.var("x") > 1) a = tvm.convert(tvm.var("x") > 1)
b = tvm.convert(tvm.var("x") == 1) b = tvm.convert(tvm.var("x") == 1)
for cls in [tvm.expr.And, for cls in [tvm.tir.And,
tvm.expr.Or]: tvm.tir.Or]:
x = cls(a, b) x = cls(a, b)
assert isinstance(x, cls) assert isinstance(x, cls)
assert x.a == a assert x.a == a
assert x.b.same_as(b) assert x.b.same_as(b)
x = tvm.expr.Not(a) x = tvm.tir.Not(a)
assert isinstance(x, tvm.expr.Not) assert isinstance(x, tvm.tir.Not)
assert x.a == a assert x.a == a
x = tvm.expr.Select(a, a, b) x = tvm.tir.Select(a, a, b)
assert isinstance(x, tvm.expr.Select) assert isinstance(x, tvm.tir.Select)
assert x.true_value == a assert x.true_value == a
assert x.false_value == b assert x.false_value == b
assert x.condition == a assert x.condition == a
buffer_var = tvm.var("x", dtype="handle") buffer_var = tvm.var("x", dtype="handle")
x = tvm.expr.Load("float32", buffer_var, 1, a) x = tvm.tir.Load("float32", buffer_var, 1, a)
assert isinstance(x, tvm.expr.Load) assert isinstance(x, tvm.tir.Load)
assert x.dtype == "float32" assert x.dtype == "float32"
assert x.buffer_var == buffer_var assert x.buffer_var == buffer_var
assert x.index.value == 1 assert x.index.value == 1
assert x.predicate == a assert x.predicate == a
x = tvm.expr.Ramp(1, 2, 10) x = tvm.tir.Ramp(1, 2, 10)
assert isinstance(x, tvm.expr.Ramp) assert isinstance(x, tvm.tir.Ramp)
assert x.base.value == 1 assert x.base.value == 1
assert x.stride.value == 2 assert x.stride.value == 2
assert x.lanes == 10 assert x.lanes == 10
x = tvm.expr.Broadcast(a, 10) x = tvm.tir.Broadcast(a, 10)
assert isinstance(x, tvm.expr.Broadcast) assert isinstance(x, tvm.tir.Broadcast)
assert x.value == a assert x.value == a
assert x.lanes == 10 assert x.lanes == 10
x = tvm.expr.Shuffle([a], [0]) x = tvm.tir.Shuffle([a], [0])
assert isinstance(x, tvm.expr.Shuffle) assert isinstance(x, tvm.tir.Shuffle)
assert x.vectors[0] == a assert x.vectors[0] == a
assert x.indices[0].value == 0 assert x.indices[0].value == 0
x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0) x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0)
assert isinstance(x, tvm.expr.Call) assert isinstance(x, tvm.tir.Call)
assert x.dtype == "float32" assert x.dtype == "float32"
assert x.name == "xyz" assert x.name == "xyz"
assert x.args[0] == a assert x.args[0] == a
assert x.call_type == tvm.expr.Call.Extern assert x.call_type == tvm.tir.Call.Extern
assert x.func == None assert x.func == None
assert x.value_index == 0 assert x.value_index == 0
v = tvm.var("aa") v = tvm.var("aa")
x = tvm.expr.Let(v, 1, v) x = tvm.tir.Let(v, 1, v)
assert x.var == v assert x.var == v
assert x.value.value == 1 assert x.value.value == 1
assert x.body == v assert x.body == v
...@@ -130,75 +130,75 @@ def test_expr_constructor(): ...@@ -130,75 +130,75 @@ def test_expr_constructor():
def test_stmt_constructor(): def test_stmt_constructor():
v = tvm.var("aa") v = tvm.var("aa")
buffer_var = tvm.var("buf", dtype="handle") buffer_var = tvm.var("buf", dtype="handle")
nop = tvm.stmt.Evaluate(1) nop = tvm.tir.Evaluate(1)
x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1)) x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1))
assert isinstance(x, tvm.stmt.LetStmt) assert isinstance(x, tvm.tir.LetStmt)
assert x.var == v assert x.var == v
assert x.value.value == 1 assert x.value.value == 1
assert isinstance(x.body, tvm.stmt.Evaluate) assert isinstance(x.body, tvm.tir.Evaluate)
x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1)) x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1))
assert isinstance(x, tvm.stmt.AttrStmt) assert isinstance(x, tvm.tir.AttrStmt)
assert x.value.value == 1 assert x.value.value == 1
x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"), x = tvm.tir.AssertStmt(tvm.const(1, "uint1"),
tvm.convert("hellow"), tvm.convert("hellow"),
nop) nop)
assert isinstance(x, tvm.stmt.AssertStmt) assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop assert x.body == nop
x = tvm.stmt.ProducerConsumer(None, True, nop) x = tvm.tir.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.stmt.ProducerConsumer) assert isinstance(x, tvm.tir.ProducerConsumer)
assert x.body == nop assert x.body == nop
x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop) x = tvm.tir.For(tvm.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.stmt.For) assert isinstance(x, tvm.tir.For)
assert x.min.value == 0 assert x.min.value == 0
assert x.extent.value == 10 assert x.extent.value == 10
assert x.body == nop assert x.body == nop
x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1")) x = tvm.tir.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
assert isinstance(x, tvm.stmt.Store) assert isinstance(x, tvm.tir.Store)
assert x.buffer_var == buffer_var assert x.buffer_var == buffer_var
assert x.index.value == 10 assert x.index.value == 10
assert x.value.value == 1 assert x.value.value == 1
tensor = tvm.placeholder((), dtype="float32") tensor = tvm.placeholder((), dtype="float32")
x = tvm.stmt.Provide(tensor.op, 0, 10, []) x = tvm.tir.Provide(tensor.op, 0, 10, [])
assert isinstance(x, tvm.stmt.Provide) assert isinstance(x, tvm.tir.Provide)
assert x.value_index == 0 assert x.value_index == 0
assert x.value.value == 10 assert x.value.value == 10
x = tvm.stmt.Allocate(buffer_var, "float32", [10], x = tvm.tir.Allocate(buffer_var, "float32", [10],
tvm.const(1, "uint1"), nop) tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Allocate) assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32" assert x.dtype == "float32"
assert x.buffer_var == buffer_var assert x.buffer_var == buffer_var
assert x.body == nop assert x.body == nop
x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop) x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.stmt.AttrStmt) assert isinstance(x, tvm.tir.AttrStmt)
assert x.node == buffer_var assert x.node == buffer_var
assert x.attr_key == "xyz" assert x.attr_key == "xyz"
assert x.body == nop assert x.body == nop
x = tvm.stmt.Free(buffer_var) x = tvm.tir.Free(buffer_var)
assert isinstance(x, tvm.stmt.Free) assert isinstance(x, tvm.tir.Free)
assert x.buffer_var == buffer_var assert x.buffer_var == buffer_var
x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop) x = tvm.tir.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Realize) assert isinstance(x, tvm.tir.Realize)
assert x.body == nop assert x.body == nop
x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"), x = tvm.tir.IfThenElse(tvm.const(1, "uint1"),
tvm.stmt.Evaluate(11), tvm.tir.Evaluate(11),
nop) nop)
assert isinstance(x, tvm.stmt.IfThenElse) assert isinstance(x, tvm.tir.IfThenElse)
assert x.then_case.value.value == 11 assert x.then_case.value.value == 11
assert x.else_case == nop assert x.else_case == nop
x = tvm.stmt.Prefetch(None, 1, "float32", []) x = tvm.tir.Prefetch(None, 1, "float32", [])
assert isinstance(x, tvm.stmt.Prefetch) assert isinstance(x, tvm.tir.Prefetch)
assert x.value_index == 1 assert x.value_index == 1
......
...@@ -69,7 +69,7 @@ def test_map_save_load_json(): ...@@ -69,7 +69,7 @@ def test_map_save_load_json():
def test_in_container(): def test_in_container():
arr = tvm.convert(['a', 'b', 'c']) arr = tvm.convert(['a', 'b', 'c'])
assert 'a' in arr assert 'a' in arr
assert tvm.make.StringImm('a') in arr assert tvm.tir.StringImm('a') in arr
assert 'd' not in arr assert 'd' not in arr
def test_ndarray_container(): def test_ndarray_container():
......
...@@ -20,9 +20,9 @@ import tvm ...@@ -20,9 +20,9 @@ import tvm
from topi.util import get_const_tuple from topi.util import get_const_tuple
def test_layout(): def test_layout():
layout = tvm.layout("NCHW16c") layout = tvm.tir.layout("NCHW16c")
assert layout is not None assert layout is not None
assert isinstance(layout, tvm.tensor.Layout) assert isinstance(layout, tvm.tir.Layout)
assert layout.factor_of("c") == 16 assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16 assert layout.factor_of("C") == 16
...@@ -63,7 +63,7 @@ def test_bilayout_convertible(): ...@@ -63,7 +63,7 @@ def test_bilayout_convertible():
def test_bilayout_shape(): def test_bilayout_shape():
bilayout = tvm.bijective_layout("NCHW", "NCHW16c") bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tensor.BijectiveLayout) assert isinstance(bilayout, tvm.tir.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7)) dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16) assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
......
...@@ -29,7 +29,7 @@ def test_const_fold(): ...@@ -29,7 +29,7 @@ def test_const_fold():
def check(f, *args): def check(f, *args):
x = f(*[tvm.const(x, "int32") for x in args]) x = f(*[tvm.const(x, "int32") for x in args])
y = f(*args) y = f(*args)
if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): if not isinstance(x, (tvm.tir.IntImm,)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y)) raise ValueError("check error: %s vs %s " % (x, y))
tmod = tvm.truncmod tmod = tvm.truncmod
...@@ -56,7 +56,7 @@ def test_const_fold2(): ...@@ -56,7 +56,7 @@ def test_const_fold2():
assert tmod(x, 1).value == 0 assert tmod(x, 1).value == 0
assert (x * 1).same_as(x) assert (x * 1).same_as(x)
assert (1 * x).same_as(x) assert (1 * x).same_as(x)
assert isinstance(tdiv(1, x), tvm.expr.Div) assert isinstance(tdiv(1, x), tvm.tir.Div)
def test_const_fold3(): def test_const_fold3():
# Test that using ints with logic operations is forbidden # Test that using ints with logic operations is forbidden
...@@ -92,17 +92,17 @@ def test_const_fold4(): ...@@ -92,17 +92,17 @@ def test_const_fold4():
x1 = tvm.const(4, "int32") x1 = tvm.const(4, "int32")
x2 = x1 + 5 x2 = x1 + 5
tdiv = tvm.truncdiv tdiv = tvm.truncdiv
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9
x3 = tdiv(x2, 3) x3 = tdiv(x2, 3)
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3
x4 = x3 + 0.55 x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6 assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6
x5 = tvm.ceil(x4) x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4
x6 = x5.astype('int') x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6) assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6)
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6 assert isinstance(y, tvm.tir.IntImm) and y.value == 6
def test_binary_dtype_match(): def test_binary_dtype_match():
......
...@@ -31,7 +31,7 @@ def test_make_smap(): ...@@ -31,7 +31,7 @@ def test_make_smap():
# save load json # save load json
x = tvm.const(1, "int32") x = tvm.const(1, "int32")
y = tvm.const(10, "int32") y = tvm.const(10, "int32")
z = tvm.expr.Add(x, y) z = tvm.tir.Add(x, y)
smap = tvm.convert({"z": z, "x": x}) smap = tvm.convert({"z": z, "x": x})
json_str = tvm.ir.save_json(tvm.convert([smap])) json_str = tvm.ir.save_json(tvm.convert([smap]))
arr = tvm.ir.load_json(json_str) arr = tvm.ir.load_json(json_str)
...@@ -40,11 +40,11 @@ def test_make_smap(): ...@@ -40,11 +40,11 @@ def test_make_smap():
def test_make_node(): def test_make_node():
x = tvm.make.node("IntImm", dtype="int32", value=10) x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10 assert x.value == 10
A = tvm.placeholder((10, ), name='A') A = tvm.placeholder((10, ), name='A')
AA = tvm.make.node("Tensor", AA = tvm.ir.make_node("Tensor",
shape=A.shape, shape=A.shape,
dtype=A.dtype, dtype=A.dtype,
op=A.op, op=A.op,
...@@ -55,25 +55,25 @@ def test_make_node(): ...@@ -55,25 +55,25 @@ def test_make_node():
def test_make_attrs(): def test_make_attrs():
try: try:
x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx") x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False assert False
except tvm.error.TVMError as e: except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1 assert str(e).find("unknown_key") != -1
try: try:
x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx") x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False assert False
except tvm.error.TVMError as e: except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1 assert str(e).find("upper bound") != -1
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4)) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx" assert x.name == "xx"
assert x.padding[0].value == 3 assert x.padding[0].value == 3
assert x.padding[1].value == 4 assert x.padding[1].value == 4
assert x.axis == 10 assert x.axis == 10
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1 assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz" assert dattr.name.value == "xyz"
...@@ -104,7 +104,7 @@ def test_env_func(): ...@@ -104,7 +104,7 @@ def test_env_func():
assert y(1) == 2 assert y(1) == 2
assert y.func(1) == 2 assert y.func(1) == 2
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
assert x.name == "xx" assert x.name == "xx"
assert x.padding[0].value == 3 assert x.padding[0].value == 3
assert x.padding[1].value == 4 assert x.padding[1].value == 4
......
...@@ -240,7 +240,7 @@ def test_tensor_intrin_scalar_params(): ...@@ -240,7 +240,7 @@ def test_tensor_intrin_scalar_params():
C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True) stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.body.value.args) == 5 assert len(stmt.body.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)" assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
......
...@@ -128,7 +128,7 @@ def test_tensor_compute1(): ...@@ -128,7 +128,7 @@ def test_tensor_compute1():
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.stmt.Evaluate) assert isinstance(stmt.body.body, tvm.tir.Evaluate)
def test_tensor_compute2(): def test_tensor_compute2():
M = 2048 M = 2048
...@@ -171,8 +171,8 @@ def test_tensor_compute2(): ...@@ -171,8 +171,8 @@ def test_tensor_compute2():
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate) assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate)
def test_tensor_scan(): def test_tensor_scan():
m = tvm.size_var("m") m = tvm.size_var("m")
...@@ -259,7 +259,7 @@ def test_tuple_with_different_deps(): ...@@ -259,7 +259,7 @@ def test_tuple_with_different_deps():
stmt = tvm.schedule.ScheduleOps(sch, bounds) stmt = tvm.schedule.ScheduleOps(sch, bounds)
def get_B1_realize(x): def get_B1_realize(x):
if isinstance(x, tvm.stmt.Realize) and \ if isinstance(x, tvm.tir.Realize) and \
x.func == B1.op and x.value_index == 1: x.func == B1.op and x.value_index == 1:
ret.append(x) ret.append(x)
ret = [] ret = []
......
...@@ -29,8 +29,8 @@ def test_operator_type_and_tags(): ...@@ -29,8 +29,8 @@ def test_operator_type_and_tags():
B1 = B[0] B1 = B[0]
B2 = B[0,0] B2 = B[0,0]
assert isinstance(k + n, tvm.expr.PrimExpr) assert isinstance(k + n, tvm.tir.PrimExpr)
assert isinstance(n + n, tvm.expr.PrimExpr) assert isinstance(n + n, tvm.tir.PrimExpr)
assert isinstance(k + A, tvm.tensor.Tensor) assert isinstance(k + A, tvm.tensor.Tensor)
assert isinstance(A + k, tvm.tensor.Tensor) assert isinstance(A + k, tvm.tensor.Tensor)
assert isinstance(n + A, tvm.tensor.Tensor) assert isinstance(n + A, tvm.tensor.Tensor)
...@@ -53,11 +53,11 @@ def test_operator_type_and_tags(): ...@@ -53,11 +53,11 @@ def test_operator_type_and_tags():
assert (B + A).op.tag == topi.tag.BROADCAST assert (B + A).op.tag == topi.tag.BROADCAST
assert (B + B).op.tag == topi.tag.BROADCAST assert (B + B).op.tag == topi.tag.BROADCAST
assert isinstance(k + B2, tvm.expr.PrimExpr) assert isinstance(k + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + k, tvm.expr.PrimExpr) assert isinstance(B2 + k, tvm.tir.PrimExpr)
assert isinstance(n + B2, tvm.expr.PrimExpr) assert isinstance(n + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + n, tvm.expr.PrimExpr) assert isinstance(B2 + n, tvm.tir.PrimExpr)
assert isinstance(B2 + B2, tvm.expr.PrimExpr) assert isinstance(B2 + B2, tvm.tir.PrimExpr)
assert isinstance(B2 + A, tvm.tensor.Tensor) assert isinstance(B2 + A, tvm.tensor.Tensor)
assert isinstance(A + B2, tvm.tensor.Tensor) assert isinstance(A + B2, tvm.tensor.Tensor)
assert isinstance(B2 + B, tvm.tensor.Tensor) assert isinstance(B2 + B, tvm.tensor.Tensor)
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
import tvm import tvm
def test_attrs_equal(): def test_attrs_equal():
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1)) z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
assert tvm.ir_pass.AttrsEqual(x, y) assert tvm.ir_pass.AttrsEqual(x, y)
assert not tvm.ir_pass.AttrsEqual(x, z) assert not tvm.ir_pass.AttrsEqual(x, z)
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert not tvm.ir_pass.AttrsEqual(dattr, x) assert not tvm.ir_pass.AttrsEqual(dattr, x)
dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert tvm.ir_pass.AttrsEqual(dattr, dattr2) assert tvm.ir_pass.AttrsEqual(dattr, dattr2)
assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y}) assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y})
...@@ -42,8 +42,8 @@ def test_attrs_equal(): ...@@ -42,8 +42,8 @@ def test_attrs_equal():
def test_attrs_hash(): def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash fhash = tvm.ir_pass.AttrsHash
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y}) assert fhash({"x": x}) == fhash({"x": y})
assert fhash({"x": x}) != fhash({"x": [y, 1]}) assert fhash({"x": x}) != fhash({"x": [y, 1]})
assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]}) assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]})
......
...@@ -31,16 +31,16 @@ def test_simplify(): ...@@ -31,16 +31,16 @@ def test_simplify():
def test_verify_ssa(): def test_verify_ssa():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var() y = tvm.var()
z = tvm.make.Evaluate(x + y) z = tvm.tir.Evaluate(x + y)
assert(tvm.ir_pass.VerifySSA(z)) assert(tvm.ir_pass.VerifySSA(z))
def test_convert_ssa(): def test_convert_ssa():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var() y = tvm.var()
let1 = tvm.make.Let(x, 1, x + 1) let1 = tvm.tir.Let(x, 1, x + 1)
let2 = tvm.make.Let(x, 1, x + y) let2 = tvm.tir.Let(x, 1, x + y)
z = tvm.make.Evaluate(let1 + let2) z = tvm.tir.Evaluate(let1 + let2)
assert(not tvm.ir_pass.VerifySSA(z)) assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z) z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa)) assert(tvm.ir_pass.VerifySSA(z_ssa))
......
...@@ -166,12 +166,12 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): ...@@ -166,12 +166,12 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
def test_in_bounds_const_loop_partition_ir(): def test_in_bounds_const_loop_partition_ir():
def check_attr_stmt (x): def check_attr_stmt (x):
if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n): if isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n):
return True return True
return False return False
def check_branch_stmt (x): def check_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse): if isinstance(x, tvm.tir.IfThenElse):
return True return True
return False return False
...@@ -183,7 +183,7 @@ def test_in_bounds_const_loop_partition_ir(): ...@@ -183,7 +183,7 @@ def test_in_bounds_const_loop_partition_ir():
assert (count == nums) assert (count == nums)
def collect_branch_stmt (x): def collect_branch_stmt (x):
if isinstance(x, tvm.stmt.IfThenElse): if isinstance(x, tvm.tir.IfThenElse):
branch_collector.append(x) branch_collector.append(x)
n = 21 n = 21
......
...@@ -20,8 +20,8 @@ def test_for(): ...@@ -20,8 +20,8 @@ def test_for():
dev_type = tvm.var("dev_type") dev_type = tvm.var("dev_type")
def device_context(dev_id): def device_context(dev_id):
ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id) ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.make.Call( return tvm.tir.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
......
...@@ -33,7 +33,7 @@ def test_decorate_device(): ...@@ -33,7 +33,7 @@ def test_decorate_device():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.ir_pass.Simplify(stmt) stmt1 = tvm.ir_pass.Simplify(stmt)
stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1) stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.stmt.AttrStmt) assert isinstance(stmt2, tvm.tir.AttrStmt)
assert stmt2.attr_key == "device_scope" assert stmt2.attr_key == "device_scope"
assert stmt1 == stmt2.body assert stmt1 == stmt2.body
......
...@@ -24,19 +24,19 @@ def verify_structure(stmt, expected_struct): ...@@ -24,19 +24,19 @@ def verify_structure(stmt, expected_struct):
struct = {} struct = {}
def _extract_vars(op): def _extract_vars(op):
global var_list global var_list
if isinstance(op, tvm.expr.Var): if isinstance(op, tvm.tir.Var):
var_list.append(op.name) var_list.append(op.name)
def _visit(op): def _visit(op):
key = op key = op
if isinstance(op, tvm.stmt.IfThenElse): if isinstance(op, tvm.tir.IfThenElse):
global var_list global var_list
tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars)
val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
var_list.clear() var_list.clear()
elif isinstance(op, tvm.stmt.For): elif isinstance(op, tvm.tir.For):
val = [(op.body,), ("For", op.loop_var.name)] val = [(op.body,), ("For", op.loop_var.name)]
elif isinstance(op, tvm.stmt.AttrStmt): elif isinstance(op, tvm.tir.AttrStmt):
val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
else: else:
return return
...@@ -61,9 +61,9 @@ def test_basic(): ...@@ -61,9 +61,9 @@ def test_basic():
with ib.for_range(0, m, "j") as j: with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k: with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)): with ib.if_scope(ib.likely(i < 2)):
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.tir.Evaluate(m))
with ib.else_scope(): with ib.else_scope():
ib.emit(tvm.make.Evaluate(n)) ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get() stmt = ib.get()
new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
...@@ -82,7 +82,7 @@ def test_no_else(): ...@@ -82,7 +82,7 @@ def test_no_else():
with ib.for_range(0, m, "j") as j: with ib.for_range(0, m, "j") as j:
with ib.for_range(0, n, "k") as k: with ib.for_range(0, n, "k") as k:
with ib.if_scope(ib.likely(i < 2)): with ib.if_scope(ib.likely(i < 2)):
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.tir.Evaluate(m))
stmt = ib.get() stmt = ib.get()
new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) new_stmt = tvm.ir_pass.HoistIfThenElse(stmt)
......
...@@ -33,7 +33,7 @@ def test_copy2d(): ...@@ -33,7 +33,7 @@ def test_copy2d():
assert dst.strides[1].value == 1 assert dst.strides[1].value == 1
assert src.strides[0] == l assert src.strides[0] == l
assert tuple(src.shape) == (m, l) assert tuple(src.shape) == (m, l)
return tvm.make.Evaluate(0) return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_copy_pad(): def test_copy_pad():
...@@ -57,7 +57,7 @@ def test_copy_pad(): ...@@ -57,7 +57,7 @@ def test_copy_pad():
assert pad_after[0].value == 1 assert pad_after[0].value == 1
assert pad_after[1].value == 0 assert pad_after[1].value == 0
assert pad_value.value == 1.0 assert pad_value.value == 1.0
return tvm.make.Evaluate(0) return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_single_point_test(): def test_single_point_test():
...@@ -76,7 +76,7 @@ def test_single_point_test(): ...@@ -76,7 +76,7 @@ def test_single_point_test():
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0 assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1 assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1 assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0) return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b): def assert_expr_equal(a, b):
...@@ -109,7 +109,7 @@ def test_copy_pad_split(): ...@@ -109,7 +109,7 @@ def test_copy_pad_split():
assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_before[0], rpad_before)
assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.make.Evaluate(0) return tvm.tir.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
......
...@@ -37,13 +37,13 @@ def test_double_buffer(): ...@@ -37,13 +37,13 @@ def test_double_buffer():
stmt = ib.get() stmt = ib.get()
stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared") f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0] count = [0]
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1 count[0] += 1
tvm.ir_pass.PostOrderVisit(f.body, count_sync) tvm.ir_pass.PostOrderVisit(f.body, count_sync)
assert count[0] == 4 assert count[0] == 4
......
...@@ -20,7 +20,7 @@ def test_inline(): ...@@ -20,7 +20,7 @@ def test_inline():
m = tvm.size_var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.tir.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt) print(stmt)
...@@ -39,11 +39,11 @@ def test_inline2(): ...@@ -39,11 +39,11 @@ def test_inline2():
m = tvm.size_var('m') m = tvm.size_var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.tir.Evaluate(tvm.exp(T[10]) + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op): def check(op):
if isinstance(op, tvm.expr.Call): if isinstance(op, tvm.tir.Call):
assert op.func != T.op assert op.func != T.op
tvm.ir_pass.PostOrderVisit(stmt, check) tvm.ir_pass.PostOrderVisit(stmt, check)
......
...@@ -32,12 +32,12 @@ def test_ir_transform(): ...@@ -32,12 +32,12 @@ def test_ir_transform():
return None return None
def postorder(op): def postorder(op):
assert isinstance(op, tvm.expr.Call) assert isinstance(op, tvm.tir.Call)
if op.name == "TestA": if op.name == "TestA":
return tvm.call_extern("int32", "TestB", op.args[0] + 1) return tvm.call_extern("int32", "TestB", op.args[0] + 1)
return op return op
body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
stmt_list = tvm.make.stmt_list(body.body.body) stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0 assert stmt_list[1].value.value == 0
......
...@@ -20,7 +20,7 @@ def test_coproc_lift(): ...@@ -20,7 +20,7 @@ def test_coproc_lift():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop") cp = tvm.thread_axis((0, 1), "cop")
value = tvm.make.StringImm("xxx") value = tvm.tir.StringImm("xxx")
A = ib.allocate("float32", n, name="A", scope="global") A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment