Unverified Commit 08338dd5 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.te and tvm.driver (#4900)

- Move the related files to tvm.te
- Move build_module.py to tvm.driver
parent 27a02844
......@@ -47,25 +47,30 @@ from . import tir
# tvm.target
from . import target
from .target import build_config
# others
from . import tensor
from . import arith
from . import make
from . import schedule
from . import hybrid
# tvm.te
from .te import decl_tensor_intrin, create_schedule, tag_scope
# tvm.testing
from . import testing
from .api import *
from .tensor_intrin import decl_tensor_intrin
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
# tvm.driver
from .driver import build, lower
# tvm.hybrid
from . import hybrid
# others
from . import arith
# backward compact for topi, to be removed later
from .api import *
from .tir import expr, stmt, ir_builder, ir_pass, generic
from .te import tensor, schedule
from .tir.op import *
from . import intrin
from . import make
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......
......@@ -16,623 +16,23 @@
# under the License.
"""Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
from numbers import Integral as _Integral
import tvm._ffi
import tvm.ir
import tvm.tir
from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.ir import container as _container, Range
from tvm.tir import decl_buffer, layout, bijective_layout
from tvm.tir import min_value, max_value, indexdiv, indexmod
import tvm.tir._ffi_api
from tvm.tir import min_value, max_value, indexdiv, indexmod, all, any
from tvm.te import placeholder, compute, scan, extern, var, size_var, thread_axis, reduce_axis
from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from . import _api_internal
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from . import tag as _tag
int8 = "int8"
int32 = "int32"
float32 = "float32"
handle = "handle"
def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : Var
The result symbolic variable.
"""
return _expr.Var(name, dtype)
def size_var(name="size", dtype=int32):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return _expr.SizeVar(name, dtype)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = tvm.tir._ffi_api._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = tvm.tir._ffi_api._OpOr(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = tvm.tir._ffi_api._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = tvm.tir._ffi_api._OpAnd(ret, args[i])
return ret
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
dtype = float32 if dtype is None else dtype
return _api_internal._Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of indices-> value
Specifies the input source expression
name: str, optional
The name hint of the tensor
tag: str, optional
Additional tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor
The created tensor
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis.
Parameters
----------
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
.. code-block:: python
# The following code is equivalent to numpy.cumsum
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape,
inputs,
fcompute,
name="extern",
dtype=None,
in_buffers=None,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: tuple or list of tuples.
The shape of the outputs.
inputs: list of Tensor
The inputs
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
dtype: str or list of str, optional
The data types of outputs,
by default dtype will be same as inputs.
in_buffers: Buffer or list of Buffer, optional
Input buffers.
out_buffers: Buffer or list of Buffers, optional
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
if shape == () or isinstance(shape[0], (_expr.PrimExpr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
% (len(inputs), len(in_buffers)))
if out_buffers is not None:
out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
if in_buffers is None:
input_placeholders.append(
decl_buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
if len(types) != 1:
raise ValueError("Cannot infer output type, please provide dtype argument")
infered_type = types.pop()
dtype = [infered_type for _ in shape]
if isinstance(dtype, str):
dtype = [dtype]
if out_buffers is None:
for shp, dt in zip(shape, dtype):
output_placeholders.append(decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, _expr.PrimExpr):
body = _stmt.Evaluate(body)
op = _api_internal._ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of iteration variable.
iter_type : int
The type of iteration.
thread_tag : str
The thread tag of the iteration variable.
Returns
-------
iter_var : IterVar
The result itervar
"""
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
return _api_internal._IterVar(dom, v, iter_type, thread_tag)
def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str, optional
The thread tag
name : str, optional
The name of the var.
Returns
-------
axis : IterVar
The thread itervar.
"""
if isinstance(dom, string_types):
tag, dom = dom, None
if not tag:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return _IterVar(dom, name, 1, tag)
def reduce_axis(dom, name="rv"):
"""Create a new IterVar for reduction.
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of the variable.
Returns
-------
axis : IterVar
An iteration variable representing the value.
"""
return _IterVar(dom, name, 2)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = tvm.var('n')
m = tvm.var('m')
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, _container.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.PrimExpr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = convert(id_elem)
combiner = _expr.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
if where is None:
assert not args
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : Expr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : Expr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMax(x, y), min_value, name='max')
tvm._ffi._init_api("tvm.api")
......@@ -18,17 +18,16 @@
import tvm._ffi
from tvm.runtime import Object
from . import _api_internal
class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self)
return _IntSetIsNothing(self)
def is_everything(self):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)
return _IntSetIsEverything(self)
@tvm._ffi.register_object("arith.IntervalSet")
......
......@@ -29,7 +29,8 @@ There are two types of feature
import struct
import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
from tvm import schedule, ir_pass, get_global_func, target as _target
from tvm.driver import build_module
def ana_lower(sch, args,
binds=None,
......
......@@ -26,8 +26,9 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
import tvm.te._ffi_api
from ... import _api_internal, tensor, placeholder
from ... import tensor, placeholder
from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple
......@@ -420,10 +421,10 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None, o
attrs[k] = v
attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op = tvm.te._ffi_api.ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
elif isinstance(op, tensor.ExternOp):
op = _api_internal._ExternOp(
op = tvm.te._ffi_api.ExternOp(
op.name, op.tag, attrs,
op.inputs, op.input_placeholders,
op.output_placeholders, op.body)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for driver APIs"""
from .build_module import lower, build
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The build utils in python.
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
import warnings
import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = BuildConfig.current()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
Parameters
----------
sch : tvm.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str, optional
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
f : LoweredFunc or Stmt
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
cfg = BuildConfig.current()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
mdev : tvm.module
A module that contains device code.
"""
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if BuildConfig.current().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
def build(inputs,
args=None,
target=None,
target_host=None,
name="default_function",
binds=None):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
target_flist = inputs
for tar, flist in target_flist.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not target_host:
for tar, _ in target_flist.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
......@@ -21,7 +21,7 @@ See the example sections for for suggested message conventions.
To make the code more readable, we recommended developers to
copy the examples and raise errors with the same message convention.
"""
from ._ffi.base import register_error, TVMError
from tvm._ffi.base import register_error, TVMError
@register_error
class InternalError(TVMError):
......
......@@ -30,9 +30,9 @@ HalideIR.
# 2. Support multi-level HalideIR
import inspect
import tvm._ffi
from tvm.driver.build_module import form_body
from .._ffi.base import decorate
from ..build_module import form_body
from .module import HybridModule
from .parser import source_to_op
......
......@@ -26,19 +26,20 @@ import numbers
from enum import Enum
from tvm.ir import Array, Range
import tvm.tir
import tvm.te._ffi_api
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.tir import ir_pass as _ir_pass
from tvm.te.tensor import Tensor, Operation
from tvm.tir import all as _all
from tvm.tir import any as _any
from .util import _internal_assert
from . import calls
from . import util
from .preprocessor import determine_variable_usage
from ..api import all as _all
from ..api import any as _any
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import api as _api
......@@ -653,7 +654,7 @@ def source_to_op(src, args, symbols, closure_vars):
for i in args:
get_input_tensors(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
op = tvm.te._ffi_api.HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
......@@ -27,9 +27,9 @@ from tvm.ir.container import Array
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.te.tensor import Tensor
from .. import api as _api
from ..tensor import Tensor
#pylint: disable=invalid-name
......
......@@ -17,10 +17,10 @@
"""Common expressions data structures in the IR."""
import tvm._ffi
from .base import Node
from . import _ffi_api
class BaseExpr(Node):
"""Base class of all the expressions."""
......@@ -98,7 +98,29 @@ class Range(Node):
You do not need to create a Range explicitly.
Python lists and tuples will be converted automatically to a Range in API functions.
Parameters
----------
begin : PrimExpr
The begin value of the range when end is None.
Otherwise it is the length of the range.
end : Optional[PrimExpr]
The end value of the range.
Note
----
The constructor creates the range `[begin, end)`
if the end argument is not None. Otherwise, it creates `[0, begin)`.
"""
def __init__(self, begin, end=None):
if end is None:
self.__init_handle_by_constructor__(
_ffi_api.Range, 0, begin)
else:
self.__init_handle_by_constructor__(
_ffi_api.Range, begin, end)
@staticmethod
def make_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
......
......@@ -16,10 +16,9 @@
# under the License.
"""The interface of expr function exposed from C++."""
import tvm._ffi
import tvm.driver
from tvm.ir import container as _container
from ... import build_module as _build
@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
......@@ -48,7 +47,7 @@ def lower(sch, inputs, func_name, source_func):
import traceback
try:
f = _build.lower(sch, inputs, name=func_name)
f = tvm.driver.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
......@@ -85,7 +84,7 @@ def build(funcs, target, target_host=None):
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)
return tvm.driver.build(funcs, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr")
......
......@@ -18,11 +18,11 @@
"""The base node types for the Relay language."""
import topi
import tvm._ffi
from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr
from ...api import register_func
from ...build_module import lower, build
from . import _make
@register_relay_node
......
......@@ -20,6 +20,7 @@ import logging
import multiprocessing as mp
import numpy as np
import tvm
import tvm.driver
from tvm.ir import IRModule
from . import _quantize
......
......@@ -61,3 +61,4 @@ from .generic_func import generic_func, get_native_generic_func, override_native
from . import datatype
from . import codegen
from .intrin import register_intrin_rule
from .build_config import BuildConfig, build_config
......@@ -14,31 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The build utils in python.
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
import warnings
"""Target dependent BuildConfig for low-level passes."""
# TODO(tvm-team) consolidate with PassContext
import tvm._ffi
import tvm.runtime
import tvm.ir
from tvm.runtime import Object, ndarray
from tvm.runtime import Object
from tvm.ir import container
from tvm.target import codegen
from tvm.tir import expr
from tvm.tir import ir_pass
from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import target as _target
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import make
from . import _ffi_api
class DumpIR(object):
......@@ -166,11 +151,11 @@ class BuildConfig(Object):
@property
def add_lower_pass(self):
size = _api_internal._BuildConfigGetAddLowerPassInfo(self)
size = _ffi_api.BuildConfigGetAddLowerPassInfo(self)
result = []
for i in range(size):
phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True)
func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False)
phase = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, True)
func = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, False)
result += [(phase, func)]
return result
......@@ -179,11 +164,11 @@ class BuildConfig(Object):
add_lower_pass_args = []
for x in value:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
_ffi_api.BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
def __enter__(self):
# pylint: disable=protected-access
_api_internal._EnterBuildConfigScope(self)
_ffi_api.EnterBuildConfigScope(self)
if self.dump_pass_ir:
BuildConfig._dump_ir.enter()
return self
......@@ -191,7 +176,7 @@ class BuildConfig(Object):
def __exit__(self, ptype, value, trace):
if self.dump_pass_ir:
BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope(self)
_ffi_api.ExitBuildConfigScope(self)
def __setattr__(self, name, value):
if name in BuildConfig._object_defaults:
......@@ -199,10 +184,10 @@ class BuildConfig(Object):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value)
def current_build_config():
"""Get the current build configuration."""
return _api_internal._GetCurrentBuildConfig()
@staticmethod
def current():
"""Get the current build configuration."""
return _ffi_api.GetCurrentBuildConfig()
def build_config(**kwargs):
......@@ -261,393 +246,9 @@ def build_config(**kwargs):
"""
node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in BuildConfig._object_defaults.items()}
config = make.node("BuildConfig", **node_args)
config = tvm.ir.make_node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs:
config.add_lower_pass = kwargs["add_lower_pass"]
return config
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = current_build_config()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, expr.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
Parameters
----------
sch : tvm.schedule.Schedule
The schedule to be built
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str, optional
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
f : LoweredFunc or Stmt
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(
stmt,
cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
mdev : tvm.module
A module that contains device code.
"""
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
def build(inputs,
args=None,
target=None,
target_host=None,
name="default_function",
binds=None):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
The schedule to be built
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
name : str, optional
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
.. code-block:: python
n = 2
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
else:
target_flist = inputs
for tar, flist in target_flist.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not target_host:
for tar, _ in target_flist.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin, wildcard-import
"""Namespace for Tensor-level IR"""
# expose all operators in tvm tir.op
from tvm.tir.op import *
from .schedule import Schedule, create_schedule
from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.te"""
import tvm._ffi
tvm._ffi._init_api("te", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Operation class for computation declaration."""
# pylint: disable=invalid-name
from numbers import Integral as _Integral
import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
from tvm._ffi.base import string_types
from tvm.runtime import convert
from . import tag as _tag
from . import tensor as _tensor
from . import _ffi_api
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
Returns
-------
tensor: Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
dtype = "float32" if dtype is None else dtype
return _ffi_api.Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="compute", tag="", attrs=None):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of indices-> value
Specifies the input source expression
name: str, optional
The name hint of the tensor
tag: str, optional
Additional tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor
The created tensor
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(tvm.tir.IterVar((0, s), var_name, 4))
op_node = _ffi_api.TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _ffi_api.ComputeOp(
name, tag, attrs, dim_var, body)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
"""Construct new tensors by scanning over axis.
Parameters
----------
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional
The name hint of the tensor
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
.. code-block:: python
# The following code is equivalent to numpy.cumsum
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _ffi_api.ScanOp(name, tag, attrs,
axis, init, update,
state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape,
inputs,
fcompute,
name="extern",
dtype=None,
in_buffers=None,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: tuple or list of tuples.
The shape of the outputs.
inputs: list of Tensor
The inputs
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
dtype: str or list of str, optional
The data types of outputs,
by default dtype will be same as inputs.
in_buffers: Buffer or list of Buffer, optional
Input buffers.
out_buffers: Buffer or list of Buffers, optional
Output buffers.
tag: str, optional
Additonal tag information about the compute.
attrs: dict, optional
The additional auxiliary attributes about the compute.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name="A")
B = tvm.placeholder((l, m), name="B")
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (tvm.tir.PrimExpr, _Integral)) else shape
if shape == () or isinstance(shape[0], (tvm.tir.PrimExpr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
% (len(inputs), len(in_buffers)))
if out_buffers is not None:
out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
if in_buffers is None:
input_placeholders.append(
tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
if len(types) != 1:
raise ValueError("Cannot infer output type, please provide dtype argument")
infered_type = types.pop()
dtype = [infered_type for _ in shape]
if isinstance(dtype, str):
dtype = [dtype]
if out_buffers is None:
for shp, dt in zip(shape, dtype):
output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, tvm.tir.PrimExpr):
body = tvm.tir.Evaluate(body)
op = _ffi_api.ExternOp(name, tag, attrs,
inputs, input_placeholders,
output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
def var(name="tindex", dtype="int32"):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : Var
The result symbolic variable.
"""
return tvm.tir.Var(name, dtype)
def size_var(name="size", dtype="int32"):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return tvm.tir.SizeVar(name, dtype)
def thread_axis(dom=None, tag="", name=""):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range or str
The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str, optional
The thread tag
name : str, optional
The name of the var.
Returns
-------
axis : IterVar
The thread itervar.
"""
if isinstance(dom, string_types):
tag, dom = dom, None
if not tag:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return tvm.tir.IterVar(dom, name, 1, tag)
def reduce_axis(dom, name="rv"):
"""Create a new IterVar for reduction.
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of the variable.
Returns
-------
axis : IterVar
An iteration variable representing the value.
"""
return tvm.tir.IterVar(dom, name, 2)
......@@ -21,10 +21,10 @@ from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import expr as _expr, Buffer
from tvm.tir import IterVar, Buffer
from . import _api_internal
from . import tensor as _tensor
from . import _ffi_api
@tvm._ffi.register_object
......@@ -42,31 +42,6 @@ class Singleton(Object):
"""Singleton axis."""
@tvm._ffi.register_object
class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable.
IterVar is normally created by Operation, to represent
axis iterations in the computation.
It can also created by schedule primitives like :any:`tvm.schedule.Stage.split`.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
_tensor.iter_var_cls = IterVar
def create_schedule(ops):
"""Create a schedule for list of ops
......@@ -82,7 +57,7 @@ def create_schedule(ops):
"""
if not isinstance(ops, (list, _container.Array)):
ops = [ops]
return _api_internal._CreateSchedule(ops)
return _ffi_api.CreateSchedule(ops)
@tvm._ffi.register_object
......@@ -108,7 +83,7 @@ class Schedule(Object):
sch : Schedule
The normalized schedule.
"""
return _api_internal._ScheduleNormalize(self)
return _ffi_api.ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
......@@ -137,7 +112,7 @@ class Schedule(Object):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
return _ffi_api.ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers):
......@@ -164,7 +139,7 @@ class Schedule(Object):
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
......@@ -192,7 +167,7 @@ class Schedule(Object):
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis, factor_axis=0):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.
......@@ -215,7 +190,7 @@ class Schedule(Object):
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis)
factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis)
return factored[0] if len(factored) == 1 else factored
......@@ -247,11 +222,11 @@ class Stage(Object):
if nparts is not None:
if factor is not None:
raise ValueError("Do not need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
else:
if factor is None:
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, *args):
......@@ -270,7 +245,7 @@ class Stage(Object):
fused : IterVar
The fused variable of iteration.
"""
fused = _api_internal._StageFuse(self, args)
fused = _ffi_api.StageFuse(self, args)
return fused
def set_scope(self, scope):
......@@ -281,7 +256,7 @@ class Stage(Object):
scope : str
The thread scope of this stage
"""
return _api_internal._StageSetScope(self, scope)
return _ffi_api.StageSetScope(self, scope)
def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar
......@@ -294,7 +269,7 @@ class Stage(Object):
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
_ffi_api.StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
......@@ -306,7 +281,7 @@ class Stage(Object):
"""
if isinstance(threads, IterVar):
threads = [threads]
_api_internal._StageEnvThreads(self, threads)
_ffi_api.StageEnvThreads(self, threads)
def set_store_predicate(self, predicate):
"""Set predicate under which store to the array can be performed.
......@@ -319,7 +294,7 @@ class Stage(Object):
predicate : Expr
The guard condition fo store.
"""
_api_internal._StageSetStorePredicate(self, predicate)
_ffi_api.StageSetStorePredicate(self, predicate)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......@@ -332,7 +307,7 @@ class Stage(Object):
scope : IterVar
The loop scope t be attached to.
"""
_api_internal._StageComputeAt(self, parent, scope)
_ffi_api.StageComputeAt(self, parent, scope)
def compute_inline(self):
"""Mark stage as inline
......@@ -342,7 +317,7 @@ class Stage(Object):
parent : Stage
The parent stage
"""
_api_internal._StageComputeInline(self)
_ffi_api.StageComputeInline(self)
def compute_root(self):
"""Attach the stage at parent, and mark it as root
......@@ -352,7 +327,7 @@ class Stage(Object):
parent : Stage
The parent stage
"""
_api_internal._StageComputeRoot(self)
_ffi_api.StageComputeRoot(self)
def reorder(self, *args):
"""reorder the arguments in the specified order.
......@@ -362,7 +337,7 @@ class Stage(Object):
args : list of IterVar
The order to be ordered
"""
_api_internal._StageReorder(self, args)
_ffi_api.StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions
......@@ -392,7 +367,7 @@ class Stage(Object):
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
......@@ -404,7 +379,7 @@ class Stage(Object):
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)
_ffi_api.StageVectorize(self, var)
def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin
......@@ -417,7 +392,7 @@ class Stage(Object):
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
"""
_api_internal._StageTensorize(self, var, tensor_intrin)
_ffi_api.StageTensorize(self, var, tensor_intrin)
def unroll(self, var):
"""Unroll the iteration.
......@@ -427,7 +402,7 @@ class Stage(Object):
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
_ffi_api.StageUnroll(self, var)
def parallel(self, var):
"""Parallelize the iteration.
......@@ -437,7 +412,7 @@ class Stage(Object):
var : IterVar
The iteration to be parallelized.
"""
_api_internal._StageParallel(self, var)
_ffi_api.StageParallel(self, var)
def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma
......@@ -489,7 +464,7 @@ class Stage(Object):
"""
if isinstance(pragma_value, string_types):
pragma_value = convert(pragma_value)
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
_ffi_api.StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
......@@ -503,7 +478,7 @@ class Stage(Object):
offset : Expr
The number of iterations to be prefetched before actual execution
"""
_api_internal._StagePrefetch(self, tensor, var, offset)
_ffi_api.StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis
......@@ -523,7 +498,7 @@ class Stage(Object):
offset : int
The offset in the alignment specification.
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)
_ffi_api.StageStorageAlign(self, axis, factor, offset)
def double_buffer(self):
"""Compute the current stage via double buffering.
......@@ -532,13 +507,14 @@ class Stage(Object):
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_api_internal._StageDoubleBuffer(self)
_ffi_api.StageDoubleBuffer(self)
def opengl(self):
"""The special OpenGL schedule
Maps each output element to a pixel.
"""
_api_internal._StageOpenGL(self)
_ffi_api.StageOpenGL(self)
tvm._ffi._init_api("tvm.schedule")
tvm._ffi._init_api("schedule", __name__)
......@@ -16,7 +16,7 @@
# under the License.
"""Tag class for TVM operators."""
import warnings
from ._ffi.base import decorate
from tvm._ffi.base import decorate
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
......
......@@ -14,15 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor and Operation class for computation declaration."""
"""Tensor class for computation declaration."""
# pylint: disable=invalid-name
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr
from . import _api_internal
from . import _ffi_api
class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
......@@ -52,9 +51,6 @@ class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None
@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
......@@ -68,7 +64,7 @@ class Tensor(Object, _expr.ExprOp):
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, iter_var_cls):
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")
......@@ -81,7 +77,7 @@ class Tensor(Object, _expr.ExprOp):
return TensorSlice(self, indices)
def __hash__(self):
return _api_internal._TensorHash(self)
return _ffi_api.TensorHash(self)
def __eq__(self, other):
if not isinstance(other, Tensor):
......@@ -92,7 +88,7 @@ class Tensor(Object, _expr.ExprOp):
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison")
return _api_internal._TensorEqual(self, other)
return _ffi_api.TensorEqual(self, other)
@property
def ndim(self):
......@@ -143,17 +139,17 @@ class Operation(Object):
out : Tensor
The i-th output.
"""
return _api_internal._OpGetOutput(self, index)
return _ffi_api.OpGetOutput(self, index)
@property
def num_outputs(self):
"""Number of outputs from this op."""
return _api_internal._OpNumOutputs(self)
return _ffi_api.OpNumOutputs(self)
@property
def input_tensors(self):
"""List of input tensors to this op."""
return _api_internal._OpInputTensors(self)
return _ffi_api.OpInputTensors(self)
@tvm._ffi.register_object
......
......@@ -16,17 +16,15 @@
# under the License.
"""Tensor intrinsics"""
import tvm._ffi
import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, convert
from tvm.ir import Range
from tvm.tir import expr as _expr
from tvm.tir import stmt as _stmt
from tvm.target import BuildConfig
from .tensor import PlaceholderOp
from . import _api_internal
from . import api as _api
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from . import _ffi_api
def _get_region(tslice):
......@@ -34,15 +32,16 @@ def _get_region(tslice):
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
region.append(Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
if isinstance(idx, tvm.tir.IterVar):
begin = idx.var
else:
begin = idx
region.append(Range.make_by_min_extent(begin, 1))
return region
@tvm._ffi.register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
......@@ -60,10 +59,11 @@ class TensorIntrin(Object):
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
reduce_axis = convert(reduce_axis)
if scalar_inputs:
scalar_inputs = _api.convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
scalar_inputs = convert(scalar_inputs)
return _ffi_api.TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
def decl_tensor_intrin(op,
fcompute,
......@@ -119,15 +119,15 @@ def decl_tensor_intrin(op,
binds_list = []
for t in inputs:
if not isinstance(t.op, _tensor.PlaceholderOp):
if not isinstance(t.op, PlaceholderOp):
raise ValueError("Do not yet support composition op")
cfg = current_build_config()
cfg = BuildConfig.current()
for t in tensors:
buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
binds_list.append(buf)
if scalar_params:
......@@ -135,10 +135,10 @@ def decl_tensor_intrin(op,
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
body = [body]
body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
return _ffi_api.TensorIntrin(
name, op, inputs, binds_list, scalar_params, *body)
......@@ -17,6 +17,8 @@
""" TVM testing utilities """
import logging
import numpy as np
import tvm._ffi
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
""" Version of np.testing.assert_allclose with `atol` and `rtol` fields set
......@@ -161,3 +163,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
logging.info("Numerical grad test wrt '%s' of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff)
tvm._ffi._init_api("testing", __name__)
......@@ -23,16 +23,18 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, min_value, max_value
from .op import call_llvm_intrin, all, any, min_value, max_value
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from . import ir_builder
from . import ir_pass
......@@ -309,6 +309,57 @@ class SizeVar(Var):
@tvm._ffi.register_object
class IterVar(Object, ExprOp):
"""Represent iteration variable.
IterVar represents axis iterations in the computation.
Parameters
----------
dom : Range
The domain of the iteration.
var : Union[Var, str]
The internal variable that is used for iteration.
iter_type : int
The iteration type.
thread_tag : str
The thread type tag.
See Also
--------
tvm.thread_axis: Create thread axis IterVar.
tvm.reduce_axis: Create reduce axis IterVar.
"""
DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
Tensorized = 8
def __init__(self, dom, var, iter_type, thread_tag=""):
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise TypeError("need to be list of ranges")
dom = tvm.ir.Range(dom[0], dom[1])
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
@tvm._ffi.register_object
class CommReducer(Object):
"""Communicative reduce operator
......
......@@ -14,13 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
import tvm._ffi
from tvm.runtime import convert, const
from tvm.schedule import Buffer
from tvm.ir import Array
from .expr import Call
from .buffer import Buffer
from .expr import Call, Var, CommReducer
from . import _ffi_api
......@@ -196,6 +197,53 @@ def call_llvm_intrin(dtype, name, *args):
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpOr(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _ffi_api._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = _ffi_api._OpAnd(ret, args[i])
return ret
@tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args):
print(list(args))
......@@ -780,3 +828,137 @@ def floormod(a, b):
The result expression.
"""
return _ffi_api._OpFloorMod(a, b)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = tvm.var("n")
m = tvm.var("m")
mysum = tvm.comm_reducer(lambda x, y: x+y,
lambda t: tvm.const(0, dtype=t), name="mysum")
A = tvm.placeholder((n, m), name="A")
k = tvm.reduce_axis((0, m), name="k")
B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
def _reduce_directly(*args):
num = len(args)
# process `where` is None
if num == 3 and args[2] is None:
num = 2
res = args[0]
for i in range(num-1):
res = fcombine(res, args[i+1])
return res
def _make_reduce(expr, axis, where=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + "_" + str(i)
larr.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rarr.append(Var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, tvm.ir.PrimExpr)
size = 1
dtype = expr.dtype
lvar = Var(code.co_varnames[0], dtype)
rvar = Var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = convert(id_elem)
combiner = CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
def reducer(expr, axis, where=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where)
if where is None:
assert not args
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
doc_str = """Create a {0} expression over axis.
Parameters
----------
expr : PrimExpr
The source expression.
axis : IterVar
The reduction IterVar axis
where : optional, Expr
Filtering predicate of the reduction.
Returns
-------
value : PrimExpr
The result value.
Example
-------
.. code-block:: python
m = tvm.var("m")
n = tvm.var("n")
A = tvm.placeholder((m, n), name="A")
k = tvm.reduce_axis((0, n), name="k")
# there are two way to use this {0} reducer:
# mode 1, accept (expr, axis, where) to produce an Reduce Expr
B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
# mode 2, simply use it with multiple Exprs:
{0}_res = tvm.{0}(m, n)
"""
reducer.__doc__ = doc_str.format(name)
return reducer
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y), max_value, name="min")
max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y), min_value, name="max")
......@@ -64,16 +64,16 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("_IntervalSetGetMin")
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin")
.set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("_IntervalSetGetMax")
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax")
.set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("_IntSetIsNothing")
TVM_REGISTER_GLOBAL("arith._IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("_IntSetIsEverything")
TVM_REGISTER_GLOBAL("arith._IntSetIsEverything")
.set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
......
......@@ -40,115 +40,113 @@ TVM_REGISTER_GLOBAL("tir.min_value")
TVM_REGISTER_GLOBAL("tir.max_value")
.set_body_typed(max_value);
TVM_REGISTER_GLOBAL("Range")
TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = Range(0, args[0]);
} else {
*ret = Range(args[0], args[1]);
}
*ret = Range(args[0], args[1]);
});
namespace tir {
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
}
namespace te {
TVM_REGISTER_GLOBAL("_Tensor")
TVM_REGISTER_GLOBAL("te.Tensor")
.set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrin")
TVM_REGISTER_GLOBAL("te.TensorIntrin")
.set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_GLOBAL("_TensorIntrinCall")
TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
.set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_GLOBAL("_TensorEqual")
TVM_REGISTER_GLOBAL("te.TensorEqual")
.set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("_TensorHash")
TVM_REGISTER_GLOBAL("te.TensorHash")
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_GLOBAL("_Placeholder")
TVM_REGISTER_GLOBAL("te.Placeholder")
.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
TVM_REGISTER_GLOBAL("_ComputeOp")
TVM_REGISTER_GLOBAL("te.ComputeOp")
.set_body_typed(ComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ScanOp")
TVM_REGISTER_GLOBAL("te.ScanOp")
.set_body_typed(ScanOpNode::make);
TVM_REGISTER_GLOBAL("_TensorComputeOp")
TVM_REGISTER_GLOBAL("te.TensorComputeOp")
.set_body_typed(TensorComputeOpNode::make);
TVM_REGISTER_GLOBAL("_ExternOp")
TVM_REGISTER_GLOBAL("te.ExternOp")
.set_body_typed(ExternOpNode::make);
TVM_REGISTER_GLOBAL("_HybridOp")
TVM_REGISTER_GLOBAL("te.HybridOp")
.set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("_OpGetOutput")
TVM_REGISTER_GLOBAL("te.OpGetOutput")
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
TVM_REGISTER_GLOBAL("_OpNumOutputs")
TVM_REGISTER_GLOBAL("te.OpNumOutputs")
.set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_GLOBAL("_OpInputTensors")
TVM_REGISTER_GLOBAL("te.OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("_IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
thread_tag);
});
TVM_REGISTER_GLOBAL("_CreateSchedule")
TVM_REGISTER_GLOBAL("te.CreateSchedule")
.set_body_typed(create_schedule);
TVM_REGISTER_GLOBAL("_StageSetScope")
TVM_REGISTER_GLOBAL("te.StageSetScope")
.set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("_StageBind")
TVM_REGISTER_GLOBAL("te.StageBind")
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor")
TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageSplitByNParts")
TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageFuse")
TVM_REGISTER_GLOBAL("te.StageFuse")
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
});
TVM_REGISTER_GLOBAL("_StageComputeAt")
TVM_REGISTER_GLOBAL("te.StageComputeAt")
.set_body_method(&Stage::compute_at);
TVM_REGISTER_GLOBAL("_StageComputeInline")
TVM_REGISTER_GLOBAL("te.StageComputeInline")
.set_body_method(&Stage::compute_inline);
TVM_REGISTER_GLOBAL("_StageComputeRoot")
TVM_REGISTER_GLOBAL("te.StageComputeRoot")
.set_body_method(&Stage::compute_root);
TVM_REGISTER_GLOBAL("_StageReorder")
TVM_REGISTER_GLOBAL("te.StageReorder")
.set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("_StageTile")
TVM_REGISTER_GLOBAL("te.StageTile")
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
......@@ -162,49 +160,49 @@ TVM_REGISTER_GLOBAL("_StageTile")
return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_GLOBAL("_StageEnvThreads")
TVM_REGISTER_GLOBAL("te.StageEnvThreads")
.set_body_method(&Stage::env_threads);
TVM_REGISTER_GLOBAL("_StageSetStorePredicate")
TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
.set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_GLOBAL("_StageUnroll")
TVM_REGISTER_GLOBAL("te.StageUnroll")
.set_body_method(&Stage::unroll);
TVM_REGISTER_GLOBAL("_StageVectorize")
TVM_REGISTER_GLOBAL("te.StageVectorize")
.set_body_method(&Stage::vectorize);
TVM_REGISTER_GLOBAL("_StageTensorize")
TVM_REGISTER_GLOBAL("te.StageTensorize")
.set_body_method(&Stage::tensorize);
TVM_REGISTER_GLOBAL("_StageParallel")
TVM_REGISTER_GLOBAL("te.StageParallel")
.set_body_method(&Stage::parallel);
TVM_REGISTER_GLOBAL("_StagePragma")
TVM_REGISTER_GLOBAL("te.StagePragma")
.set_body_method(&Stage::pragma);
TVM_REGISTER_GLOBAL("_StagePrefetch")
TVM_REGISTER_GLOBAL("te.StagePrefetch")
.set_body_method(&Stage::prefetch);
TVM_REGISTER_GLOBAL("_StageStorageAlign")
TVM_REGISTER_GLOBAL("te.StageStorageAlign")
.set_body_method(&Stage::storage_align);
TVM_REGISTER_GLOBAL("_StageDoubleBuffer")
TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
.set_body_method(&Stage::double_buffer);
TVM_REGISTER_GLOBAL("_StageOpenGL")
TVM_REGISTER_GLOBAL("te.StageOpenGL")
.set_body_method(&Stage::opengl);
TVM_REGISTER_GLOBAL("_ScheduleNormalize")
TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
.set_body_method(&Schedule::normalize);
TVM_REGISTER_GLOBAL("_ScheduleCreateGroup")
TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
.set_body_method(&Schedule::create_group);
TVM_REGISTER_GLOBAL("_ScheduleCacheRead")
TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
.set_body_method(&Schedule::cache_read);
TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[1].IsObjectRef<Tensor>()) {
*ret = args[0].operator Schedule()
......@@ -215,11 +213,11 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
}
});
TVM_REGISTER_GLOBAL("_ScheduleRFactor")
TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
} // namespace te
TVM_REGISTER_GLOBAL("_CommReducerCombine")
TVM_REGISTER_GLOBAL("te.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
} // namespace tvm
......@@ -47,9 +47,9 @@ TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS(PassName) \
#define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_GLOBAL("schedule."#PassName) \
.set_body_typed(PassName); \
.set_body_typed(PassName); \
REGISTER_SCHEDULE_PASS(InferBound);
......
......@@ -54,11 +54,11 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TVM_REGISTER_NODE_TYPE(TestAttrs);
TVM_REGISTER_GLOBAL("_nop")
TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_GLOBAL("_test_wrap_callback")
TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
*ret = runtime::TypedPackedFunc<void()>([pf](){
......@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("_test_wrap_callback")
});
});
TVM_REGISTER_GLOBAL("_test_raise_error_callback")
TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void()>([msg](){
......@@ -74,7 +74,7 @@ TVM_REGISTER_GLOBAL("_test_raise_error_callback")
});
});
TVM_REGISTER_GLOBAL("_test_check_eq_callback")
TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
......@@ -82,7 +82,7 @@ TVM_REGISTER_GLOBAL("_test_check_eq_callback")
});
});
TVM_REGISTER_GLOBAL("_context_test")
TVM_REGISTER_GLOBAL("testing.context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0];
int dtype = args[1];
......@@ -103,11 +103,11 @@ void ErrorTest(int x, int y) {
}
}
TVM_REGISTER_GLOBAL("_ErrorTest")
TVM_REGISTER_GLOBAL("testing.ErrorTest")
.set_body_typed(ErrorTest);
// internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("_ndarray_use_count")
TVM_REGISTER_GLOBAL("testing.ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0];
// substract the current one
......
......@@ -403,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});
TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig")
TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
......@@ -418,13 +418,13 @@ class BuildConfig::Internal {
}
};
TVM_REGISTER_GLOBAL("_EnterBuildConfigScope")
TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope")
.set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
.set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
......@@ -437,7 +437,7 @@ TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo")
TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.schedule import Buffer
from tvm.tir import Buffer
import numpy as np
def test_buffer():
......@@ -25,7 +25,7 @@ def test_buffer():
Ab = tvm.decl_buffer((m, n), tvm.float32)
Bb = tvm.decl_buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.schedule.Buffer)
assert isinstance(Ab, tvm.tir.Buffer)
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
......
......@@ -22,8 +22,8 @@ def test_expr_constructor():
assert x.name == "xx"
x = tvm.tir.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)],
None, 0)
[tvm.tir.IterVar((0, 1), "x", 2)],
None, 0)
assert isinstance(x, tvm.tir.Reduce)
assert x.combiner == None
assert x.value_index == 0
......
......@@ -16,9 +16,10 @@
# under the License.
"""Test runtime error handling"""
import tvm
import tvm.testing
def test_op_translation():
ferror = tvm._api_internal._test_raise_error_callback(
ferror = tvm.testing.test_raise_error_callback(
"OpNotImplemented: myop")
try:
ferror()
......@@ -28,7 +29,7 @@ def test_op_translation():
assert isinstance(e, NotImplementedError)
assert msg.find("api_test.cc") != -1
fchk_eq = tvm._api_internal._test_check_eq_callback(
fchk_eq = tvm.testing.test_check_eq_callback(
"InternalError: myop")
try:
fchk_eq(0, 1)
......@@ -38,7 +39,7 @@ def test_op_translation():
assert msg.find("api_test.cc") != -1
try:
tvm._api_internal._ErrorTest(0, 1)
tvm.testing.ErrorTest(0, 1)
assert False
except ValueError as e:
msg = str(e)
......@@ -48,13 +49,13 @@ def test_op_translation():
def test_deep_callback():
def error_callback():
raise ValueError("callback error")
wrap1 = tvm._api_internal._test_wrap_callback(error_callback)
wrap1 = tvm.testing.test_wrap_callback(error_callback)
def flevel2():
wrap1()
wrap2 = tvm._api_internal._test_wrap_callback(flevel2)
wrap2 = tvm.testing.test_wrap_callback(flevel2)
def flevel3():
wrap2()
wrap3 = tvm._api_internal._test_wrap_callback(flevel3)
wrap3 = tvm.testing.test_wrap_callback(flevel3)
try:
wrap3()
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import numpy as np
def test_get_global():
......@@ -93,7 +94,7 @@ def test_ctx():
x = test_ctx_func(tvm.gpu(7))
assert x == tvm.cpu(0)
x = tvm.opencl(10)
x = tvm._api_internal._context_test(x, x.device_type, x.device_id)
x = tvm.testing.context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10)
def test_trace_default_action():
......@@ -282,4 +283,3 @@ if __name__ == "__main__":
test_trace_default_action()
test_trace_can_change_traced_value_int()
test_trace_can_change_traced_value_float()
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import os
import logging
import time
......@@ -210,7 +211,7 @@ def test_rpc_return_ndarray():
if name == "get_arr":
return lambda : nd
elif name == "ref_count":
return lambda : tvm._api_internal._ndarray_use_count(nd)
return lambda : tvm.testing.ndarray_use_count(nd)
elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem":
......
......@@ -96,7 +96,7 @@ def lower(*args, **kwargs):
--------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
......@@ -113,7 +113,7 @@ def build(*args, **kwargs):
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
cfg = tvm.target.BuildConfig.current()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment