Unverified Commit 6eecec92 by Tianqi Chen Committed by GitHub

[PYTHON] Enable constructors in Node (#1647)

parent 62d34ca5
......@@ -17,6 +17,7 @@ from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .node import NodeBase
from . import node as _node
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
......@@ -186,6 +187,23 @@ class FunctionBase(object):
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
def __init_handle_by_constructor__(fconstructor, args):
"""Initialize handle by constructor"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
handle = ret_val.v_handle
return handle
def _return_module(x):
"""Return function"""
handle = x.v_handle
......@@ -202,6 +220,7 @@ def _handle_return_func(x):
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
......
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
# pylint: disable=no-member, missing-docstring, not-callable
from __future__ import absolute_import
import ctypes
......@@ -9,6 +9,7 @@ from .types import TVMValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p
__init_by_constructor__ = None
"""Maps node type to its constructor"""
NODE_TYPE = {}
......@@ -58,4 +59,26 @@ class NodeBase(object):
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
self.handle = handle
_set_class_node_base(NodeBase)
......@@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef inline int FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef int[3] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
nargs, ret_val, ret_tcode))
return 0
cdef inline object FuncCall(void* chandle, tuple args):
cdef inline int FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef int nargs
nargs = len(args)
if nargs <= 3:
return FuncCall3(chandle, args, nargs)
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0
cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
nargs, ret_val, ret_tcode))
return 0
cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0
cdef class FunctionBase:
......@@ -264,7 +281,10 @@ cdef class FunctionBase:
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
return FuncCall(self.chandle, args)
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
......
......@@ -65,4 +65,27 @@ cdef class NodeBase:
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(NodeBase)
......@@ -262,23 +262,7 @@ def extract_ext_funcs(finit):
def _get_api(f):
flocal = f
flocal.is_global = True
def my_api_func(*args):
"""
This is a type erased API that calls into Global PackedFunc.
These APIs corresponds to functions registered from C++ backend
and can be used as developer functions.
args : list
The positional arguments to the function call.
Returns
-------
value : int, float, None, Node or Function
The result of the API function call.
"""
return flocal(*args)
return my_api_func
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
......
......@@ -134,9 +134,9 @@ def any(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.Or(args[0], args[1])
ret = _expr.Or(args[0], args[1])
for i in range(2, len(args)):
ret = _make.Or(ret, args[i])
ret = _expr.Or(ret, args[i])
return ret
......@@ -158,9 +158,9 @@ def all(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.And(args[0], args[1])
ret = _expr.And(args[0], args[1])
for i in range(2, len(args)):
ret = _make.And(ret, args[i])
ret = _expr.And(ret, args[i])
return ret
......@@ -616,7 +616,7 @@ def select(cond, t, f):
node : Node
The tvm.expr.Select node
"""
return _make.Select(convert(cond), convert(t), convert(f))
return _expr.Select(convert(cond), convert(t), convert(f))
def comm_reducer(fcombine, fidentity, name="reduce"):
......@@ -699,7 +699,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs
......@@ -751,5 +751,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max')
min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max')
......@@ -225,127 +225,545 @@ class LogicalExpr(Expr):
@register_node("Variable")
class Var(Expr):
"""Symbolic variable."""
pass
"""Symbolic variable.
Parameters
----------
name : str
The name
dtype : int
The data type
"""
def __init__(self, name, dtype):
self.__init_handle_by_constructor__(
_api_internal._Var, name, dtype)
@register_node
class Reduce(Expr):
pass
"""Reduce node.
Parameters
----------
combiner : CommReducer
The combiner.
src : list of Expr
The source expression.
rdom : list of IterVar
The iteration domain
condition : Expr
The reduce condition.
value_index : int
The value index.
"""
def __init__(self, combiner, src, rdom, condition, value_index):
self.__init_handle_by_constructor__(
_make.Reduce, combiner, src, rdom,
condition, value_index)
@register_node
class FloatImm(ConstExpr):
pass
"""Float constant.
Parameters
----------
dtype : str
The data type
value : float
The constant value.
"""
def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value)
@register_node
class IntImm(ConstExpr):
pass
"""Int constant.
Parameters
----------
dtype : str
The data type
value : int
The constant value.
"""
def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.IntImm, dtype, value)
@register_node
class UIntImm(ConstExpr):
pass
"""UInt constant.
Parameters
----------
dtype : str
The data type
value : int
The constant value.
"""
def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.UIntImm, dtype, value)
@register_node
class StringImm(ConstExpr):
pass
"""String constant.
Parameters
----------
value : str
The value of the function.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.StringImm, value)
@register_node
class Cast(Expr):
pass
"""Cast expression.
Parameters
----------
dtype : str
The data type
value : Expr
The value of the function.
"""
def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
_make.Cast, dtype, value)
@register_node
class Add(BinaryOpExpr):
pass
"""Add node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Add, a, b)
@register_node
class Sub(BinaryOpExpr):
pass
"""Sub node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Sub, a, b)
@register_node
class Mul(BinaryOpExpr):
pass
"""Mul node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Mul, a, b)
@register_node
class Div(BinaryOpExpr):
pass
"""Div node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Div, a, b)
@register_node
class Mod(BinaryOpExpr):
pass
"""Mod node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Mod, a, b)
@register_node
class Min(BinaryOpExpr):
pass
"""Min node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Min, a, b)
@register_node
class Max(BinaryOpExpr):
pass
"""Max node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Max, a, b)
@register_node
class EQ(CmpExpr):
pass
"""EQ node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.EQ, a, b)
@register_node
class NE(CmpExpr):
pass
"""NE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.NE, a, b)
@register_node
class LT(CmpExpr):
pass
"""LT node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.LT, a, b)
@register_node
class LE(CmpExpr):
pass
"""LE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.LE, a, b)
@register_node
class GT(CmpExpr):
pass
"""GT node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.GT, a, b)
@register_node
class GE(CmpExpr):
pass
"""GE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.GE, a, b)
@register_node
class And(LogicalExpr):
pass
"""And node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.And, a, b)
@register_node
class Or(LogicalExpr):
pass
"""Or node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def __init__(self, a, b):
self.__init_handle_by_constructor__(
_make.Or, a, b)
@register_node
class Not(LogicalExpr):
pass
"""Not node.
Parameters
----------
a : Expr
The input value
"""
def __init__(self, a):
self.__init_handle_by_constructor__(
_make.Not, a)
@register_node
class Select(Expr):
pass
"""Select node.
Parameters
----------
condition : Expr
The condition expression.
true_value : Expr
The value to take when condition is true.
false_value : Expr
The value to take when condition is false.
"""
def __init__(self, condition, true_value, false_value):
self.__init_handle_by_constructor__(
_make.Select, condition, true_value, false_value)
@register_node
class Load(Expr):
pass
"""Load node.
Parameters
----------
dtype : str
The data type.
buffer_var : Var
The buffer variable in the load expression.
index : Expr
The index in the load.
predicate : Expr
The load predicate.
"""
def __init__(self, dtype, buffer_var, index, predicate):
self.__init_handle_by_constructor__(
_make.Load, dtype, buffer_var, index, predicate)
@register_node
class Ramp(Expr):
pass
"""Ramp node.
Parameters
----------
base : Expr
The base expression.
stride : ramp stride
The stride of the ramp.
lanes : int
The lanes of the expression.
"""
def __init__(self, base, stride, lanes):
self.__init_handle_by_constructor__(
_make.Ramp, base, stride, lanes)
@register_node
class Broadcast(Expr):
pass
"""Broadcast node.
Parameters
----------
value : Expr
The value of the expression.
lanes : int
The lanes of the expression.
"""
def __init__(self, value, lanes):
self.__init_handle_by_constructor__(
_make.Broadcast, value, lanes)
@register_node
class Shuffle(Expr):
pass
"""Shuffle node.
Parameters
----------
vectors : Array of Expr
The vectors
indices : Array of indices
The indices
"""
def __init__(self, vectors, indices):
self.__init_handle_by_constructor__(
_make.Shuffle, vectors, indices)
@register_node
class Call(Expr):
"""Call node.
Parameters
----------
dtype : str
The return data type
name : str
The name of the function
args : list of Expr
The input arguments to the call
call_type : int
The type of the call
func : Operation, optional
Operation if call_type is Halide
value_index : int
The output value index
"""
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
def __init__(self, dtype, name, args, call_type, func, value_index):
self.__init_handle_by_constructor__(
_make.Call, dtype, name, args, call_type, func, value_index)
@register_node
class Let(Expr):
pass
"""Let node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Expr
The body expression.
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_make.Let, var, value, body)
......@@ -6,9 +6,10 @@ The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
from __future__ import absolute_import as _abs
from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType
from . import stmt as _stmt
def range_by_min_extent(min_value, extent):
"""Construct a Range by min and extent.
......@@ -98,44 +99,4 @@ def node(type_key, **kwargs):
return _Node(*args)
def stmt_seq(*args):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret = None
for value in args:
if not isinstance(value, _stmt.Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
def stmt_list(stmt):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, _stmt.Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, _stmt.ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_init_api("tvm.make")
......@@ -15,65 +15,376 @@ Each statement node have subfields that can be visited from python side.
"""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from . import make as _make
class Stmt(NodeBase):
pass
@register_node
class LetStmt(Stmt):
pass
"""LetStmt node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Stmt
The body statement.
"""
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_make.LetStmt, var, value, body)
@register_node
class AssertStmt(Stmt):
pass
"""AssertStmt node.
Parameters
----------
condition : Expr
The assert condition.
message : Expr
The error message.
body : Stmt
The body statement.
"""
def __init__(self, condition, message, body):
self.__init_handle_by_constructor__(
_make.AssertStmt, condition, message, body)
@register_node
class ProducerConsumer(Stmt):
pass
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_make.ProducerConsumer, func, is_producer, body)
@register_node
class For(Stmt):
"""For node.
Parameters
----------
loop_var : Var
The loop variable.
min_val : Expr
The begining value.
extent : Expr
The length of the loop.
for_type : int
The for type.
device_api : int
The device api type.
body : Stmt
The body statement.
"""
Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
def __init__(self,
loop_var,
min_val,
extent,
for_type,
device_api,
body):
self.__init_handle_by_constructor__(
_make.For, loop_var, min_val, extent,
for_type, device_api, body)
@register_node
class Store(Stmt):
pass
"""Store node.
Parameters
----------
buffer_var : Var
The buffer Variable.
value : Expr
The value we want to store.
index : Expr
The index in the store expression.
predicate : Expr
The store predicate.
"""
def __init__(self, buffer_var, value, index, predicate):
self.__init_handle_by_constructor__(
_make.Store, buffer_var, value, index, predicate)
@register_node
class Provide(Stmt):
pass
"""Provide node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
value : Expr
The value to be stored.
args : list of Expr
The index arguments of the Provide.
"""
def __init__(self, func, value_index, value, args):
self.__init_handle_by_constructor__(
_make.Provide, func, value_index, value, args)
@register_node
class Allocate(Stmt):
pass
"""Allocate node.
Parameters
----------
buffer_var : Var
The buffer variable.
dtype : str
The data type of the buffer.
extents : list of Expr
The extents of the allocate
condition : Expr
The condition.
body : Stmt
The body statement.
"""
def __init__(self,
buffer_var,
dtype,
extents,
condition,
body):
self.__init_handle_by_constructor__(
_make.Allocate, buffer_var, dtype,
extents, condition, body)
@register_node
class AttrStmt(Stmt):
pass
"""AttrStmt node.
Parameters
----------
node : Node
The node to annotate the attribute
attr_key : str
Attribute type key.
value : Expr
The value of the attribute
body : Stmt
The body statement.
"""
def __init__(self, node, attr_key, value, body):
self.__init_handle_by_constructor__(
_make.AttrStmt, node, attr_key, value, body)
@register_node
class Free(Stmt):
pass
"""Free node.
Parameters
----------
buffer_var : Var
The buffer variable.
"""
def __init__(self, buffer_var):
self.__init_handle_by_constructor__(
_make.Free, buffer_var)
@register_node
class Realize(Stmt):
pass
"""Realize node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type of the operation.
bounds : list of range
The bound of realize
condition : Expr
The realize condition.
body : Stmt
The realize body
"""
def __init__(self,
func,
value_index,
dtype,
bounds,
condition,
body):
self.__init_handle_by_constructor__(
_make.Realize, func, value_index, dtype,
bounds, condition, body)
@register_node
class Block(Stmt):
pass
"""Block node.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
"""
def __init__(self, first, rest):
self.__init_handle_by_constructor__(
_make.Block, first, rest)
@register_node
class IfThenElse(Stmt):
pass
"""IfThenElse node.
Parameters
----------
condition : Expr
The expression
then_case : Stmt
The statement to execute if condition is true.
else_case : Stmt
The statement to execute if condition is false.
"""
def __init__(self, condition, then_case, else_case):
self.__init_handle_by_constructor__(
_make.IfThenElse, condition, then_case, else_case)
@register_node
class Evaluate(Stmt):
pass
"""Evaluate node.
Parameters
----------
value : Expr
The expression to be evalued.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.Evaluate, value)
@register_node
class Prefetch(Stmt):
pass
"""Prefetch node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
def __init__(self, func, value_index, dtype, bounds):
self.__init_handle_by_constructor__(
_make.Prefetch, func, value_index, dtype, bounds)
def stmt_seq(*args):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret = None
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
def stmt_list(stmt):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq
......@@ -170,6 +170,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
......
import tvm
def test_expr_constructor():
x = tvm.expr.Var("xx", "float32")
assert isinstance(x, tvm.expr.Var)
assert x.name == "xx"
x = tvm.expr.Reduce(None, [1],
[tvm.api._IterVar((0, 1), "x", 2)],
None, 0)
assert isinstance(x, tvm.expr.Reduce)
assert x.combiner == None
assert x.value_index == 0
x = tvm.expr.FloatImm("float32", 1.0)
assert isinstance(x, tvm.expr.FloatImm)
assert x.value == 1.0
assert x.dtype == "float32"
x = tvm.expr.IntImm("int64", 2)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 2
assert x.dtype == "int64"
x = tvm.expr.UIntImm("uint16", 2)
assert isinstance(x, tvm.expr.UIntImm)
assert x.value == 2
assert x.dtype == "uint16"
x = tvm.expr.StringImm("xyza")
assert isinstance(x, tvm.expr.StringImm)
assert x.value == "xyza"
x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1))
assert isinstance(x, tvm.expr.Cast)
assert x.dtype == "float32"
assert x.value.value == 1
a = tvm.const(1.0, dtype="float32")
b = tvm.var("x", dtype="float32")
for cls in [tvm.expr.Add,
tvm.expr.Sub,
tvm.expr.Mul,
tvm.expr.Div,
tvm.expr.Mod,
tvm.expr.Min,
tvm.expr.Max,
tvm.expr.LT,
tvm.expr.LE,
tvm.expr.GT,
tvm.expr.GE]:
x = cls(a, b)
assert isinstance(x, cls)
assert x.a == a
assert x.b.same_as(b)
a = tvm.convert(tvm.var("x") > 1)
b = tvm.convert(tvm.var("x") == 1)
for cls in [tvm.expr.And,
tvm.expr.Or]:
x = cls(a, b)
assert isinstance(x, cls)
assert x.a == a
assert x.b.same_as(b)
x = tvm.expr.Not(a)
assert isinstance(x, tvm.expr.Not)
assert x.a == a
x = tvm.expr.Select(a, a, b)
assert isinstance(x, tvm.expr.Select)
assert x.true_value == a
assert x.false_value == b
assert x.condition == a
buffer_var = tvm.var("x", dtype="handle")
x = tvm.expr.Load("float32", buffer_var, 1, a)
assert isinstance(x, tvm.expr.Load)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.index.value == 1
assert x.predicate == a
x = tvm.expr.Ramp(1, 2, 10)
assert isinstance(x, tvm.expr.Ramp)
assert x.base.value == 1
assert x.stride.value == 2
assert x.lanes == 10
x = tvm.expr.Broadcast(a, 10)
assert isinstance(x, tvm.expr.Broadcast)
assert x.value == a
assert x.lanes == 10
x = tvm.expr.Shuffle([a], [0])
assert isinstance(x, tvm.expr.Shuffle)
assert x.vectors[0] == a
assert x.indices[0].value == 0
x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0)
assert isinstance(x, tvm.expr.Call)
assert x.dtype == "float32"
assert x.name == "xyz"
assert x.args[0] == a
assert x.call_type == tvm.expr.Call.Extern
assert x.func == None
assert x.value_index == 0
v = tvm.var("aa")
x = tvm.expr.Let(v, 1, v)
assert x.var == v
assert x.value.value == 1
assert x.body == v
def test_stmt_constructor():
v = tvm.var("aa")
buffer_var = tvm.var("buf", dtype="handle")
nop = tvm.stmt.Evaluate(1)
x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1))
assert isinstance(x, tvm.stmt.LetStmt)
assert x.var == v
assert x.value.value == 1
assert isinstance(x.body, tvm.stmt.Evaluate)
x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1))
assert isinstance(x, tvm.stmt.AttrStmt)
assert x.value.value == 1
x = tvm.stmt.Block(tvm.stmt.Evaluate(11),
nop)
assert isinstance(x, tvm.stmt.Block)
assert x.first.value.value == 11
assert x.rest == nop
x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"),
tvm.convert("hellow"),
nop)
assert isinstance(x, tvm.stmt.AssertStmt)
assert x.body == nop
x = tvm.stmt.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.stmt.ProducerConsumer)
assert x.body == nop
x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.stmt.For)
assert x.min.value == 0
assert x.extent.value == 10
assert x.body == nop
x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1"))
assert isinstance(x, tvm.stmt.Store)
assert x.buffer_var == buffer_var
assert x.index.value == 10
assert x.value.value == 1
tensor = tvm.placeholder((), dtype="float32")
x = tvm.stmt.Provide(tensor.op, 0, 10, [])
assert isinstance(x, tvm.stmt.Provide)
assert x.value_index == 0
assert x.value.value == 10
x = tvm.stmt.Allocate(buffer_var, "float32", [10],
tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.body == nop
x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.stmt.AttrStmt)
assert x.node == buffer_var
assert x.attr_key == "xyz"
assert x.body == nop
x = tvm.stmt.Free(buffer_var)
assert isinstance(x, tvm.stmt.Free)
assert x.buffer_var == buffer_var
x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop)
assert isinstance(x, tvm.stmt.Realize)
assert x.body == nop
x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"),
tvm.stmt.Evaluate(11),
nop)
assert isinstance(x, tvm.stmt.IfThenElse)
assert x.then_case.value.value == 11
assert x.else_case == nop
x = tvm.stmt.Prefetch(None, 1, "float32", [])
assert isinstance(x, tvm.stmt.Prefetch)
assert x.value_index == 1
if __name__ == "__main__":
test_expr_constructor()
test_stmt_constructor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment