Unverified Commit 08338dd5 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.te and tvm.driver (#4900)

- Move the related files to tvm.te
- Move build_module.py to tvm.driver
parent 27a02844
......@@ -47,25 +47,30 @@ from . import tir
# tvm.target
from . import target
from .target import build_config
# others
from . import tensor
from . import arith
from . import make
from . import schedule
from . import hybrid
# tvm.te
from .te import decl_tensor_intrin, create_schedule, tag_scope
# tvm.testing
from . import testing
from .api import *
from .tensor_intrin import decl_tensor_intrin
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
# tvm.driver
from .driver import build, lower
# tvm.hybrid
from . import hybrid
# others
from . import arith
# backward compact for topi, to be removed later
from .api import *
from .tir import expr, stmt, ir_builder, ir_pass, generic
from .te import tensor, schedule
from .tir.op import *
from . import intrin
from . import make
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
......@@ -18,17 +18,16 @@
import tvm._ffi
from tvm.runtime import Object
from . import _api_internal
class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self)
return _IntSetIsNothing(self)
def is_everything(self):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)
return _IntSetIsEverything(self)
@tvm._ffi.register_object("arith.IntervalSet")
......
......@@ -29,7 +29,8 @@ There are two types of feature
import struct
import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
from tvm import schedule, ir_pass, get_global_func, target as _target
from tvm.driver import build_module
def ana_lower(sch, args,
binds=None,
......
......@@ -26,8 +26,9 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
import tvm.te._ffi_api
from ... import _api_internal, tensor, placeholder
from ... import tensor, placeholder
from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple
......@@ -420,10 +421,10 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None, o
attrs[k] = v
attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op = tvm.te._ffi_api.ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
elif isinstance(op, tensor.ExternOp):
op = _api_internal._ExternOp(
op = tvm.te._ffi_api.ExternOp(
op.name, op.tag, attrs,
op.inputs, op.input_placeholders,
op.output_placeholders, op.body)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for driver APIs"""
from .build_module import lower, build
......@@ -21,7 +21,7 @@ See the example sections for for suggested message conventions.
To make the code more readable, we recommended developers to
copy the examples and raise errors with the same message convention.
"""
from ._ffi.base import register_error, TVMError
from tvm._ffi.base import register_error, TVMError
@register_error
class InternalError(TVMError):
......
......@@ -30,9 +30,9 @@ HalideIR.
# 2. Support multi-level HalideIR
import inspect
import tvm._ffi
from tvm.driver.build_module import form_body
from .._ffi.base import decorate
from ..build_module import form_body
from .module import HybridModule
from .parser import source_to_op
......
......@@ -26,19 +26,20 @@ import numbers
from enum import Enum
from tvm.ir import Array, Range
import tvm.tir
import tvm.te._ffi_api
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from tvm.te.tensor import Tensor, Operation
from tvm.tir import all as _all
from tvm.tir import any as _any
from .util import _internal_assert
from . import calls
from . import util
from .preprocessor import determine_variable_usage
from ..api import all as _all
from ..api import any as _any
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import api as _api
......@@ -653,7 +654,7 @@ def source_to_op(src, args, symbols, closure_vars):
for i in args:
get_input_tensors(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
op = tvm.te._ffi_api.HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
......@@ -27,9 +27,9 @@ from tvm.ir.container import Array
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.te.tensor import Tensor
from .. import api as _api
from ..tensor import Tensor
#pylint: disable=invalid-name
......
......@@ -17,10 +17,10 @@
"""Common expressions data structures in the IR."""
import tvm._ffi
from .base import Node
from . import _ffi_api
class BaseExpr(Node):
"""Base class of all the expressions."""
......@@ -98,7 +98,29 @@ class Range(Node):
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
Parameters
----------
begin : PrimExpr
The begin value of the range when end is None.
Otherwise it is the length of the range.
end : Optional[PrimExpr]
The end value of the range.
Note
----
The constructor creates the range `[begin, end)`
if the end argument is not None. Otherwise, it creates `[0, begin)`.
"""
def __init__(self, begin, end=None):
if end is None:
self.__init_handle_by_constructor__(
_ffi_api.Range, 0, begin)
else:
self.__init_handle_by_constructor__(
_ffi_api.Range, begin, end)
@staticmethod
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
......
......@@ -16,10 +16,9 @@
# under the License.
"""The interface of expr function exposed from C++."""
import tvm._ffi
import tvm.driver
from tvm.ir import container as _container
from ... import build_module as _build
@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
......@@ -48,7 +47,7 @@ def lower(sch, inputs, func_name, source_func):
import traceback
try:
f = _build.lower(sch, inputs, name=func_name)
f = tvm.driver.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
......@@ -85,7 +84,7 @@ def build(funcs, target, target_host=None):
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)
return tvm.driver.build(funcs, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr")
......
......@@ -18,11 +18,11 @@
"""The base node types for the Relay language."""
import topi
import tvm._ffi
from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr
from ...api import register_func
from ...build_module import lower, build
from . import _make
@register_relay_node
......
......@@ -20,6 +20,7 @@ import logging
import multiprocessing as mp
import numpy as np
import tvm
import tvm.driver
from tvm.ir import IRModule
from . import _quantize
......
......@@ -61,3 +61,4 @@ from .generic_func import generic_func, get_native_generic_func, override_native
from . import datatype
from . import codegen
from .intrin import register_intrin_rule
from .build_config import BuildConfig, build_config
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin, wildcard-import
"""Namespace for Tensor-level IR"""
# expose all operators in tvm tir.op
from tvm.tir.op import *
from .schedule import Schedule, create_schedule
from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.te"""
import tvm._ffi
tvm._ffi._init_api("te", __name__)
......@@ -16,7 +16,7 @@
# under the License.
"""Tag class for TVM operators."""
import warnings
from ._ffi.base import decorate
from tvm._ffi.base import decorate
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
......
......@@ -14,15 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor and Operation class for computation declaration."""
"""Tensor class for computation declaration."""
# pylint: disable=invalid-name
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from . import _api_internal
from . import _ffi_api
class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
......@@ -52,9 +51,6 @@ class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None
@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
......@@ -68,7 +64,7 @@ class Tensor(Object, _expr.ExprOp):
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, iter_var_cls):
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")
......@@ -81,7 +77,7 @@ class Tensor(Object, _expr.ExprOp):
return TensorSlice(self, indices)
def __hash__(self):
return _api_internal._TensorHash(self)
return _ffi_api.TensorHash(self)
def __eq__(self, other):
if not isinstance(other, Tensor):
......@@ -92,7 +88,7 @@ class Tensor(Object, _expr.ExprOp):
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison")
return _api_internal._TensorEqual(self, other)
return _ffi_api.TensorEqual(self, other)
@property
def ndim(self):
......@@ -143,17 +139,17 @@ class Operation(Object):
out : Tensor
The i-th output.
"""
return _api_internal._OpGetOutput(self, index)
return _ffi_api.OpGetOutput(self, index)
@property
def num_outputs(self):
"""Number of outputs from this op."""
return _api_internal._OpNumOutputs(self)
return _ffi_api.OpNumOutputs(self)
@property
def input_tensors(self):
"""List of input tensors to this op."""
return _api_internal._OpInputTensors(self)
return _ffi_api.OpInputTensors(self)
@tvm._ffi.register_object
......
......@@ -16,17 +16,15 @@
# under the License.
"""Tensor intrinsics"""
import tvm._ffi
import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, convert
from tvm.ir import Range
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.target import BuildConfig
from .tensor import PlaceholderOp
from . import _api_internal
from . import api as _api
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from . import _ffi_api
def _get_region(tslice):
......@@ -34,15 +32,16 @@ def _get_region(tslice):
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
region.append(Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
if isinstance(idx, tvm.tir.IterVar):
begin = idx.var
else:
begin = idx
region.append(Range.make_by_min_extent(begin, 1))
return region
@tvm._ffi.register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
......@@ -60,10 +59,11 @@ class TensorIntrin(Object):
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
reduce_axis = convert(reduce_axis)
if scalar_inputs:
scalar_inputs = _api.convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
scalar_inputs = convert(scalar_inputs)
return _ffi_api.TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
def decl_tensor_intrin(op,
fcompute,
......@@ -119,15 +119,15 @@ def decl_tensor_intrin(op,
binds_list = []
for t in inputs:
if not isinstance(t.op, _tensor.PlaceholderOp):
if not isinstance(t.op, PlaceholderOp):
raise ValueError("Do not yet support composition op")
cfg = current_build_config()
cfg = BuildConfig.current()
for t in tensors:
buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
binds_list.append(buf)
if scalar_params:
......@@ -135,10 +135,10 @@ def decl_tensor_intrin(op,
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
body = [body]
body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
return _ffi_api.TensorIntrin(
name, op, inputs, binds_list, scalar_params, *body)
......@@ -17,6 +17,8 @@
""" TVM testing utilities """
import logging
import numpy as np
import tvm._ffi
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
""" Version of np.testing.assert_allclose with `atol` and `rtol` fields set
......@@ -161,3 +163,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
logging.info("Numerical grad test wrt '%s' of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff)
tvm._ffi._init_api("testing", __name__)
......@@ -23,16 +23,18 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, min_value, max_value
from .op import call_llvm_intrin, all, any, min_value, max_value
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from . import ir_builder
from . import ir_pass
......@@ -309,6 +309,57 @@ class SizeVar(Var):
@tvm._ffi.register_object
class IterVar(Object, ExprOp):
"""Represent iteration variable.
IterVar represents axis iterations in the computation.
Parameters
----------
dom : Range
The domain of the iteration.
var : Union[Var, str]
The internal variable that is used for iteration.
iter_type : int
The iteration type.
thread_tag : str
The thread type tag.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
def __init__(self, dom, var, iter_type, thread_tag=""):
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise TypeError("need to be list of ranges")
dom = tvm.ir.Range(dom[0], dom[1])
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
@tvm._ffi.register_object
class CommReducer(Object):
"""Communicative reduce operator
......
......@@ -14,13 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
import tvm._ffi
from tvm.runtime import convert, const
from tvm.schedule import Buffer
from tvm.ir import Array
from .expr import Call
from .buffer import Buffer
from .expr import Call, Var, CommReducer
from . import _ffi_api
......@@ -196,6 +197,53 @@ def call_llvm_intrin(dtype, name, *args):
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpOr(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpAnd(ret, args[i])
return ret
@tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args):
print(list(args))
......@@ -780,3 +828,137 @@ def floormod(a, b):
The result expression.
"""
return _ffi_api._OpFloorMod(a, b)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = tvm.var("n")
m = tvm.var("m")
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name="A")
k = tvm.reduce_axis((0, m), name="k")
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + "_" + str(i)
larr.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rarr.append(Var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, tvm.ir.PrimExpr)
size = 1
dtype = expr.dtype
lvar = Var(code.co_varnames[0], dtype)
rvar = Var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = convert(id_elem)
combiner = CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
if where is None:
assert not args
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : PrimExpr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : PrimExpr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y), max_value, name="min")
max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y), min_value, name="max")
......@@ -64,16 +64,16 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("_IntervalSetGetMin")
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin")
.set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("_IntervalSetGetMax")
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax")
.set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("_IntSetIsNothing")
TVM_REGISTER_GLOBAL("arith._IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("_IntSetIsEverything")
TVM_REGISTER_GLOBAL("arith._IntSetIsEverything")
.set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
......
......@@ -40,115 +40,113 @@ TVM_REGISTER_GLOBAL("tir.min_value")
TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("Range")
TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = Range(0, args[0]);
} else {
*ret = Range(args[0], args[1]);
}
*ret = Range(args[0], args[1]);
});
namespace tir {
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
}
namespace te {
TVM_REGISTER_GLOBAL("_Tensor")
TVM_REGISTER_GLOBAL("te.Tensor")
.set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrin")
TVM_REGISTER_GLOBAL("te.TensorIntrin")
.set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrinCall")
TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
.set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_GLOBAL("_TensorEqual")
TVM_REGISTER_GLOBAL("te.TensorEqual")
.set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("_TensorHash")
TVM_REGISTER_GLOBAL("te.TensorHash")
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_GLOBAL("_Placeholder")
TVM_REGISTER_GLOBAL("te.Placeholder")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
TVM_REGISTER_GLOBAL("_ComputeOp")
TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed(ComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ScanOp")
TVM_REGISTER_GLOBAL("te.ScanOp")
.set_body_typed(ScanOpNode::make);
TVM_REGISTER_GLOBAL("_TensorComputeOp")
TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ExternOp")
TVM_REGISTER_GLOBAL("te.ExternOp")
.set_body_typed(ExternOpNode::make);
TVM_REGISTER_GLOBAL("_HybridOp")
TVM_REGISTER_GLOBAL("te.HybridOp")
.set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("_OpGetOutput")
TVM_REGISTER_GLOBAL("te.OpGetOutput")
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
TVM_REGISTER_GLOBAL("_OpNumOutputs")
TVM_REGISTER_GLOBAL("te.OpNumOutputs")
.set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_GLOBAL("_OpInputTensors")
TVM_REGISTER_GLOBAL("te.OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("_IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
TVM_REGISTER_GLOBAL("_CreateSchedule")
TVM_REGISTER_GLOBAL("te.CreateSchedule")
.set_body_typed(create_schedule);
TVM_REGISTER_GLOBAL("_StageSetScope")
TVM_REGISTER_GLOBAL("te.StageSetScope")
.set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("_StageBind")
TVM_REGISTER_GLOBAL("te.StageBind")
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor")
TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageSplitByNParts")
TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageFuse")
TVM_REGISTER_GLOBAL("te.StageFuse")
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
});
TVM_REGISTER_GLOBAL("_StageComputeAt")
TVM_REGISTER_GLOBAL("te.StageComputeAt")
.set_body_method(&Stage::compute_at);
TVM_REGISTER_GLOBAL("_StageComputeInline")
TVM_REGISTER_GLOBAL("te.StageComputeInline")
.set_body_method(&Stage::compute_inline);
TVM_REGISTER_GLOBAL("_StageComputeRoot")
TVM_REGISTER_GLOBAL("te.StageComputeRoot")
.set_body_method(&Stage::compute_root);
TVM_REGISTER_GLOBAL("_StageReorder")
TVM_REGISTER_GLOBAL("te.StageReorder")
.set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("_StageTile")
TVM_REGISTER_GLOBAL("te.StageTile")
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
......@@ -162,49 +160,49 @@ TVM_REGISTER_GLOBAL("_StageTile")
return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_GLOBAL("_StageEnvThreads")
TVM_REGISTER_GLOBAL("te.StageEnvThreads")
.set_body_method(&Stage::env_threads);
TVM_REGISTER_GLOBAL("_StageSetStorePredicate")
TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
.set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_GLOBAL("_StageUnroll")
TVM_REGISTER_GLOBAL("te.StageUnroll")
.set_body_method(&Stage::unroll);
TVM_REGISTER_GLOBAL("_StageVectorize")
TVM_REGISTER_GLOBAL("te.StageVectorize")
.set_body_method(&Stage::vectorize);
TVM_REGISTER_GLOBAL("_StageTensorize")
TVM_REGISTER_GLOBAL("te.StageTensorize")
.set_body_method(&Stage::tensorize);
TVM_REGISTER_GLOBAL("_StageParallel")
TVM_REGISTER_GLOBAL("te.StageParallel")
.set_body_method(&Stage::parallel);
TVM_REGISTER_GLOBAL("_StagePragma")
TVM_REGISTER_GLOBAL("te.StagePragma")
.set_body_method(&Stage::pragma);
TVM_REGISTER_GLOBAL("_StagePrefetch")
TVM_REGISTER_GLOBAL("te.StagePrefetch")
.set_body_method(&Stage::prefetch);
TVM_REGISTER_GLOBAL("_StageStorageAlign")
TVM_REGISTER_GLOBAL("te.StageStorageAlign")
.set_body_method(&Stage::storage_align);
TVM_REGISTER_GLOBAL("_StageDoubleBuffer")
TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
.set_body_method(&Stage::double_buffer);
TVM_REGISTER_GLOBAL("_StageOpenGL")
TVM_REGISTER_GLOBAL("te.StageOpenGL")
.set_body_method(&Stage::opengl);
TVM_REGISTER_GLOBAL("_ScheduleNormalize")
TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
.set_body_method(&Schedule::normalize);
TVM_REGISTER_GLOBAL("_ScheduleCreateGroup")
TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
.set_body_method(&Schedule::create_group);
TVM_REGISTER_GLOBAL("_ScheduleCacheRead")
TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
.set_body_method(&Schedule::cache_read);
TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Tensor>()) {
*ret = args[0].operator Schedule()
......@@ -215,11 +213,11 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
}
});
TVM_REGISTER_GLOBAL("_ScheduleRFactor")
TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
} // namespace te
TVM_REGISTER_GLOBAL("_CommReducerCombine")
TVM_REGISTER_GLOBAL("te.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
} // namespace tvm
......@@ -47,9 +47,9 @@ TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS(PassName) \
#define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_GLOBAL("schedule."#PassName) \
.set_body_typed(PassName); \
.set_body_typed(PassName); \
REGISTER_SCHEDULE_PASS(InferBound);
......
......@@ -54,11 +54,11 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_REGISTER_GLOBAL("_nop")
TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_GLOBAL("_test_wrap_callback")
TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
*ret = runtime::TypedPackedFunc<void()>([pf](){
......@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("_test_wrap_callback")
});
});
TVM_REGISTER_GLOBAL("_test_raise_error_callback")
TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void()>([msg](){
......@@ -74,7 +74,7 @@ TVM_REGISTER_GLOBAL("_test_raise_error_callback")
});
});
TVM_REGISTER_GLOBAL("_test_check_eq_callback")
TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
......@@ -82,7 +82,7 @@ TVM_REGISTER_GLOBAL("_test_check_eq_callback")
});
});
TVM_REGISTER_GLOBAL("_context_test")
TVM_REGISTER_GLOBAL("testing.context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0];
int dtype = args[1];
......@@ -103,11 +103,11 @@ void ErrorTest(int x, int y) {
}
}
TVM_REGISTER_GLOBAL("_ErrorTest")
TVM_REGISTER_GLOBAL("testing.ErrorTest")
.set_body_typed(ErrorTest);
// internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("_ndarray_use_count")
TVM_REGISTER_GLOBAL("testing.ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
......
......@@ -403,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});
TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig")
TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
......@@ -418,13 +418,13 @@ class BuildConfig::Internal {
}
};
TVM_REGISTER_GLOBAL("_EnterBuildConfigScope")
TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope")
.set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
.set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
......@@ -437,7 +437,7 @@ TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.schedule import Buffer
from tvm.tir import Buffer
import numpy as np
def test_buffer():
......@@ -25,7 +25,7 @@ def test_buffer():
Ab = tvm.decl_buffer((m, n), tvm.float32)
Bb = tvm.decl_buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.schedule.Buffer)
assert isinstance(Ab, tvm.tir.Buffer)
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
......
......@@ -22,8 +22,8 @@ def test_expr_constructor():
assert x.name == "xx"
x = tvm.tir.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)],
None, 0)
[tvm.tir.IterVar((0, 1), "x", 2)],
None, 0)
assert isinstance(x, tvm.tir.Reduce)
assert x.combiner == None
assert x.value_index == 0
......
......@@ -16,9 +16,10 @@
# under the License.
"""Test runtime error handling"""
import tvm
import tvm.testing
def test_op_translation():
ferror = tvm._api_internal._test_raise_error_callback(
ferror = tvm.testing.test_raise_error_callback(
"OpNotImplemented: myop")
try:
ferror()
......@@ -28,7 +29,7 @@ def test_op_translation():
assert isinstance(e, NotImplementedError)
assert msg.find("api_test.cc") != -1
fchk_eq = tvm._api_internal._test_check_eq_callback(
fchk_eq = tvm.testing.test_check_eq_callback(
"InternalError: myop")
try:
fchk_eq(0, 1)
......@@ -38,7 +39,7 @@ def test_op_translation():
assert msg.find("api_test.cc") != -1
try:
tvm._api_internal._ErrorTest(0, 1)
tvm.testing.ErrorTest(0, 1)
assert False
except ValueError as e:
msg = str(e)
......@@ -48,13 +49,13 @@ def test_op_translation():
def test_deep_callback():
def error_callback():
raise ValueError("callback error")
wrap1 = tvm._api_internal._test_wrap_callback(error_callback)
wrap1 = tvm.testing.test_wrap_callback(error_callback)
def flevel2():
wrap1()
wrap2 = tvm._api_internal._test_wrap_callback(flevel2)
wrap2 = tvm.testing.test_wrap_callback(flevel2)
def flevel3():
wrap2()
wrap3 = tvm._api_internal._test_wrap_callback(flevel3)
wrap3 = tvm.testing.test_wrap_callback(flevel3)
try:
wrap3()
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import numpy as np
def test_get_global():
......@@ -93,7 +94,7 @@ def test_ctx():
x = test_ctx_func(tvm.gpu(7))
assert x == tvm.cpu(0)
x = tvm.opencl(10)
x = tvm._api_internal._context_test(x, x.device_type, x.device_id)
x = tvm.testing.context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10)
def test_trace_default_action():
......@@ -282,4 +283,3 @@ if __name__ == "__main__":
test_trace_default_action()
test_trace_can_change_traced_value_int()
test_trace_can_change_traced_value_float()
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import os
import logging
import time
......@@ -210,7 +211,7 @@ def test_rpc_return_ndarray():
if name == "get_arr":
return lambda : nd
elif name == "ref_count":
return lambda : tvm._api_internal._ndarray_use_count(nd)
return lambda : tvm.testing.ndarray_use_count(nd)
elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem":
......
......@@ -96,7 +96,7 @@ def lower(*args, **kwargs):
--------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
......@@ -113,7 +113,7 @@ def build(*args, **kwargs):
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment