Unverified Commit 08338dd5 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.te and tvm.driver (#4900)

- Move the related files to tvm.te
- Move build_module.py to tvm.driver
parent 27a02844
...@@ -47,25 +47,30 @@ from . import tir ...@@ -47,25 +47,30 @@ from . import tir
# tvm.target # tvm.target
from . import target from . import target
from .target import build_config
# others # tvm.te
from . import tensor from .te import decl_tensor_intrin, create_schedule, tag_scope
from . import arith
from . import make # tvm.testing
from . import schedule
from . import hybrid
from . import testing from . import testing
from .api import * # tvm.driver
from .tensor_intrin import decl_tensor_intrin from .driver import build, lower
from .schedule import create_schedule
from .build_module import build, lower, build_config # tvm.hybrid
from .tag import tag_scope from . import hybrid
# others
from . import arith
# backward compact for topi, to be removed later # backward compact for topi, to be removed later
from .api import *
from .tir import expr, stmt, ir_builder, ir_pass, generic from .tir import expr, stmt, ir_builder, ir_pass, generic
from .te import tensor, schedule
from .tir.op import * from .tir.op import *
from . import intrin from . import intrin
from . import make
# Contrib initializers # Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
...@@ -16,623 +16,23 @@ ...@@ -16,623 +16,23 @@
# under the License. # under the License.
"""Functions defined in TVM.""" """Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin # pylint: disable=invalid-name,unused-import,redefined-builtin
from numbers import Integral as _Integral
import tvm._ffi import tvm._ffi
import tvm.ir import tvm.ir
import tvm.tir
from tvm.runtime import convert, const, DataType from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container from tvm.ir import container as _container, Range
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import decl_buffer, layout, bijective_layout from tvm.tir import decl_buffer, layout, bijective_layout
from tvm.tir import min_value, max_value, indexdiv, indexmod from tvm.tir import min_value, max_value, indexdiv, indexmod, all, any
import tvm.tir._ffi_api from tvm.te import placeholder, compute, scan, extern, var, size_var, thread_axis, reduce_axis
from ._ffi.base import string_types, TVMError from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from . import _api_internal
from . import make as _make from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from . import tag as _tag
int8 = "int8" int8 = "int8"
int32 = "int32" int32 = "int32"
float32 = "float32" float32 = "float32"
handle = "handle" handle = "handle"
def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : Var
The result symbolic variable.
"""
return _expr.Var(name, dtype)
def size_var(name="size", dtype=int32):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return _expr.SizeVar(name, dtype)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = tvm.tir._ffi_api._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = tvm.tir._ffi_api._OpOr(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = tvm.tir._ffi_api._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = tvm.tir._ffi_api._OpAnd(ret, args[i])
return ret
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
dtype = float32 if dtype is None else dtype
return _api_internal._Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of indices-> value
Specifies the input source expression
name: str, optional
The name hint of the tensor
tag: str, optional
Additional tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor
The created tensor
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis.
Parameters
----------
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
.. code-block:: python
# The following code is equivalent to numpy.cumsum
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape,
inputs,
fcompute,
name="extern",
dtype=None,
in_buffers=None,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: tuple or list of tuples.
The shape of the outputs.
inputs: list of Tensor
The inputs
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
dtype: str or list of str, optional
The data types of outputs,
by default dtype will be same as inputs.
in_buffers: Buffer or list of Buffer, optional
Input buffers.
out_buffers: Buffer or list of Buffers, optional
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
if shape == () or isinstance(shape[0], (_expr.PrimExpr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
% (len(inputs), len(in_buffers)))
if out_buffers is not None:
out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
if in_buffers is None:
input_placeholders.append(
decl_buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
if len(types) != 1:
raise ValueError("Cannot infer output type, please provide dtype argument")
infered_type = types.pop()
dtype = [infered_type for _ in shape]
if isinstance(dtype, str):
dtype = [dtype]
if out_buffers is None:
for shp, dt in zip(shape, dtype):
output_placeholders.append(decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, _expr.PrimExpr):
body = _stmt.Evaluate(body)
op = _api_internal._ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of iteration variable.
iter_type : int
The type of iteration.
thread_tag : str
The thread tag of the iteration variable.
Returns
-------
iter_var : IterVar
The result itervar
"""
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
return _api_internal._IterVar(dom, v, iter_type, thread_tag)
def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str, optional
The thread tag
name : str, optional
The name of the var.
Returns
-------
axis : IterVar
The thread itervar.
"""
if isinstance(dom, string_types):
tag, dom = dom, None
if not tag:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return _IterVar(dom, name, 1, tag)
def reduce_axis(dom, name="rv"):
"""Create a new IterVar for reduction.
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of the variable.
Returns
-------
axis : IterVar
An iteration variable representing the value.
"""
return _IterVar(dom, name, 2)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, _container.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.PrimExpr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = convert(id_elem)
combiner = _expr.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
if where is None:
assert not args
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : Expr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : Expr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMax(x, y), min_value, name='max')
tvm._ffi._init_api("tvm.api")
...@@ -18,17 +18,16 @@ ...@@ -18,17 +18,16 @@
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from . import _api_internal
class IntSet(Object): class IntSet(Object):
"""Represent a set of integer in one dimension.""" """Represent a set of integer in one dimension."""
def is_nothing(self): def is_nothing(self):
"""Whether the set represent nothing""" """Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self) return _IntSetIsNothing(self)
def is_everything(self): def is_everything(self):
"""Whether the set represent everything""" """Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self) return _IntSetIsEverything(self)
@tvm._ffi.register_object("arith.IntervalSet") @tvm._ffi.register_object("arith.IntervalSet")
......
...@@ -29,7 +29,8 @@ There are two types of feature ...@@ -29,7 +29,8 @@ There are two types of feature
import struct import struct
import numpy as np import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target from tvm import schedule, ir_pass, get_global_func, target as _target
from tvm.driver import build_module
def ana_lower(sch, args, def ana_lower(sch, args,
binds=None, binds=None,
......
...@@ -26,8 +26,9 @@ tuple. ...@@ -26,8 +26,9 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
""" """
import tvm.te._ffi_api
from ... import _api_internal, tensor, placeholder from ... import tensor, placeholder
from .task import args_to_workload, dispatcher, register from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple from ..util import get_const_tuple
...@@ -420,10 +421,10 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None, o ...@@ -420,10 +421,10 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None, o
attrs[k] = v attrs[k] = v
attrs['workload'] = args_to_workload(args, topi_compute) attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp): if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp( op = tvm.te._ffi_api.ComputeOp(
op.name, op.tag, attrs, op.axis, op.body) op.name, op.tag, attrs, op.axis, op.body)
elif isinstance(op, tensor.ExternOp): elif isinstance(op, tensor.ExternOp):
op = _api_internal._ExternOp( op = tvm.te._ffi_api.ExternOp(
op.name, op.tag, attrs, op.name, op.tag, attrs,
op.inputs, op.input_placeholders, op.inputs, op.input_placeholders,
op.output_placeholders, op.body) op.output_placeholders, op.body)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for driver APIs"""
from .build_module import lower, build
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The build utils in python.
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
import warnings
import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = BuildConfig.current()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
Parameters
----------
sch : tvm.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str, optional
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
f : LoweredFunc or Stmt
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
cfg = BuildConfig.current()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
mdev : tvm.module
A module that contains device code.
"""
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if BuildConfig.current().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
def build(inputs,
args=None,
target=None,
target_host=None,
name="default_function",
binds=None):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
target_flist = inputs
for tar, flist in target_flist.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not target_host:
for tar, _ in target_flist.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
...@@ -21,7 +21,7 @@ See the example sections for for suggested message conventions. ...@@ -21,7 +21,7 @@ See the example sections for for suggested message conventions.
To make the code more readable, we recommended developers to To make the code more readable, we recommended developers to
copy the examples and raise errors with the same message convention. copy the examples and raise errors with the same message convention.
""" """
from ._ffi.base import register_error, TVMError from tvm._ffi.base import register_error, TVMError
@register_error @register_error
class InternalError(TVMError): class InternalError(TVMError):
......
...@@ -30,9 +30,9 @@ HalideIR. ...@@ -30,9 +30,9 @@ HalideIR.
# 2. Support multi-level HalideIR # 2. Support multi-level HalideIR
import inspect import inspect
import tvm._ffi import tvm._ffi
from tvm.driver.build_module import form_body
from .._ffi.base import decorate from .._ffi.base import decorate
from ..build_module import form_body
from .module import HybridModule from .module import HybridModule
from .parser import source_to_op from .parser import source_to_op
......
...@@ -26,19 +26,20 @@ import numbers ...@@ -26,19 +26,20 @@ import numbers
from enum import Enum from enum import Enum
from tvm.ir import Array, Range from tvm.ir import Array, Range
import tvm.tir import tvm.tir
import tvm.te._ffi_api
from tvm.tir import expr as _expr from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass from tvm.tir import ir_pass as _ir_pass
from tvm.te.tensor import Tensor, Operation
from tvm.tir import all as _all
from tvm.tir import any as _any
from .util import _internal_assert from .util import _internal_assert
from . import calls from . import calls
from . import util from . import util
from .preprocessor import determine_variable_usage from .preprocessor import determine_variable_usage
from ..api import all as _all
from ..api import any as _any
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import api as _api from .. import api as _api
...@@ -653,7 +654,7 @@ def source_to_op(src, args, symbols, closure_vars): ...@@ -653,7 +654,7 @@ def source_to_op(src, args, symbols, closure_vars):
for i in args: for i in args:
get_input_tensors(i) get_input_tensors(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, op = tvm.te._ffi_api.HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body) parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))] res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
...@@ -27,9 +27,9 @@ from tvm.ir.container import Array ...@@ -27,9 +27,9 @@ from tvm.ir.container import Array
from tvm.tir import expr as _expr from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt from tvm.tir import stmt as _stmt
from tvm.te.tensor import Tensor
from .. import api as _api from .. import api as _api
from ..tensor import Tensor
#pylint: disable=invalid-name #pylint: disable=invalid-name
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
"""Common expressions data structures in the IR.""" """Common expressions data structures in the IR."""
import tvm._ffi import tvm._ffi
from .base import Node from .base import Node
from . import _ffi_api from . import _ffi_api
class BaseExpr(Node): class BaseExpr(Node):
"""Base class of all the expressions.""" """Base class of all the expressions."""
...@@ -98,7 +98,29 @@ class Range(Node): ...@@ -98,7 +98,29 @@ class Range(Node):
You do not need to create a Range explicitly. You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions. Python lists and tuples will be converted automatically to a Range in API functions.
Parameters
----------
begin : PrimExpr
The begin value of the range when end is None.
Otherwise it is the length of the range.
end : Optional[PrimExpr]
The end value of the range.
Note
----
The constructor creates the range `[begin, end)`
if the end argument is not None. Otherwise, it creates `[0, begin)`.
""" """
def __init__(self, begin, end=None):
if end is None:
self.__init_handle_by_constructor__(
_ffi_api.Range, 0, begin)
else:
self.__init_handle_by_constructor__(
_ffi_api.Range, begin, end)
@staticmethod @staticmethod
def make_by_min_extent(min_value, extent): def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent. """Construct a Range by min and extent.
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
# under the License. # under the License.
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
import tvm._ffi import tvm._ffi
import tvm.driver
from tvm.ir import container as _container from tvm.ir import container as _container
from ... import build_module as _build
@tvm._ffi.register_func("relay.backend.lower") @tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func): def lower(sch, inputs, func_name, source_func):
...@@ -48,7 +47,7 @@ def lower(sch, inputs, func_name, source_func): ...@@ -48,7 +47,7 @@ def lower(sch, inputs, func_name, source_func):
import traceback import traceback
try: try:
f = _build.lower(sch, inputs, name=func_name) f = tvm.driver.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name) # logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception: except Exception:
...@@ -85,7 +84,7 @@ def build(funcs, target, target_host=None): ...@@ -85,7 +84,7 @@ def build(funcs, target, target_host=None):
""" """
if target_host == "": if target_host == "":
target_host = None target_host = None
return _build.build(funcs, target=target, target_host=target_host) return tvm.driver.build(funcs, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr") @tvm._ffi.register_func("relay._tensor_value_repr")
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
import topi import topi
import tvm._ffi import tvm._ffi
from tvm.driver import lower, build
from ..base import register_relay_node from ..base import register_relay_node
from ..expr import RelayExpr from ..expr import RelayExpr
from ...api import register_func from ...api import register_func
from ...build_module import lower, build
from . import _make from . import _make
@register_relay_node @register_relay_node
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import multiprocessing as mp import multiprocessing as mp
import numpy as np import numpy as np
import tvm import tvm
import tvm.driver
from tvm.ir import IRModule from tvm.ir import IRModule
from . import _quantize from . import _quantize
......
...@@ -61,3 +61,4 @@ from .generic_func import generic_func, get_native_generic_func, override_native ...@@ -61,3 +61,4 @@ from .generic_func import generic_func, get_native_generic_func, override_native
from . import datatype from . import datatype
from . import codegen from . import codegen
from .intrin import register_intrin_rule from .intrin import register_intrin_rule
from .build_config import BuildConfig, build_config
...@@ -14,31 +14,16 @@ ...@@ -14,31 +14,16 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""The build utils in python. """Target dependent BuildConfig for low-level passes."""
# TODO(tvm-team) consolidate with PassContext
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
import warnings
import tvm._ffi import tvm._ffi
import tvm.runtime import tvm.ir
from tvm.runtime import Object, ndarray from tvm.runtime import Object
from tvm.ir import container from tvm.ir import container
from tvm.target import codegen
from tvm.tir import expr
from tvm.tir import ir_pass
from tvm.tir import Stmt from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc from tvm.tir.stmt import LoweredFunc
from . import _ffi_api
from . import target as _target
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import make
class DumpIR(object): class DumpIR(object):
...@@ -166,11 +151,11 @@ class BuildConfig(Object): ...@@ -166,11 +151,11 @@ class BuildConfig(Object):
@property @property
def add_lower_pass(self): def add_lower_pass(self):
size = _api_internal._BuildConfigGetAddLowerPassInfo(self) size = _ffi_api.BuildConfigGetAddLowerPassInfo(self)
result = [] result = []
for i in range(size): for i in range(size):
phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True) phase = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, True)
func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False) func = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, False)
result += [(phase, func)] result += [(phase, func)]
return result return result
...@@ -179,11 +164,11 @@ class BuildConfig(Object): ...@@ -179,11 +164,11 @@ class BuildConfig(Object):
add_lower_pass_args = [] add_lower_pass_args = []
for x in value: for x in value:
add_lower_pass_args += [x[0], x[1]] add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args) _ffi_api.BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
def __enter__(self): def __enter__(self):
# pylint: disable=protected-access # pylint: disable=protected-access
_api_internal._EnterBuildConfigScope(self) _ffi_api.EnterBuildConfigScope(self)
if self.dump_pass_ir: if self.dump_pass_ir:
BuildConfig._dump_ir.enter() BuildConfig._dump_ir.enter()
return self return self
...@@ -191,7 +176,7 @@ class BuildConfig(Object): ...@@ -191,7 +176,7 @@ class BuildConfig(Object):
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
if self.dump_pass_ir: if self.dump_pass_ir:
BuildConfig._dump_ir.exit() BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope(self) _ffi_api.ExitBuildConfigScope(self)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in BuildConfig._object_defaults: if name in BuildConfig._object_defaults:
...@@ -199,10 +184,10 @@ class BuildConfig(Object): ...@@ -199,10 +184,10 @@ class BuildConfig(Object):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name)) "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value) return super(BuildConfig, self).__setattr__(name, value)
@staticmethod
def current_build_config(): def current():
"""Get the current build configuration.""" """Get the current build configuration."""
return _api_internal._GetCurrentBuildConfig() return _ffi_api.GetCurrentBuildConfig()
def build_config(**kwargs): def build_config(**kwargs):
...@@ -261,393 +246,9 @@ def build_config(**kwargs): ...@@ -261,393 +246,9 @@ def build_config(**kwargs):
""" """
node_args = {k: v if k not in kwargs else kwargs[k] node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in BuildConfig._object_defaults.items()} for k, v in BuildConfig._object_defaults.items()}
config = make.node("BuildConfig", **node_args) config = tvm.ir.make_node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs: if "add_lower_pass" in kwargs:
config.add_lower_pass = kwargs["add_lower_pass"] config.add_lower_pass = kwargs["add_lower_pass"]
return config return config
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = current_build_config()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, expr.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
Parameters
----------
sch : tvm.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str, optional
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
f : LoweredFunc or Stmt
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
mdev : tvm.module
A module that contains device code.
"""
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
def build(inputs,
args=None,
target=None,
target_host=None,
name="default_function",
binds=None):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
target_flist = inputs
for tar, flist in target_flist.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not target_host:
for tar, _ in target_flist.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin, wildcard-import
"""Namespace for Tensor-level IR"""
# expose all operators in tvm tir.op
from tvm.tir.op import *
from .schedule import Schedule, create_schedule
from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.te"""
import tvm._ffi
tvm._ffi._init_api("te", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Operation class for computation declaration."""
# pylint: disable=invalid-name
from numbers import Integral as _Integral
import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
from tvm._ffi.base import string_types
from tvm.runtime import convert
from . import tag as _tag
from . import tensor as _tensor
from . import _ffi_api
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
dtype = "float32" if dtype is None else dtype
return _ffi_api.Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of indices-> value
Specifies the input source expression
name: str, optional
The name hint of the tensor
tag: str, optional
Additional tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor
The created tensor
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(tvm.tir.IterVar((0, s), var_name, 4))
op_node = _ffi_api.TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _ffi_api.ComputeOp(
name, tag, attrs, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis.
Parameters
----------
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
.. code-block:: python
# The following code is equivalent to numpy.cumsum
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _ffi_api.ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape,
inputs,
fcompute,
name="extern",
dtype=None,
in_buffers=None,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: tuple or list of tuples.
The shape of the outputs.
inputs: list of Tensor
The inputs
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
dtype: str or list of str, optional
The data types of outputs,
by default dtype will be same as inputs.
in_buffers: Buffer or list of Buffer, optional
Input buffers.
out_buffers: Buffer or list of Buffers, optional
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name="A")
B = tvm.placeholder((l, m), name="B")
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (tvm.tir.PrimExpr, _Integral)) else shape
if shape == () or isinstance(shape[0], (tvm.tir.PrimExpr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
% (len(inputs), len(in_buffers)))
if out_buffers is not None:
out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
if in_buffers is None:
input_placeholders.append(
tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
if len(types) != 1:
raise ValueError("Cannot infer output type, please provide dtype argument")
infered_type = types.pop()
dtype = [infered_type for _ in shape]
if isinstance(dtype, str):
dtype = [dtype]
if out_buffers is None:
for shp, dt in zip(shape, dtype):
output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, tvm.tir.PrimExpr):
body = tvm.tir.Evaluate(body)
op = _ffi_api.ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
def var(name="tindex", dtype="int32"):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : Var
The result symbolic variable.
"""
return tvm.tir.Var(name, dtype)
def size_var(name="size", dtype="int32"):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return tvm.tir.SizeVar(name, dtype)
def thread_axis(dom=None, tag="", name=""):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str, optional
The thread tag
name : str, optional
The name of the var.
Returns
-------
axis : IterVar
The thread itervar.
"""
if isinstance(dom, string_types):
tag, dom = dom, None
if not tag:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return tvm.tir.IterVar(dom, name, 1, tag)
def reduce_axis(dom, name="rv"):
"""Create a new IterVar for reduction.
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of the variable.
Returns
-------
axis : IterVar
An iteration variable representing the value.
"""
return tvm.tir.IterVar(dom, name, 2)
...@@ -21,10 +21,10 @@ from tvm._ffi.base import string_types ...@@ -21,10 +21,10 @@ from tvm._ffi.base import string_types
from tvm.runtime import Object, convert from tvm.runtime import Object, convert
from tvm.ir import container as _container from tvm.ir import container as _container
from tvm.tir import expr as _expr, Buffer from tvm.tir import IterVar, Buffer
from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import _ffi_api
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -42,31 +42,6 @@ class Singleton(Object): ...@@ -42,31 +42,6 @@ class Singleton(Object):
"""Singleton axis.""" """Singleton axis."""
@tvm._ffi.register_object
class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable.
IterVar is normally created by Operation, to represent
axis iterations in the computation.
It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
_tensor.iter_var_cls = IterVar
def create_schedule(ops): def create_schedule(ops):
"""Create a schedule for list of ops """Create a schedule for list of ops
...@@ -82,7 +57,7 @@ def create_schedule(ops): ...@@ -82,7 +57,7 @@ def create_schedule(ops):
""" """
if not isinstance(ops, (list, _container.Array)): if not isinstance(ops, (list, _container.Array)):
ops = [ops] ops = [ops]
return _api_internal._CreateSchedule(ops) return _ffi_api.CreateSchedule(ops)
@tvm._ffi.register_object @tvm._ffi.register_object
...@@ -108,7 +83,7 @@ class Schedule(Object): ...@@ -108,7 +83,7 @@ class Schedule(Object):
sch : Schedule sch : Schedule
The normalized schedule. The normalized schedule.
""" """
return _api_internal._ScheduleNormalize(self) return _ffi_api.ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False): def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary. """Create stage group by giving output and input boundary.
...@@ -137,7 +112,7 @@ class Schedule(Object): ...@@ -137,7 +112,7 @@ class Schedule(Object):
outputs = [outputs] outputs = [outputs]
if isinstance(inputs, _tensor.Tensor): if isinstance(inputs, _tensor.Tensor):
inputs = [inputs] inputs = [inputs]
return _api_internal._ScheduleCreateGroup( return _ffi_api.ScheduleCreateGroup(
self, outputs, inputs, include_inputs) self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers): def cache_read(self, tensor, scope, readers):
...@@ -164,7 +139,7 @@ class Schedule(Object): ...@@ -164,7 +139,7 @@ class Schedule(Object):
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)): if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers] readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers] readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers) return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope): def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor. """Create a cache write of original tensor, before storing into tensor.
...@@ -192,7 +167,7 @@ class Schedule(Object): ...@@ -192,7 +167,7 @@ class Schedule(Object):
cache : Tensor cache : Tensor
The created cache tensor. The created cache tensor.
""" """
return _api_internal._ScheduleCacheWrite(self, tensor, scope) return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis, factor_axis=0): def rfactor(self, tensor, axis, factor_axis=0):
""" Factor a reduction axis in tensor's schedule to be an explicit axis. """ Factor a reduction axis in tensor's schedule to be an explicit axis.
...@@ -215,7 +190,7 @@ class Schedule(Object): ...@@ -215,7 +190,7 @@ class Schedule(Object):
tfactor : Tensor or Array of Tensor tfactor : Tensor or Array of Tensor
The created factored tensor. The created factored tensor.
""" """
factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis) factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis)
return factored[0] if len(factored) == 1 else factored return factored[0] if len(factored) == 1 else factored
...@@ -247,11 +222,11 @@ class Stage(Object): ...@@ -247,11 +222,11 @@ class Stage(Object):
if nparts is not None: if nparts is not None:
if factor is not None: if factor is not None:
raise ValueError("Do not need to provide both outer and nparts") raise ValueError("Do not need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts) outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
else: else:
if factor is None: if factor is None:
raise ValueError("Either nparts or factor need to be provided") raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor) outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
def fuse(self, *args): def fuse(self, *args):
...@@ -270,7 +245,7 @@ class Stage(Object): ...@@ -270,7 +245,7 @@ class Stage(Object):
fused : IterVar fused : IterVar
The fused variable of iteration. The fused variable of iteration.
""" """
fused = _api_internal._StageFuse(self, args) fused = _ffi_api.StageFuse(self, args)
return fused return fused
def set_scope(self, scope): def set_scope(self, scope):
...@@ -281,7 +256,7 @@ class Stage(Object): ...@@ -281,7 +256,7 @@ class Stage(Object):
scope : str scope : str
The thread scope of this stage The thread scope of this stage
""" """
return _api_internal._StageSetScope(self, scope) return _ffi_api.StageSetScope(self, scope)
def bind(self, ivar, thread_ivar): def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar """Bind ivar to thread index thread_ivar
...@@ -294,7 +269,7 @@ class Stage(Object): ...@@ -294,7 +269,7 @@ class Stage(Object):
thread_ivar : IterVar thread_ivar : IterVar
The thread to be binded. The thread to be binded.
""" """
_api_internal._StageBind(self, ivar, thread_ivar) _ffi_api.StageBind(self, ivar, thread_ivar)
def env_threads(self, threads): def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op. """Mark threads to be launched at the outer scope of composed op.
...@@ -306,7 +281,7 @@ class Stage(Object): ...@@ -306,7 +281,7 @@ class Stage(Object):
""" """
if isinstance(threads, IterVar): if isinstance(threads, IterVar):
threads = [threads] threads = [threads]
_api_internal._StageEnvThreads(self, threads) _ffi_api.StageEnvThreads(self, threads)
def set_store_predicate(self, predicate): def set_store_predicate(self, predicate):
"""Set predicate under which store to the array can be performed. """Set predicate under which store to the array can be performed.
...@@ -319,7 +294,7 @@ class Stage(Object): ...@@ -319,7 +294,7 @@ class Stage(Object):
predicate : Expr predicate : Expr
The guard condition fo store. The guard condition fo store.
""" """
_api_internal._StageSetStorePredicate(self, predicate) _ffi_api.StageSetStorePredicate(self, predicate)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
...@@ -332,7 +307,7 @@ class Stage(Object): ...@@ -332,7 +307,7 @@ class Stage(Object):
scope : IterVar scope : IterVar
The loop scope t be attached to. The loop scope t be attached to.
""" """
_api_internal._StageComputeAt(self, parent, scope) _ffi_api.StageComputeAt(self, parent, scope)
def compute_inline(self): def compute_inline(self):
"""Mark stage as inline """Mark stage as inline
...@@ -342,7 +317,7 @@ class Stage(Object): ...@@ -342,7 +317,7 @@ class Stage(Object):
parent : Stage parent : Stage
The parent stage The parent stage
""" """
_api_internal._StageComputeInline(self) _ffi_api.StageComputeInline(self)
def compute_root(self): def compute_root(self):
"""Attach the stage at parent, and mark it as root """Attach the stage at parent, and mark it as root
...@@ -352,7 +327,7 @@ class Stage(Object): ...@@ -352,7 +327,7 @@ class Stage(Object):
parent : Stage parent : Stage
The parent stage The parent stage
""" """
_api_internal._StageComputeRoot(self) _ffi_api.StageComputeRoot(self)
def reorder(self, *args): def reorder(self, *args):
"""reorder the arguments in the specified order. """reorder the arguments in the specified order.
...@@ -362,7 +337,7 @@ class Stage(Object): ...@@ -362,7 +337,7 @@ class Stage(Object):
args : list of IterVar args : list of IterVar
The order to be ordered The order to be ordered
""" """
_api_internal._StageReorder(self, args) _ffi_api.StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor): def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions """ Perform tiling on two dimensions
...@@ -392,7 +367,7 @@ class Stage(Object): ...@@ -392,7 +367,7 @@ class Stage(Object):
p_y_inner : IterVar p_y_inner : IterVar
Inner axis of y dimension Inner axis of y dimension
""" """
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile( x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
...@@ -404,7 +379,7 @@ class Stage(Object): ...@@ -404,7 +379,7 @@ class Stage(Object):
var : IterVar var : IterVar
The iteration to be vectorize The iteration to be vectorize
""" """
_api_internal._StageVectorize(self, var) _ffi_api.StageVectorize(self, var)
def tensorize(self, var, tensor_intrin): def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin """Tensorize the computation enclosed by var with tensor_intrin
...@@ -417,7 +392,7 @@ class Stage(Object): ...@@ -417,7 +392,7 @@ class Stage(Object):
tensor_intrin : TensorIntrin tensor_intrin : TensorIntrin
The tensor intrinsic used for computation. The tensor intrinsic used for computation.
""" """
_api_internal._StageTensorize(self, var, tensor_intrin) _ffi_api.StageTensorize(self, var, tensor_intrin)
def unroll(self, var): def unroll(self, var):
"""Unroll the iteration. """Unroll the iteration.
...@@ -427,7 +402,7 @@ class Stage(Object): ...@@ -427,7 +402,7 @@ class Stage(Object):
var : IterVar var : IterVar
The iteration to be unrolled. The iteration to be unrolled.
""" """
_api_internal._StageUnroll(self, var) _ffi_api.StageUnroll(self, var)
def parallel(self, var): def parallel(self, var):
"""Parallelize the iteration. """Parallelize the iteration.
...@@ -437,7 +412,7 @@ class Stage(Object): ...@@ -437,7 +412,7 @@ class Stage(Object):
var : IterVar var : IterVar
The iteration to be parallelized. The iteration to be parallelized.
""" """
_api_internal._StageParallel(self, var) _ffi_api.StageParallel(self, var)
def pragma(self, var, pragma_type, pragma_value=None): def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma """Annotate the iteration with pragma
...@@ -489,7 +464,7 @@ class Stage(Object): ...@@ -489,7 +464,7 @@ class Stage(Object):
""" """
if isinstance(pragma_value, string_types): if isinstance(pragma_value, string_types):
pragma_value = convert(pragma_value) pragma_value = convert(pragma_value)
_api_internal._StagePragma(self, var, pragma_type, pragma_value) _ffi_api.StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset): def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable """Prefetch the specified variable
...@@ -503,7 +478,7 @@ class Stage(Object): ...@@ -503,7 +478,7 @@ class Stage(Object):
offset : Expr offset : Expr
The number of iterations to be prefetched before actual execution The number of iterations to be prefetched before actual execution
""" """
_api_internal._StagePrefetch(self, tensor, var, offset) _ffi_api.StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset): def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis """Set alignment requirement for specific axis
...@@ -523,7 +498,7 @@ class Stage(Object): ...@@ -523,7 +498,7 @@ class Stage(Object):
offset : int offset : int
The offset in the alignment specification. The offset in the alignment specification.
""" """
_api_internal._StageStorageAlign(self, axis, factor, offset) _ffi_api.StageStorageAlign(self, axis, factor, offset)
def double_buffer(self): def double_buffer(self):
"""Compute the current stage via double buffering. """Compute the current stage via double buffering.
...@@ -532,13 +507,14 @@ class Stage(Object): ...@@ -532,13 +507,14 @@ class Stage(Object):
This will double the storage cost of the current stage. This will double the storage cost of the current stage.
Can be useful to hide load latency. Can be useful to hide load latency.
""" """
_api_internal._StageDoubleBuffer(self) _ffi_api.StageDoubleBuffer(self)
def opengl(self): def opengl(self):
"""The special OpenGL schedule """The special OpenGL schedule
Maps each output element to a pixel. Maps each output element to a pixel.
""" """
_api_internal._StageOpenGL(self) _ffi_api.StageOpenGL(self)
tvm._ffi._init_api("tvm.schedule") tvm._ffi._init_api("schedule", __name__)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
"""Tag class for TVM operators.""" """Tag class for TVM operators."""
import warnings import warnings
from ._ffi.base import decorate from tvm._ffi.base import decorate
class TagScope(object): class TagScope(object):
"""Tag scope object to set tag for operators, working as context """Tag scope object to set tag for operators, working as context
......
...@@ -14,15 +14,14 @@ ...@@ -14,15 +14,14 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Tensor and Operation class for computation declaration.""" """Tensor class for computation declaration."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
import tvm._ffi import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr from tvm.tir import expr as _expr
from . import _api_internal from . import _ffi_api
class TensorSlice(ObjectGeneric, _expr.ExprOp): class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor.""" """Auxiliary data structure for enable slicing syntax from tensor."""
...@@ -52,9 +51,6 @@ class TensorIntrinCall(Object): ...@@ -52,9 +51,6 @@ class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic.""" """Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None
@tvm._ffi.register_object @tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp): class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor""" """Tensor object, to construct, see function.Tensor"""
...@@ -68,7 +64,7 @@ class Tensor(Object, _expr.ExprOp): ...@@ -68,7 +64,7 @@ class Tensor(Object, _expr.ExprOp):
for x in indices: for x in indices:
if isinstance(x, _expr.PrimExpr): if isinstance(x, _expr.PrimExpr):
args.append(x) args.append(x)
elif isinstance(x, iter_var_cls): elif isinstance(x, _expr.IterVar):
args.append(x.var) args.append(x.var)
else: else:
raise ValueError("The indices must be expression") raise ValueError("The indices must be expression")
...@@ -81,7 +77,7 @@ class Tensor(Object, _expr.ExprOp): ...@@ -81,7 +77,7 @@ class Tensor(Object, _expr.ExprOp):
return TensorSlice(self, indices) return TensorSlice(self, indices)
def __hash__(self): def __hash__(self):
return _api_internal._TensorHash(self) return _ffi_api.TensorHash(self)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
...@@ -92,7 +88,7 @@ class Tensor(Object, _expr.ExprOp): ...@@ -92,7 +88,7 @@ class Tensor(Object, _expr.ExprOp):
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, " raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, " "use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison") "use Tensor.same_as for exact reference comparison")
return _api_internal._TensorEqual(self, other) return _ffi_api.TensorEqual(self, other)
@property @property
def ndim(self): def ndim(self):
...@@ -143,17 +139,17 @@ class Operation(Object): ...@@ -143,17 +139,17 @@ class Operation(Object):
out : Tensor out : Tensor
The i-th output. The i-th output.
""" """
return _api_internal._OpGetOutput(self, index) return _ffi_api.OpGetOutput(self, index)
@property @property
def num_outputs(self): def num_outputs(self):
"""Number of outputs from this op.""" """Number of outputs from this op."""
return _api_internal._OpNumOutputs(self) return _ffi_api.OpNumOutputs(self)
@property @property
def input_tensors(self): def input_tensors(self):
"""List of input tensors to this op.""" """List of input tensors to this op."""
return _api_internal._OpInputTensors(self) return _ffi_api.OpInputTensors(self)
@tvm._ffi.register_object @tvm._ffi.register_object
......
...@@ -16,17 +16,15 @@ ...@@ -16,17 +16,15 @@
# under the License. # under the License.
"""Tensor intrinsics""" """Tensor intrinsics"""
import tvm._ffi import tvm._ffi
import tvm.tir
from tvm.runtime import Object from tvm.runtime import Object, convert
from tvm.ir import Range from tvm.ir import Range
from tvm.tir import expr as _expr from tvm.target import BuildConfig
from tvm.tir import stmt as _stmt from .tensor import PlaceholderOp
from . import _api_internal
from . import api as _api
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import _ffi_api
from .build_module import current_build_config
def _get_region(tslice): def _get_region(tslice):
...@@ -34,15 +32,16 @@ def _get_region(tslice): ...@@ -34,15 +32,16 @@ def _get_region(tslice):
for idx in tslice.indices: for idx in tslice.indices:
if isinstance(idx, slice): if isinstance(idx, slice):
assert idx.step is None assert idx.step is None
region.append(_api.Range(idx.start, idx.stop)) region.append(Range(idx.start, idx.stop))
else: else:
if isinstance(idx, _schedule.IterVar): if isinstance(idx, tvm.tir.IterVar):
begin = idx.var begin = idx.var
else: else:
begin = idx begin = idx
region.append(Range.make_by_min_extent(begin, 1)) region.append(Range.make_by_min_extent(begin, 1))
return region return region
@tvm._ffi.register_object @tvm._ffi.register_object
class TensorIntrin(Object): class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation. """Tensor intrinsic functions for certain computation.
...@@ -60,10 +59,11 @@ class TensorIntrin(Object): ...@@ -60,10 +59,11 @@ class TensorIntrin(Object):
reduce_axis = kwargs["reduce_axis"] reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)): if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis] reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis) reduce_axis = convert(reduce_axis)
if scalar_inputs: if scalar_inputs:
scalar_inputs = _api.convert(scalar_inputs) scalar_inputs = convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs) return _ffi_api.TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
def decl_tensor_intrin(op, def decl_tensor_intrin(op,
fcompute, fcompute,
...@@ -119,15 +119,15 @@ def decl_tensor_intrin(op, ...@@ -119,15 +119,15 @@ def decl_tensor_intrin(op,
binds_list = [] binds_list = []
for t in inputs: for t in inputs:
if not isinstance(t.op, _tensor.PlaceholderOp): if not isinstance(t.op, PlaceholderOp):
raise ValueError("Do not yet support composition op") raise ValueError("Do not yet support composition op")
cfg = current_build_config() cfg = BuildConfig.current()
for t in tensors: for t in tensors:
buf = (binds[t] if t in binds else buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name, tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment, data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)) offset_factor=cfg.offset_factor))
binds_list.append(buf) binds_list.append(buf)
if scalar_params: if scalar_params:
...@@ -135,10 +135,10 @@ def decl_tensor_intrin(op, ...@@ -135,10 +135,10 @@ def decl_tensor_intrin(op,
else: else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = [] scalar_params = []
if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)): if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
body = [body] body = [body]
body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body] body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
if len(body) < 3: if len(body) < 3:
body += [None] * (3 - len(body)) body += [None] * (3 - len(body))
return _api_internal._TensorIntrin( return _ffi_api.TensorIntrin(
name, op, inputs, binds_list, scalar_params, *body) name, op, inputs, binds_list, scalar_params, *body)
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
""" TVM testing utilities """ """ TVM testing utilities """
import logging import logging
import numpy as np import numpy as np
import tvm._ffi
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
""" Version of np.testing.assert_allclose with `atol` and `rtol` fields set """ Version of np.testing.assert_allclose with `atol` and `rtol` fields set
...@@ -161,3 +163,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No ...@@ -161,3 +163,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
logging.info("Numerical grad test wrt '%s' of shape %s passes, " logging.info("Numerical grad test wrt '%s' of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f", "dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff) x_name, grad.shape, dist, max_diff, avg_diff)
tvm._ffi._init_api("testing", __name__)
...@@ -23,16 +23,18 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast ...@@ -23,16 +23,18 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, min_value, max_value from .op import call_llvm_intrin, all, any, min_value, max_value
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from . import ir_builder from . import ir_builder
from . import ir_pass from . import ir_pass
...@@ -309,6 +309,57 @@ class SizeVar(Var): ...@@ -309,6 +309,57 @@ class SizeVar(Var):
@tvm._ffi.register_object @tvm._ffi.register_object
class IterVar(Object, ExprOp):
"""Represent iteration variable.
IterVar represents axis iterations in the computation.
Parameters
----------
dom : Range
The domain of the iteration.
var : Union[Var, str]
The internal variable that is used for iteration.
iter_type : int
The iteration type.
thread_tag : str
The thread type tag.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
def __init__(self, dom, var, iter_type, thread_tag=""):
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise TypeError("need to be list of ranges")
dom = tvm.ir.Range(dom[0], dom[1])
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
@tvm._ffi.register_object
class CommReducer(Object): class CommReducer(Object):
"""Communicative reduce operator """Communicative reduce operator
......
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression.""" """Operators used in TIR expression."""
import tvm._ffi import tvm._ffi
from tvm.runtime import convert, const from tvm.runtime import convert, const
from tvm.schedule import Buffer from tvm.ir import Array
from .expr import Call from .buffer import Buffer
from .expr import Call, Var, CommReducer
from . import _ffi_api from . import _ffi_api
...@@ -196,6 +197,53 @@ def call_llvm_intrin(dtype, name, *args): ...@@ -196,6 +197,53 @@ def call_llvm_intrin(dtype, name, *args):
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpOr(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpAnd(ret, args[i])
return ret
@tvm._ffi.register_func("tvm.default_trace_action") @tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args): def _tvm_default_trace_action(*args):
print(list(args)) print(list(args))
...@@ -780,3 +828,137 @@ def floormod(a, b): ...@@ -780,3 +828,137 @@ def floormod(a, b):
The result expression. The result expression.
""" """
return _ffi_api._OpFloorMod(a, b) return _ffi_api._OpFloorMod(a, b)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = tvm.var("n")
m = tvm.var("m")
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name="A")
k = tvm.reduce_axis((0, m), name="k")
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + "_" + str(i)
larr.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rarr.append(Var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, tvm.ir.PrimExpr)
size = 1
dtype = expr.dtype
lvar = Var(code.co_varnames[0], dtype)
rvar = Var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = convert(id_elem)
combiner = CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
if where is None:
assert not args
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : PrimExpr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : PrimExpr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y), max_value, name="min")
max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y), min_value, name="max")
...@@ -64,16 +64,16 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound") ...@@ -64,16 +64,16 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DomainTouched") TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched); .set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("_IntervalSetGetMin") TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin")
.set_body_method(&IntSet::min); .set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("_IntervalSetGetMax") TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax")
.set_body_method(&IntSet::max); .set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("_IntSetIsNothing") TVM_REGISTER_GLOBAL("arith._IntSetIsNothing")
.set_body_method(&IntSet::is_nothing); .set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("_IntSetIsEverything") TVM_REGISTER_GLOBAL("arith._IntSetIsEverything")
.set_body_method(&IntSet::is_everything); .set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
......
...@@ -40,115 +40,113 @@ TVM_REGISTER_GLOBAL("tir.min_value") ...@@ -40,115 +40,113 @@ TVM_REGISTER_GLOBAL("tir.min_value")
TVM_REGISTER_GLOBAL("tir.max_value") TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value); .set_body_typed(max_value);
TVM_REGISTER_GLOBAL("Range") TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { *ret = Range(args[0], args[1]);
*ret = Range(0, args[0]);
} else {
*ret = Range(args[0], args[1]);
}
}); });
namespace tir {
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
}
namespace te { namespace te {
TVM_REGISTER_GLOBAL("_Tensor") TVM_REGISTER_GLOBAL("te.Tensor")
.set_body_typed(TensorNode::make); .set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrin") TVM_REGISTER_GLOBAL("te.TensorIntrin")
.set_body_typed(TensorIntrinNode::make); .set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrinCall") TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
.set_body_typed(TensorIntrinCallNode::make); .set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_GLOBAL("_TensorEqual") TVM_REGISTER_GLOBAL("te.TensorEqual")
.set_body_method(&Tensor::operator==); .set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("_TensorHash") TVM_REGISTER_GLOBAL("te.TensorHash")
.set_body_typed([](Tensor tensor) -> int64_t { .set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor)); return static_cast<int64_t>(std::hash<Tensor>()(tensor));
}); });
TVM_REGISTER_GLOBAL("_Placeholder") TVM_REGISTER_GLOBAL("te.Placeholder")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) { .set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name); return placeholder(shape, dtype, name);
}); });
TVM_REGISTER_GLOBAL("_ComputeOp") TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed(ComputeOpNode::make); .set_body_typed(ComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ScanOp") TVM_REGISTER_GLOBAL("te.ScanOp")
.set_body_typed(ScanOpNode::make); .set_body_typed(ScanOpNode::make);
TVM_REGISTER_GLOBAL("_TensorComputeOp") TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make); .set_body_typed(TensorComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ExternOp") TVM_REGISTER_GLOBAL("te.ExternOp")
.set_body_typed(ExternOpNode::make); .set_body_typed(ExternOpNode::make);
TVM_REGISTER_GLOBAL("_HybridOp") TVM_REGISTER_GLOBAL("te.HybridOp")
.set_body_typed(HybridOpNode::make); .set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("_OpGetOutput") TVM_REGISTER_GLOBAL("te.OpGetOutput")
.set_body_typed([](Operation op, int64_t output) { .set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output)); return op.output(static_cast<size_t>(output));
}); });
TVM_REGISTER_GLOBAL("_OpNumOutputs") TVM_REGISTER_GLOBAL("te.OpNumOutputs")
.set_body_method<Operation>(&OperationNode::num_outputs); .set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_GLOBAL("_OpInputTensors") TVM_REGISTER_GLOBAL("te.OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors); .set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("_IterVar") TVM_REGISTER_GLOBAL("te.CreateSchedule")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
TVM_REGISTER_GLOBAL("_CreateSchedule")
.set_body_typed(create_schedule); .set_body_typed(create_schedule);
TVM_REGISTER_GLOBAL("_StageSetScope") TVM_REGISTER_GLOBAL("te.StageSetScope")
.set_body_method(&Stage::set_scope); .set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("_StageBind") TVM_REGISTER_GLOBAL("te.StageBind")
.set_body_method(&Stage::bind); .set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor") TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner; IterVar outer, inner;
stage.split(parent, factor, &outer, &inner); stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner}); return Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_GLOBAL("_StageSplitByNParts") TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner; IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner); stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner}); return Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_GLOBAL("_StageFuse") TVM_REGISTER_GLOBAL("te.StageFuse")
.set_body_typed([](Stage stage, Array<IterVar> axes) { .set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused; IterVar fused;
stage.fuse(axes, &fused); stage.fuse(axes, &fused);
return fused; return fused;
}); });
TVM_REGISTER_GLOBAL("_StageComputeAt") TVM_REGISTER_GLOBAL("te.StageComputeAt")
.set_body_method(&Stage::compute_at); .set_body_method(&Stage::compute_at);
TVM_REGISTER_GLOBAL("_StageComputeInline") TVM_REGISTER_GLOBAL("te.StageComputeInline")
.set_body_method(&Stage::compute_inline); .set_body_method(&Stage::compute_inline);
TVM_REGISTER_GLOBAL("_StageComputeRoot") TVM_REGISTER_GLOBAL("te.StageComputeRoot")
.set_body_method(&Stage::compute_root); .set_body_method(&Stage::compute_root);
TVM_REGISTER_GLOBAL("_StageReorder") TVM_REGISTER_GLOBAL("te.StageReorder")
.set_body_method(&Stage::reorder); .set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("_StageTile") TVM_REGISTER_GLOBAL("te.StageTile")
.set_body_typed([]( .set_body_typed([](
Stage stage, Stage stage,
IterVar x_parent, IterVar y_parent, IterVar x_parent, IterVar y_parent,
...@@ -162,49 +160,49 @@ TVM_REGISTER_GLOBAL("_StageTile") ...@@ -162,49 +160,49 @@ TVM_REGISTER_GLOBAL("_StageTile")
return Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_GLOBAL("_StageEnvThreads") TVM_REGISTER_GLOBAL("te.StageEnvThreads")
.set_body_method(&Stage::env_threads); .set_body_method(&Stage::env_threads);
TVM_REGISTER_GLOBAL("_StageSetStorePredicate") TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
.set_body_method(&Stage::set_store_predicate); .set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_GLOBAL("_StageUnroll") TVM_REGISTER_GLOBAL("te.StageUnroll")
.set_body_method(&Stage::unroll); .set_body_method(&Stage::unroll);
TVM_REGISTER_GLOBAL("_StageVectorize") TVM_REGISTER_GLOBAL("te.StageVectorize")
.set_body_method(&Stage::vectorize); .set_body_method(&Stage::vectorize);
TVM_REGISTER_GLOBAL("_StageTensorize") TVM_REGISTER_GLOBAL("te.StageTensorize")
.set_body_method(&Stage::tensorize); .set_body_method(&Stage::tensorize);
TVM_REGISTER_GLOBAL("_StageParallel") TVM_REGISTER_GLOBAL("te.StageParallel")
.set_body_method(&Stage::parallel); .set_body_method(&Stage::parallel);
TVM_REGISTER_GLOBAL("_StagePragma") TVM_REGISTER_GLOBAL("te.StagePragma")
.set_body_method(&Stage::pragma); .set_body_method(&Stage::pragma);
TVM_REGISTER_GLOBAL("_StagePrefetch") TVM_REGISTER_GLOBAL("te.StagePrefetch")
.set_body_method(&Stage::prefetch); .set_body_method(&Stage::prefetch);
TVM_REGISTER_GLOBAL("_StageStorageAlign") TVM_REGISTER_GLOBAL("te.StageStorageAlign")
.set_body_method(&Stage::storage_align); .set_body_method(&Stage::storage_align);
TVM_REGISTER_GLOBAL("_StageDoubleBuffer") TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
.set_body_method(&Stage::double_buffer); .set_body_method(&Stage::double_buffer);
TVM_REGISTER_GLOBAL("_StageOpenGL") TVM_REGISTER_GLOBAL("te.StageOpenGL")
.set_body_method(&Stage::opengl); .set_body_method(&Stage::opengl);
TVM_REGISTER_GLOBAL("_ScheduleNormalize") TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
.set_body_method(&Schedule::normalize); .set_body_method(&Schedule::normalize);
TVM_REGISTER_GLOBAL("_ScheduleCreateGroup") TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
.set_body_method(&Schedule::create_group); .set_body_method(&Schedule::create_group);
TVM_REGISTER_GLOBAL("_ScheduleCacheRead") TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
.set_body_method(&Schedule::cache_read); .set_body_method(&Schedule::cache_read);
TVM_REGISTER_GLOBAL("_ScheduleCacheWrite") TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Tensor>()) { if (args[1].IsObjectRef<Tensor>()) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
...@@ -215,11 +213,11 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite") ...@@ -215,11 +213,11 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
} }
}); });
TVM_REGISTER_GLOBAL("_ScheduleRFactor") TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor); .set_body_method(&Schedule::rfactor);
} // namespace te } // namespace te
TVM_REGISTER_GLOBAL("_CommReducerCombine") TVM_REGISTER_GLOBAL("te.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator()); .set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
} // namespace tvm } // namespace tvm
...@@ -47,9 +47,9 @@ TVM_REGISTER_GLOBAL("schedule.ScheduleOps") ...@@ -47,9 +47,9 @@ TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
*ret = ScheduleOps(args[0], args[1], args[2]); *ret = ScheduleOps(args[0], args[1], args[2]);
}); });
#define REGISTER_SCHEDULE_PASS(PassName) \ #define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_GLOBAL("schedule."#PassName) \ TVM_REGISTER_GLOBAL("schedule."#PassName) \
.set_body_typed(PassName); \ .set_body_typed(PassName); \
REGISTER_SCHEDULE_PASS(InferBound); REGISTER_SCHEDULE_PASS(InferBound);
......
...@@ -54,11 +54,11 @@ struct TestAttrs : public AttrsNode<TestAttrs> { ...@@ -54,11 +54,11 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TVM_REGISTER_NODE_TYPE(TestAttrs); TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_REGISTER_GLOBAL("_nop") TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
}); });
TVM_REGISTER_GLOBAL("_test_wrap_callback") TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0]; PackedFunc pf = args[0];
*ret = runtime::TypedPackedFunc<void()>([pf](){ *ret = runtime::TypedPackedFunc<void()>([pf](){
...@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("_test_wrap_callback") ...@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("_test_wrap_callback")
}); });
}); });
TVM_REGISTER_GLOBAL("_test_raise_error_callback") TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0]; std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void()>([msg](){ *ret = runtime::TypedPackedFunc<void()>([msg](){
...@@ -74,7 +74,7 @@ TVM_REGISTER_GLOBAL("_test_raise_error_callback") ...@@ -74,7 +74,7 @@ TVM_REGISTER_GLOBAL("_test_raise_error_callback")
}); });
}); });
TVM_REGISTER_GLOBAL("_test_check_eq_callback") TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0]; std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){ *ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
...@@ -82,7 +82,7 @@ TVM_REGISTER_GLOBAL("_test_check_eq_callback") ...@@ -82,7 +82,7 @@ TVM_REGISTER_GLOBAL("_test_check_eq_callback")
}); });
}); });
TVM_REGISTER_GLOBAL("_context_test") TVM_REGISTER_GLOBAL("testing.context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0]; DLContext ctx = args[0];
int dtype = args[1]; int dtype = args[1];
...@@ -103,11 +103,11 @@ void ErrorTest(int x, int y) { ...@@ -103,11 +103,11 @@ void ErrorTest(int x, int y) {
} }
} }
TVM_REGISTER_GLOBAL("_ErrorTest") TVM_REGISTER_GLOBAL("testing.ErrorTest")
.set_body_typed(ErrorTest); .set_body_typed(ErrorTest);
// internal function used for debug and testing purposes // internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("_ndarray_use_count") TVM_REGISTER_GLOBAL("testing.ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0]; runtime::NDArray nd = args[0];
// substract the current one // substract the current one
......
...@@ -403,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -403,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")"; p->stream << ")";
}); });
TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig") TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current(); *ret = BuildConfig::Current();
}); });
...@@ -418,13 +418,13 @@ class BuildConfig::Internal { ...@@ -418,13 +418,13 @@ class BuildConfig::Internal {
} }
}; };
TVM_REGISTER_GLOBAL("_EnterBuildConfigScope") TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope")
.set_body_typed(BuildConfig::Internal::EnterScope); .set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitBuildConfigScope") TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
.set_body_typed(BuildConfig::Internal::ExitScope); .set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass") TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0]; BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass; std::vector< std::pair<int, PackedFunc> > add_lower_pass;
...@@ -437,7 +437,7 @@ TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass") ...@@ -437,7 +437,7 @@ TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
cfg->add_lower_pass = add_lower_pass; cfg->add_lower_pass = add_lower_pass;
}); });
TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo") TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following: // Return one of the following:
// * Size of add_lower_pass if num_args == 1 // * Size of add_lower_pass if num_args == 1
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
from tvm.schedule import Buffer from tvm.tir import Buffer
import numpy as np import numpy as np
def test_buffer(): def test_buffer():
...@@ -25,7 +25,7 @@ def test_buffer(): ...@@ -25,7 +25,7 @@ def test_buffer():
Ab = tvm.decl_buffer((m, n), tvm.float32) Ab = tvm.decl_buffer((m, n), tvm.float32)
Bb = tvm.decl_buffer((n, l), tvm.float32) Bb = tvm.decl_buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.schedule.Buffer) assert isinstance(Ab, tvm.tir.Buffer)
assert Ab.dtype == tvm.float32 assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n) assert tuple(Ab.shape) == (m, n)
......
...@@ -22,8 +22,8 @@ def test_expr_constructor(): ...@@ -22,8 +22,8 @@ def test_expr_constructor():
assert x.name == "xx" assert x.name == "xx"
x = tvm.tir.Reduce(None, [1], x = tvm.tir.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)], [tvm.tir.IterVar((0, 1), "x", 2)],
None, 0) None, 0)
assert isinstance(x, tvm.tir.Reduce) assert isinstance(x, tvm.tir.Reduce)
assert x.combiner == None assert x.combiner == None
assert x.value_index == 0 assert x.value_index == 0
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
# under the License. # under the License.
"""Test runtime error handling""" """Test runtime error handling"""
import tvm import tvm
import tvm.testing
def test_op_translation(): def test_op_translation():
ferror = tvm._api_internal._test_raise_error_callback( ferror = tvm.testing.test_raise_error_callback(
"OpNotImplemented: myop") "OpNotImplemented: myop")
try: try:
ferror() ferror()
...@@ -28,7 +29,7 @@ def test_op_translation(): ...@@ -28,7 +29,7 @@ def test_op_translation():
assert isinstance(e, NotImplementedError) assert isinstance(e, NotImplementedError)
assert msg.find("api_test.cc") != -1 assert msg.find("api_test.cc") != -1
fchk_eq = tvm._api_internal._test_check_eq_callback( fchk_eq = tvm.testing.test_check_eq_callback(
"InternalError: myop") "InternalError: myop")
try: try:
fchk_eq(0, 1) fchk_eq(0, 1)
...@@ -38,7 +39,7 @@ def test_op_translation(): ...@@ -38,7 +39,7 @@ def test_op_translation():
assert msg.find("api_test.cc") != -1 assert msg.find("api_test.cc") != -1
try: try:
tvm._api_internal._ErrorTest(0, 1) tvm.testing.ErrorTest(0, 1)
assert False assert False
except ValueError as e: except ValueError as e:
msg = str(e) msg = str(e)
...@@ -48,13 +49,13 @@ def test_op_translation(): ...@@ -48,13 +49,13 @@ def test_op_translation():
def test_deep_callback(): def test_deep_callback():
def error_callback(): def error_callback():
raise ValueError("callback error") raise ValueError("callback error")
wrap1 = tvm._api_internal._test_wrap_callback(error_callback) wrap1 = tvm.testing.test_wrap_callback(error_callback)
def flevel2(): def flevel2():
wrap1() wrap1()
wrap2 = tvm._api_internal._test_wrap_callback(flevel2) wrap2 = tvm.testing.test_wrap_callback(flevel2)
def flevel3(): def flevel3():
wrap2() wrap2()
wrap3 = tvm._api_internal._test_wrap_callback(flevel3) wrap3 = tvm.testing.test_wrap_callback(flevel3)
try: try:
wrap3() wrap3()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import tvm.testing
import numpy as np import numpy as np
def test_get_global(): def test_get_global():
...@@ -93,7 +94,7 @@ def test_ctx(): ...@@ -93,7 +94,7 @@ def test_ctx():
x = test_ctx_func(tvm.gpu(7)) x = test_ctx_func(tvm.gpu(7))
assert x == tvm.cpu(0) assert x == tvm.cpu(0)
x = tvm.opencl(10) x = tvm.opencl(10)
x = tvm._api_internal._context_test(x, x.device_type, x.device_id) x = tvm.testing.context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10) assert x == tvm.opencl(10)
def test_trace_default_action(): def test_trace_default_action():
...@@ -282,4 +283,3 @@ if __name__ == "__main__": ...@@ -282,4 +283,3 @@ if __name__ == "__main__":
test_trace_default_action() test_trace_default_action()
test_trace_can_change_traced_value_int() test_trace_can_change_traced_value_int()
test_trace_can_change_traced_value_float() test_trace_can_change_traced_value_float()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import tvm.testing
import os import os
import logging import logging
import time import time
...@@ -210,7 +211,7 @@ def test_rpc_return_ndarray(): ...@@ -210,7 +211,7 @@ def test_rpc_return_ndarray():
if name == "get_arr": if name == "get_arr":
return lambda : nd return lambda : nd
elif name == "ref_count": elif name == "ref_count":
return lambda : tvm._api_internal._ndarray_use_count(nd) return lambda : tvm.testing.ndarray_use_count(nd)
elif name == "get_elem": elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx] return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem": elif name == "get_arr_elem":
......
...@@ -96,7 +96,7 @@ def lower(*args, **kwargs): ...@@ -96,7 +96,7 @@ def lower(*args, **kwargs):
-------- --------
tvm.lower : The original TVM's lower function tvm.lower : The original TVM's lower function
""" """
cfg = tvm.build_module.current_build_config() cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass: if not cfg.add_lower_pass:
with build_config(): with build_config():
return tvm.lower(*args, **kwargs) return tvm.lower(*args, **kwargs)
...@@ -113,7 +113,7 @@ def build(*args, **kwargs): ...@@ -113,7 +113,7 @@ def build(*args, **kwargs):
-------- --------
tvm.build : The original TVM's build function tvm.build : The original TVM's build function
""" """
cfg = tvm.build_module.current_build_config() cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass: if not cfg.add_lower_pass:
with build_config(): with build_config():
return tvm.build(*args, **kwargs) return tvm.build(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment