Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6eecec92
Unverified
Commit
6eecec92
authored
Aug 23, 2018
by
Tianqi Chen
Committed by
GitHub
Aug 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PYTHON] Enable constructors in Node (#1647)
parent
62d34ca5
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1084 additions
and
122 deletions
+1084
-122
python/tvm/_ffi/_ctypes/function.py
+19
-0
python/tvm/_ffi/_ctypes/node.py
+24
-1
python/tvm/_ffi/_cython/function.pxi
+32
-12
python/tvm/_ffi/_cython/node.pxi
+23
-0
python/tvm/_ffi/function.py
+1
-17
python/tvm/api.py
+8
-8
python/tvm/expr.py
+448
-30
python/tvm/make.py
+2
-41
python/tvm/stmt.py
+324
-13
src/api/api_ir.cc
+1
-0
tests/python/unittest/test_lang_constructor.py
+202
-0
No files found.
python/tvm/_ffi/_ctypes/function.py
View file @
6eecec92
...
@@ -17,6 +17,7 @@ from .types import TVMValue, TypeCode
...
@@ -17,6 +17,7 @@ from .types import TVMValue, TypeCode
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.node
import
NodeBase
from
.node
import
NodeBase
from
.
import
node
as
_node
FunctionHandle
=
ctypes
.
c_void_p
FunctionHandle
=
ctypes
.
c_void_p
ModuleHandle
=
ctypes
.
c_void_p
ModuleHandle
=
ctypes
.
c_void_p
...
@@ -186,6 +187,23 @@ class FunctionBase(object):
...
@@ -186,6 +187,23 @@ class FunctionBase(object):
_
=
args
_
=
args
return
RETURN_SWITCH
[
ret_tcode
.
value
](
ret_val
)
return
RETURN_SWITCH
[
ret_tcode
.
value
](
ret_val
)
def
__init_handle_by_constructor__
(
fconstructor
,
args
):
"""Initialize handle by constructor"""
temp_args
=
[]
values
,
tcodes
,
num_args
=
_make_tvm_args
(
args
,
temp_args
)
ret_val
=
TVMValue
()
ret_tcode
=
ctypes
.
c_int
()
check_call
(
_LIB
.
TVMFuncCall
(
fconstructor
.
handle
,
values
,
tcodes
,
ctypes
.
c_int
(
num_args
),
ctypes
.
byref
(
ret_val
),
ctypes
.
byref
(
ret_tcode
)))
_
=
temp_args
_
=
args
assert
ret_tcode
.
value
==
TypeCode
.
NODE_HANDLE
handle
=
ret_val
.
v_handle
return
handle
def
_return_module
(
x
):
def
_return_module
(
x
):
"""Return function"""
"""Return function"""
handle
=
x
.
v_handle
handle
=
x
.
v_handle
...
@@ -202,6 +220,7 @@ def _handle_return_func(x):
...
@@ -202,6 +220,7 @@ def _handle_return_func(x):
# setup return handle for function type
# setup return handle for function type
_node
.
__init_by_constructor__
=
__init_handle_by_constructor__
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
...
...
python/tvm/_ffi/_ctypes/node.py
View file @
6eecec92
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
# pylint: disable=no-member, missing-docstring
, not-callable
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
ctypes
import
ctypes
...
@@ -9,6 +9,7 @@ from .types import TVMValue, TypeCode
...
@@ -9,6 +9,7 @@ from .types import TVMValue, TypeCode
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
NodeHandle
=
ctypes
.
c_void_p
NodeHandle
=
ctypes
.
c_void_p
__init_by_constructor__
=
None
"""Maps node type to its constructor"""
"""Maps node type to its constructor"""
NODE_TYPE
=
{}
NODE_TYPE
=
{}
...
@@ -58,4 +59,26 @@ class NodeBase(object):
...
@@ -58,4 +59,26 @@ class NodeBase(object):
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
RETURN_SWITCH
[
ret_type_code
.
value
](
ret_val
)
return
RETURN_SWITCH
[
ret_type_code
.
value
](
ret_val
)
def
__init_handle_by_constructor__
(
self
,
fconstructor
,
*
args
):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
handle
=
__init_by_constructor__
(
fconstructor
,
args
)
if
not
isinstance
(
handle
,
NodeHandle
):
handle
=
NodeHandle
(
handle
)
self
.
handle
=
handle
_set_class_node_base
(
NodeBase
)
_set_class_node_base
(
NodeBase
)
python/tvm/_ffi/_cython/function.pxi
View file @
6eecec92
...
@@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode):
...
@@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode)
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef inline int FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef TVMValue[3] values
cdef int[3] tcodes
cdef int[3] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
nargs = len(args)
temp_args = []
temp_args = []
for i in range(nargs):
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs,
&ret_val, &ret_
code))
nargs,
ret_val, ret_t
code))
return
make_ret(ret_val, ret_code)
return
0
cdef inline object FuncCall(void* chandle, tuple args):
cdef inline int FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef int nargs
cdef int nargs
nargs = len(args)
nargs = len(args)
if nargs <= 3:
if nargs <= 3:
return FuncCall3(chandle, args, nargs)
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0
cdef vector[TVMValue] values
cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
temp_args = []
for i in range(nargs):
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
nargs, ret_val, ret_tcode))
return make_ret(ret_val, ret_code)
return 0
cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0
cdef class FunctionBase:
cdef class FunctionBase:
...
@@ -264,7 +281,10 @@ cdef class FunctionBase:
...
@@ -264,7 +281,10 @@ cdef class FunctionBase:
CALL(TVMFuncFree(self.chandle))
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
def __call__(self, *args):
return FuncCall(self.chandle, args)
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_MODULE = None
...
...
python/tvm/_ffi/_cython/node.pxi
View file @
6eecec92
...
@@ -65,4 +65,27 @@ cdef class NodeBase:
...
@@ -65,4 +65,27 @@ cdef class NodeBase:
"'%s' object has no attribute '%s'" % (type(self), name))
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
return make_ret(ret_val, ret_type_code)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(NodeBase)
_set_class_node_base(NodeBase)
python/tvm/_ffi/function.py
View file @
6eecec92
...
@@ -262,23 +262,7 @@ def extract_ext_funcs(finit):
...
@@ -262,23 +262,7 @@ def extract_ext_funcs(finit):
def
_get_api
(
f
):
def
_get_api
(
f
):
flocal
=
f
flocal
=
f
flocal
.
is_global
=
True
flocal
.
is_global
=
True
def
my_api_func
(
*
args
):
return
flocal
"""
This is a type erased API that calls into Global PackedFunc.
These APIs corresponds to functions registered from C++ backend
and can be used as developer functions.
args : list
The positional arguments to the function call.
Returns
-------
value : int, float, None, Node or Function
The result of the API function call.
"""
return
flocal
(
*
args
)
return
my_api_func
def
_init_api
(
namespace
,
target_module_name
=
None
):
def
_init_api
(
namespace
,
target_module_name
=
None
):
"""Initialize api for a given module name
"""Initialize api for a given module name
...
...
python/tvm/api.py
View file @
6eecec92
...
@@ -134,9 +134,9 @@ def any(*args):
...
@@ -134,9 +134,9 @@ def any(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
make
.
Or
(
args
[
0
],
args
[
1
])
ret
=
_
expr
.
Or
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
make
.
Or
(
ret
,
args
[
i
])
ret
=
_
expr
.
Or
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -158,9 +158,9 @@ def all(*args):
...
@@ -158,9 +158,9 @@ def all(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
make
.
And
(
args
[
0
],
args
[
1
])
ret
=
_
expr
.
And
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
make
.
And
(
ret
,
args
[
i
])
ret
=
_
expr
.
And
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -616,7 +616,7 @@ def select(cond, t, f):
...
@@ -616,7 +616,7 @@ def select(cond, t, f):
node : Node
node : Node
The tvm.expr.Select node
The tvm.expr.Select node
"""
"""
return
_
make
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
return
_
expr
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
...
@@ -699,7 +699,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -699,7 +699,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
axis
=
convert
(
axis
if
isinstance
(
axis
,
(
list
,
tuple
))
else
[
axis
])
axis
=
convert
(
axis
if
isinstance
(
axis
,
(
list
,
tuple
))
else
[
axis
])
if
where
is
None
:
if
where
is
None
:
where
=
convert
(
True
)
where
=
convert
(
True
)
outputs
=
tuple
(
_
make
.
Reduce
(
combiner
,
expr
,
axis
,
where
,
i
)
outputs
=
tuple
(
_
expr
.
Reduce
(
combiner
,
expr
,
axis
,
where
,
i
)
for
i
in
range
(
size
))
for
i
in
range
(
size
))
return
outputs
[
0
]
if
size
==
1
else
outputs
return
outputs
[
0
]
if
size
==
1
else
outputs
...
@@ -751,5 +751,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -751,5 +751,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api
(
"tvm.api"
)
_init_api
(
"tvm.api"
)
#pylint: disable=unnecessary-lambda
#pylint: disable=unnecessary-lambda
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
python/tvm/expr.py
View file @
6eecec92
...
@@ -225,127 +225,545 @@ class LogicalExpr(Expr):
...
@@ -225,127 +225,545 @@ class LogicalExpr(Expr):
@register_node
(
"Variable"
)
@register_node
(
"Variable"
)
class
Var
(
Expr
):
class
Var
(
Expr
):
"""Symbolic variable."""
"""Symbolic variable.
pass
Parameters
----------
name : str
The name
dtype : int
The data type
"""
def
__init__
(
self
,
name
,
dtype
):
self
.
__init_handle_by_constructor__
(
_api_internal
.
_Var
,
name
,
dtype
)
@register_node
@register_node
class
Reduce
(
Expr
):
class
Reduce
(
Expr
):
pass
"""Reduce node.
Parameters
----------
combiner : CommReducer
The combiner.
src : list of Expr
The source expression.
rdom : list of IterVar
The iteration domain
condition : Expr
The reduce condition.
value_index : int
The value index.
"""
def
__init__
(
self
,
combiner
,
src
,
rdom
,
condition
,
value_index
):
self
.
__init_handle_by_constructor__
(
_make
.
Reduce
,
combiner
,
src
,
rdom
,
condition
,
value_index
)
@register_node
@register_node
class
FloatImm
(
ConstExpr
):
class
FloatImm
(
ConstExpr
):
pass
"""Float constant.
Parameters
----------
dtype : str
The data type
value : float
The constant value.
"""
def
__init__
(
self
,
dtype
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
FloatImm
,
dtype
,
value
)
@register_node
@register_node
class
IntImm
(
ConstExpr
):
class
IntImm
(
ConstExpr
):
pass
"""Int constant.
Parameters
----------
dtype : str
The data type
value : int
The constant value.
"""
def
__init__
(
self
,
dtype
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
IntImm
,
dtype
,
value
)
@register_node
@register_node
class
UIntImm
(
ConstExpr
):
class
UIntImm
(
ConstExpr
):
pass
"""UInt constant.
Parameters
----------
dtype : str
The data type
value : int
The constant value.
"""
def
__init__
(
self
,
dtype
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
UIntImm
,
dtype
,
value
)
@register_node
@register_node
class
StringImm
(
ConstExpr
):
class
StringImm
(
ConstExpr
):
pass
"""String constant.
Parameters
----------
value : str
The value of the function.
"""
def
__init__
(
self
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
StringImm
,
value
)
@register_node
@register_node
class
Cast
(
Expr
):
class
Cast
(
Expr
):
pass
"""Cast expression.
Parameters
----------
dtype : str
The data type
value : Expr
The value of the function.
"""
def
__init__
(
self
,
dtype
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
Cast
,
dtype
,
value
)
@register_node
@register_node
class
Add
(
BinaryOpExpr
):
class
Add
(
BinaryOpExpr
):
pass
"""Add node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Add
,
a
,
b
)
@register_node
@register_node
class
Sub
(
BinaryOpExpr
):
class
Sub
(
BinaryOpExpr
):
pass
"""Sub node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Sub
,
a
,
b
)
@register_node
@register_node
class
Mul
(
BinaryOpExpr
):
class
Mul
(
BinaryOpExpr
):
pass
"""Mul node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Mul
,
a
,
b
)
@register_node
@register_node
class
Div
(
BinaryOpExpr
):
class
Div
(
BinaryOpExpr
):
pass
"""Div node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Div
,
a
,
b
)
@register_node
@register_node
class
Mod
(
BinaryOpExpr
):
class
Mod
(
BinaryOpExpr
):
pass
"""Mod node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Mod
,
a
,
b
)
@register_node
@register_node
class
Min
(
BinaryOpExpr
):
class
Min
(
BinaryOpExpr
):
pass
"""Min node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Min
,
a
,
b
)
@register_node
@register_node
class
Max
(
BinaryOpExpr
):
class
Max
(
BinaryOpExpr
):
pass
"""Max node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Max
,
a
,
b
)
@register_node
@register_node
class
EQ
(
CmpExpr
):
class
EQ
(
CmpExpr
):
pass
"""EQ node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
EQ
,
a
,
b
)
@register_node
@register_node
class
NE
(
CmpExpr
):
class
NE
(
CmpExpr
):
pass
"""NE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
NE
,
a
,
b
)
@register_node
@register_node
class
LT
(
CmpExpr
):
class
LT
(
CmpExpr
):
pass
"""LT node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
LT
,
a
,
b
)
@register_node
@register_node
class
LE
(
CmpExpr
):
class
LE
(
CmpExpr
):
pass
"""LE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
LE
,
a
,
b
)
@register_node
@register_node
class
GT
(
CmpExpr
):
class
GT
(
CmpExpr
):
pass
"""GT node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
GT
,
a
,
b
)
@register_node
@register_node
class
GE
(
CmpExpr
):
class
GE
(
CmpExpr
):
pass
"""GE node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
GE
,
a
,
b
)
@register_node
@register_node
class
And
(
LogicalExpr
):
class
And
(
LogicalExpr
):
pass
"""And node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
And
,
a
,
b
)
@register_node
@register_node
class
Or
(
LogicalExpr
):
class
Or
(
LogicalExpr
):
pass
"""Or node.
Parameters
----------
a : Expr
The left hand operand.
b : Expr
The right hand operand.
"""
def
__init__
(
self
,
a
,
b
):
self
.
__init_handle_by_constructor__
(
_make
.
Or
,
a
,
b
)
@register_node
@register_node
class
Not
(
LogicalExpr
):
class
Not
(
LogicalExpr
):
pass
"""Not node.
Parameters
----------
a : Expr
The input value
"""
def
__init__
(
self
,
a
):
self
.
__init_handle_by_constructor__
(
_make
.
Not
,
a
)
@register_node
@register_node
class
Select
(
Expr
):
class
Select
(
Expr
):
pass
"""Select node.
Parameters
----------
condition : Expr
The condition expression.
true_value : Expr
The value to take when condition is true.
false_value : Expr
The value to take when condition is false.
"""
def
__init__
(
self
,
condition
,
true_value
,
false_value
):
self
.
__init_handle_by_constructor__
(
_make
.
Select
,
condition
,
true_value
,
false_value
)
@register_node
@register_node
class
Load
(
Expr
):
class
Load
(
Expr
):
pass
"""Load node.
Parameters
----------
dtype : str
The data type.
buffer_var : Var
The buffer variable in the load expression.
index : Expr
The index in the load.
predicate : Expr
The load predicate.
"""
def
__init__
(
self
,
dtype
,
buffer_var
,
index
,
predicate
):
self
.
__init_handle_by_constructor__
(
_make
.
Load
,
dtype
,
buffer_var
,
index
,
predicate
)
@register_node
@register_node
class
Ramp
(
Expr
):
class
Ramp
(
Expr
):
pass
"""Ramp node.
Parameters
----------
base : Expr
The base expression.
stride : ramp stride
The stride of the ramp.
lanes : int
The lanes of the expression.
"""
def
__init__
(
self
,
base
,
stride
,
lanes
):
self
.
__init_handle_by_constructor__
(
_make
.
Ramp
,
base
,
stride
,
lanes
)
@register_node
@register_node
class
Broadcast
(
Expr
):
class
Broadcast
(
Expr
):
pass
"""Broadcast node.
Parameters
----------
value : Expr
The value of the expression.
lanes : int
The lanes of the expression.
"""
def
__init__
(
self
,
value
,
lanes
):
self
.
__init_handle_by_constructor__
(
_make
.
Broadcast
,
value
,
lanes
)
@register_node
@register_node
class
Shuffle
(
Expr
):
class
Shuffle
(
Expr
):
pass
"""Shuffle node.
Parameters
----------
vectors : Array of Expr
The vectors
indices : Array of indices
The indices
"""
def
__init__
(
self
,
vectors
,
indices
):
self
.
__init_handle_by_constructor__
(
_make
.
Shuffle
,
vectors
,
indices
)
@register_node
@register_node
class
Call
(
Expr
):
class
Call
(
Expr
):
"""Call node.
Parameters
----------
dtype : str
The return data type
name : str
The name of the function
args : list of Expr
The input arguments to the call
call_type : int
The type of the call
func : Operation, optional
Operation if call_type is Halide
value_index : int
The output value index
"""
Extern
=
0
Extern
=
0
ExternCPlusPlus
=
1
ExternCPlusPlus
=
1
PureExtern
=
2
PureExtern
=
2
Halide
=
3
Halide
=
3
Intrinsic
=
4
Intrinsic
=
4
PureIntrinsic
=
5
PureIntrinsic
=
5
def
__init__
(
self
,
dtype
,
name
,
args
,
call_type
,
func
,
value_index
):
self
.
__init_handle_by_constructor__
(
_make
.
Call
,
dtype
,
name
,
args
,
call_type
,
func
,
value_index
)
@register_node
@register_node
class
Let
(
Expr
):
class
Let
(
Expr
):
pass
"""Let node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Expr
The body expression.
"""
def
__init__
(
self
,
var
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Let
,
var
,
value
,
body
)
python/tvm/make.py
View file @
6eecec92
...
@@ -6,9 +6,10 @@ The functions are automatically exported from C++ side via PackedFunc.
...
@@ -6,9 +6,10 @@ The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner.
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
You can use make function to build the IR node.
"""
"""
from
__future__
import
absolute_import
as
_abs
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
._ffi.runtime_ctypes
import
TVMType
from
._ffi.runtime_ctypes
import
TVMType
from
.
import
stmt
as
_stmt
def
range_by_min_extent
(
min_value
,
extent
):
def
range_by_min_extent
(
min_value
,
extent
):
"""Construct a Range by min and extent.
"""Construct a Range by min and extent.
...
@@ -98,44 +99,4 @@ def node(type_key, **kwargs):
...
@@ -98,44 +99,4 @@ def node(type_key, **kwargs):
return
_Node
(
*
args
)
return
_Node
(
*
args
)
def
stmt_seq
(
*
args
):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret
=
None
for
value
in
args
:
if
not
isinstance
(
value
,
_stmt
.
Stmt
):
value
=
Evaluate
(
value
)
ret
=
value
if
ret
is
None
else
Block
(
ret
,
value
)
return
ret
if
ret
else
Evaluate
(
0
)
def
stmt_list
(
stmt
):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if
isinstance
(
stmt
,
_stmt
.
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
elif
isinstance
(
stmt
,
_stmt
.
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
_init_api
(
"tvm.make"
)
_init_api
(
"tvm.make"
)
python/tvm/stmt.py
View file @
6eecec92
...
@@ -15,65 +15,376 @@ Each statement node have subfields that can be visited from python side.
...
@@ -15,65 +15,376 @@ Each statement node have subfields that can be visited from python side.
"""
"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.node
import
NodeBase
,
register_node
from
.
import
make
as
_make
class
Stmt
(
NodeBase
):
class
Stmt
(
NodeBase
):
pass
pass
@register_node
@register_node
class
LetStmt
(
Stmt
):
class
LetStmt
(
Stmt
):
pass
"""LetStmt node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
var
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
LetStmt
,
var
,
value
,
body
)
@register_node
@register_node
class
AssertStmt
(
Stmt
):
class
AssertStmt
(
Stmt
):
pass
"""AssertStmt node.
Parameters
----------
condition : Expr
The assert condition.
message : Expr
The error message.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
condition
,
message
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
AssertStmt
,
condition
,
message
,
body
)
@register_node
@register_node
class
ProducerConsumer
(
Stmt
):
class
ProducerConsumer
(
Stmt
):
pass
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
func
,
is_producer
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
ProducerConsumer
,
func
,
is_producer
,
body
)
@register_node
@register_node
class
For
(
Stmt
):
class
For
(
Stmt
):
"""For node.
Parameters
----------
loop_var : Var
The loop variable.
min_val : Expr
The begining value.
extent : Expr
The length of the loop.
for_type : int
The for type.
device_api : int
The device api type.
body : Stmt
The body statement.
"""
Serial
=
0
Serial
=
0
Parallel
=
1
Parallel
=
1
Vectorized
=
2
Vectorized
=
2
Unrolled
=
3
Unrolled
=
3
def
__init__
(
self
,
loop_var
,
min_val
,
extent
,
for_type
,
device_api
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
For
,
loop_var
,
min_val
,
extent
,
for_type
,
device_api
,
body
)
@register_node
@register_node
class
Store
(
Stmt
):
class
Store
(
Stmt
):
pass
"""Store node.
Parameters
----------
buffer_var : Var
The buffer Variable.
value : Expr
The value we want to store.
index : Expr
The index in the store expression.
predicate : Expr
The store predicate.
"""
def
__init__
(
self
,
buffer_var
,
value
,
index
,
predicate
):
self
.
__init_handle_by_constructor__
(
_make
.
Store
,
buffer_var
,
value
,
index
,
predicate
)
@register_node
@register_node
class
Provide
(
Stmt
):
class
Provide
(
Stmt
):
pass
"""Provide node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
value : Expr
The value to be stored.
args : list of Expr
The index arguments of the Provide.
"""
def
__init__
(
self
,
func
,
value_index
,
value
,
args
):
self
.
__init_handle_by_constructor__
(
_make
.
Provide
,
func
,
value_index
,
value
,
args
)
@register_node
@register_node
class
Allocate
(
Stmt
):
class
Allocate
(
Stmt
):
pass
"""Allocate node.
Parameters
----------
buffer_var : Var
The buffer variable.
dtype : str
The data type of the buffer.
extents : list of Expr
The extents of the allocate
condition : Expr
The condition.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
buffer_var
,
dtype
,
extents
,
condition
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Allocate
,
buffer_var
,
dtype
,
extents
,
condition
,
body
)
@register_node
@register_node
class
AttrStmt
(
Stmt
):
class
AttrStmt
(
Stmt
):
pass
"""AttrStmt node.
Parameters
----------
node : Node
The node to annotate the attribute
attr_key : str
Attribute type key.
value : Expr
The value of the attribute
body : Stmt
The body statement.
"""
def
__init__
(
self
,
node
,
attr_key
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
AttrStmt
,
node
,
attr_key
,
value
,
body
)
@register_node
@register_node
class
Free
(
Stmt
):
class
Free
(
Stmt
):
pass
"""Free node.
Parameters
----------
buffer_var : Var
The buffer variable.
"""
def
__init__
(
self
,
buffer_var
):
self
.
__init_handle_by_constructor__
(
_make
.
Free
,
buffer_var
)
@register_node
@register_node
class
Realize
(
Stmt
):
class
Realize
(
Stmt
):
pass
"""Realize node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type of the operation.
bounds : list of range
The bound of realize
condition : Expr
The realize condition.
body : Stmt
The realize body
"""
def
__init__
(
self
,
func
,
value_index
,
dtype
,
bounds
,
condition
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Realize
,
func
,
value_index
,
dtype
,
bounds
,
condition
,
body
)
@register_node
@register_node
class
Block
(
Stmt
):
class
Block
(
Stmt
):
pass
"""Block node.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
"""
def
__init__
(
self
,
first
,
rest
):
self
.
__init_handle_by_constructor__
(
_make
.
Block
,
first
,
rest
)
@register_node
@register_node
class
IfThenElse
(
Stmt
):
class
IfThenElse
(
Stmt
):
pass
"""IfThenElse node.
Parameters
----------
condition : Expr
The expression
then_case : Stmt
The statement to execute if condition is true.
else_case : Stmt
The statement to execute if condition is false.
"""
def
__init__
(
self
,
condition
,
then_case
,
else_case
):
self
.
__init_handle_by_constructor__
(
_make
.
IfThenElse
,
condition
,
then_case
,
else_case
)
@register_node
@register_node
class
Evaluate
(
Stmt
):
class
Evaluate
(
Stmt
):
pass
"""Evaluate node.
Parameters
----------
value : Expr
The expression to be evalued.
"""
def
__init__
(
self
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
Evaluate
,
value
)
@register_node
@register_node
class
Prefetch
(
Stmt
):
class
Prefetch
(
Stmt
):
pass
"""Prefetch node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
def
__init__
(
self
,
func
,
value_index
,
dtype
,
bounds
):
self
.
__init_handle_by_constructor__
(
_make
.
Prefetch
,
func
,
value_index
,
dtype
,
bounds
)
def
stmt_seq
(
*
args
):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret
=
None
for
value
in
args
:
if
not
isinstance
(
value
,
Stmt
):
value
=
Evaluate
(
value
)
ret
=
value
if
ret
is
None
else
Block
(
ret
,
value
)
return
ret
if
ret
else
Evaluate
(
0
)
def
stmt_list
(
stmt
):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if
isinstance
(
stmt
,
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
elif
isinstance
(
stmt
,
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
_make
.
stmt_list
=
stmt_list
_make
.
stmt_seq
=
stmt_seq
src/api/api_ir.cc
View file @
6eecec92
...
@@ -170,6 +170,7 @@ REGISTER_MAKE3(Select);
...
@@ -170,6 +170,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Shuffle
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
AssertStmt
);
REGISTER_MAKE3
(
AssertStmt
);
...
...
tests/python/unittest/test_lang_constructor.py
0 → 100644
View file @
6eecec92
import
tvm
def
test_expr_constructor
():
x
=
tvm
.
expr
.
Var
(
"xx"
,
"float32"
)
assert
isinstance
(
x
,
tvm
.
expr
.
Var
)
assert
x
.
name
==
"xx"
x
=
tvm
.
expr
.
Reduce
(
None
,
[
1
],
[
tvm
.
api
.
_IterVar
((
0
,
1
),
"x"
,
2
)],
None
,
0
)
assert
isinstance
(
x
,
tvm
.
expr
.
Reduce
)
assert
x
.
combiner
==
None
assert
x
.
value_index
==
0
x
=
tvm
.
expr
.
FloatImm
(
"float32"
,
1.0
)
assert
isinstance
(
x
,
tvm
.
expr
.
FloatImm
)
assert
x
.
value
==
1.0
assert
x
.
dtype
==
"float32"
x
=
tvm
.
expr
.
IntImm
(
"int64"
,
2
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
assert
x
.
value
==
2
assert
x
.
dtype
==
"int64"
x
=
tvm
.
expr
.
UIntImm
(
"uint16"
,
2
)
assert
isinstance
(
x
,
tvm
.
expr
.
UIntImm
)
assert
x
.
value
==
2
assert
x
.
dtype
==
"uint16"
x
=
tvm
.
expr
.
StringImm
(
"xyza"
)
assert
isinstance
(
x
,
tvm
.
expr
.
StringImm
)
assert
x
.
value
==
"xyza"
x
=
tvm
.
expr
.
Cast
(
"float32"
,
tvm
.
expr
.
IntImm
(
"int32"
,
1
))
assert
isinstance
(
x
,
tvm
.
expr
.
Cast
)
assert
x
.
dtype
==
"float32"
assert
x
.
value
.
value
==
1
a
=
tvm
.
const
(
1.0
,
dtype
=
"float32"
)
b
=
tvm
.
var
(
"x"
,
dtype
=
"float32"
)
for
cls
in
[
tvm
.
expr
.
Add
,
tvm
.
expr
.
Sub
,
tvm
.
expr
.
Mul
,
tvm
.
expr
.
Div
,
tvm
.
expr
.
Mod
,
tvm
.
expr
.
Min
,
tvm
.
expr
.
Max
,
tvm
.
expr
.
LT
,
tvm
.
expr
.
LE
,
tvm
.
expr
.
GT
,
tvm
.
expr
.
GE
]:
x
=
cls
(
a
,
b
)
assert
isinstance
(
x
,
cls
)
assert
x
.
a
==
a
assert
x
.
b
.
same_as
(
b
)
a
=
tvm
.
convert
(
tvm
.
var
(
"x"
)
>
1
)
b
=
tvm
.
convert
(
tvm
.
var
(
"x"
)
==
1
)
for
cls
in
[
tvm
.
expr
.
And
,
tvm
.
expr
.
Or
]:
x
=
cls
(
a
,
b
)
assert
isinstance
(
x
,
cls
)
assert
x
.
a
==
a
assert
x
.
b
.
same_as
(
b
)
x
=
tvm
.
expr
.
Not
(
a
)
assert
isinstance
(
x
,
tvm
.
expr
.
Not
)
assert
x
.
a
==
a
x
=
tvm
.
expr
.
Select
(
a
,
a
,
b
)
assert
isinstance
(
x
,
tvm
.
expr
.
Select
)
assert
x
.
true_value
==
a
assert
x
.
false_value
==
b
assert
x
.
condition
==
a
buffer_var
=
tvm
.
var
(
"x"
,
dtype
=
"handle"
)
x
=
tvm
.
expr
.
Load
(
"float32"
,
buffer_var
,
1
,
a
)
assert
isinstance
(
x
,
tvm
.
expr
.
Load
)
assert
x
.
dtype
==
"float32"
assert
x
.
buffer_var
==
buffer_var
assert
x
.
index
.
value
==
1
assert
x
.
predicate
==
a
x
=
tvm
.
expr
.
Ramp
(
1
,
2
,
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
Ramp
)
assert
x
.
base
.
value
==
1
assert
x
.
stride
.
value
==
2
assert
x
.
lanes
==
10
x
=
tvm
.
expr
.
Broadcast
(
a
,
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
Broadcast
)
assert
x
.
value
==
a
assert
x
.
lanes
==
10
x
=
tvm
.
expr
.
Shuffle
([
a
],
[
0
])
assert
isinstance
(
x
,
tvm
.
expr
.
Shuffle
)
assert
x
.
vectors
[
0
]
==
a
assert
x
.
indices
[
0
]
.
value
==
0
x
=
tvm
.
expr
.
Call
(
"float32"
,
"xyz"
,
[
a
],
tvm
.
expr
.
Call
.
Extern
,
None
,
0
)
assert
isinstance
(
x
,
tvm
.
expr
.
Call
)
assert
x
.
dtype
==
"float32"
assert
x
.
name
==
"xyz"
assert
x
.
args
[
0
]
==
a
assert
x
.
call_type
==
tvm
.
expr
.
Call
.
Extern
assert
x
.
func
==
None
assert
x
.
value_index
==
0
v
=
tvm
.
var
(
"aa"
)
x
=
tvm
.
expr
.
Let
(
v
,
1
,
v
)
assert
x
.
var
==
v
assert
x
.
value
.
value
==
1
assert
x
.
body
==
v
def
test_stmt_constructor
():
v
=
tvm
.
var
(
"aa"
)
buffer_var
=
tvm
.
var
(
"buf"
,
dtype
=
"handle"
)
nop
=
tvm
.
stmt
.
Evaluate
(
1
)
x
=
tvm
.
stmt
.
LetStmt
(
v
,
1
,
tvm
.
stmt
.
Evaluate
(
1
))
assert
isinstance
(
x
,
tvm
.
stmt
.
LetStmt
)
assert
x
.
var
==
v
assert
x
.
value
.
value
==
1
assert
isinstance
(
x
.
body
,
tvm
.
stmt
.
Evaluate
)
x
=
tvm
.
stmt
.
AttrStmt
(
v
==
1
,
"xx"
,
1
,
tvm
.
stmt
.
Evaluate
(
1
))
assert
isinstance
(
x
,
tvm
.
stmt
.
AttrStmt
)
assert
x
.
value
.
value
==
1
x
=
tvm
.
stmt
.
Block
(
tvm
.
stmt
.
Evaluate
(
11
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Block
)
assert
x
.
first
.
value
.
value
==
11
assert
x
.
rest
==
nop
x
=
tvm
.
stmt
.
AssertStmt
(
tvm
.
const
(
1
,
"uint1"
),
tvm
.
convert
(
"hellow"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
AssertStmt
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
ProducerConsumer
(
None
,
True
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
ProducerConsumer
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
For
(
tvm
.
var
(
"x"
),
0
,
10
,
0
,
0
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
For
)
assert
x
.
min
.
value
==
0
assert
x
.
extent
.
value
==
10
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
Store
(
buffer_var
,
1
,
10
,
tvm
.
const
(
1
,
"uint1"
))
assert
isinstance
(
x
,
tvm
.
stmt
.
Store
)
assert
x
.
buffer_var
==
buffer_var
assert
x
.
index
.
value
==
10
assert
x
.
value
.
value
==
1
tensor
=
tvm
.
placeholder
((),
dtype
=
"float32"
)
x
=
tvm
.
stmt
.
Provide
(
tensor
.
op
,
0
,
10
,
[])
assert
isinstance
(
x
,
tvm
.
stmt
.
Provide
)
assert
x
.
value_index
==
0
assert
x
.
value
.
value
==
10
x
=
tvm
.
stmt
.
Allocate
(
buffer_var
,
"float32"
,
[
10
],
tvm
.
const
(
1
,
"uint1"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Allocate
)
assert
x
.
dtype
==
"float32"
assert
x
.
buffer_var
==
buffer_var
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
AttrStmt
(
buffer_var
,
"xyz"
,
1
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
AttrStmt
)
assert
x
.
node
==
buffer_var
assert
x
.
attr_key
==
"xyz"
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
Free
(
buffer_var
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Free
)
assert
x
.
buffer_var
==
buffer_var
x
=
tvm
.
stmt
.
Realize
(
None
,
0
,
"float"
,
[],
tvm
.
const
(
1
,
"uint1"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Realize
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
IfThenElse
(
tvm
.
const
(
1
,
"uint1"
),
tvm
.
stmt
.
Evaluate
(
11
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
IfThenElse
)
assert
x
.
then_case
.
value
.
value
==
11
assert
x
.
else_case
==
nop
x
=
tvm
.
stmt
.
Prefetch
(
None
,
1
,
"float32"
,
[])
assert
isinstance
(
x
,
tvm
.
stmt
.
Prefetch
)
assert
x
.
value_index
==
1
if
__name__
==
"__main__"
:
test_expr_constructor
()
test_stmt_constructor
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment