Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6eecec92
Unverified
Commit
6eecec92
authored
Aug 23, 2018
by
Tianqi Chen
Committed by
GitHub
Aug 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PYTHON] Enable constructors in Node (#1647)
parent
62d34ca5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
636 additions
and
92 deletions
+636
-92
python/tvm/_ffi/_ctypes/function.py
+19
-0
python/tvm/_ffi/_ctypes/node.py
+24
-1
python/tvm/_ffi/_cython/function.pxi
+32
-12
python/tvm/_ffi/_cython/node.pxi
+23
-0
python/tvm/_ffi/function.py
+1
-17
python/tvm/api.py
+8
-8
python/tvm/expr.py
+0
-0
python/tvm/make.py
+2
-41
python/tvm/stmt.py
+324
-13
src/api/api_ir.cc
+1
-0
tests/python/unittest/test_lang_constructor.py
+202
-0
No files found.
python/tvm/_ffi/_ctypes/function.py
View file @
6eecec92
...
@@ -17,6 +17,7 @@ from .types import TVMValue, TypeCode
...
@@ -17,6 +17,7 @@ from .types import TVMValue, TypeCode
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.node
import
NodeBase
from
.node
import
NodeBase
from
.
import
node
as
_node
FunctionHandle
=
ctypes
.
c_void_p
FunctionHandle
=
ctypes
.
c_void_p
ModuleHandle
=
ctypes
.
c_void_p
ModuleHandle
=
ctypes
.
c_void_p
...
@@ -186,6 +187,23 @@ class FunctionBase(object):
...
@@ -186,6 +187,23 @@ class FunctionBase(object):
_
=
args
_
=
args
return
RETURN_SWITCH
[
ret_tcode
.
value
](
ret_val
)
return
RETURN_SWITCH
[
ret_tcode
.
value
](
ret_val
)
def
__init_handle_by_constructor__
(
fconstructor
,
args
):
"""Initialize handle by constructor"""
temp_args
=
[]
values
,
tcodes
,
num_args
=
_make_tvm_args
(
args
,
temp_args
)
ret_val
=
TVMValue
()
ret_tcode
=
ctypes
.
c_int
()
check_call
(
_LIB
.
TVMFuncCall
(
fconstructor
.
handle
,
values
,
tcodes
,
ctypes
.
c_int
(
num_args
),
ctypes
.
byref
(
ret_val
),
ctypes
.
byref
(
ret_tcode
)))
_
=
temp_args
_
=
args
assert
ret_tcode
.
value
==
TypeCode
.
NODE_HANDLE
handle
=
ret_val
.
v_handle
return
handle
def
_return_module
(
x
):
def
_return_module
(
x
):
"""Return function"""
"""Return function"""
handle
=
x
.
v_handle
handle
=
x
.
v_handle
...
@@ -202,6 +220,7 @@ def _handle_return_func(x):
...
@@ -202,6 +220,7 @@ def _handle_return_func(x):
# setup return handle for function type
# setup return handle for function type
_node
.
__init_by_constructor__
=
__init_handle_by_constructor__
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
...
...
python/tvm/_ffi/_ctypes/node.py
View file @
6eecec92
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
# pylint: disable=no-member, missing-docstring
, not-callable
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
ctypes
import
ctypes
...
@@ -9,6 +9,7 @@ from .types import TVMValue, TypeCode
...
@@ -9,6 +9,7 @@ from .types import TVMValue, TypeCode
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
NodeHandle
=
ctypes
.
c_void_p
NodeHandle
=
ctypes
.
c_void_p
__init_by_constructor__
=
None
"""Maps node type to its constructor"""
"""Maps node type to its constructor"""
NODE_TYPE
=
{}
NODE_TYPE
=
{}
...
@@ -58,4 +59,26 @@ class NodeBase(object):
...
@@ -58,4 +59,26 @@ class NodeBase(object):
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
RETURN_SWITCH
[
ret_type_code
.
value
](
ret_val
)
return
RETURN_SWITCH
[
ret_type_code
.
value
](
ret_val
)
def
__init_handle_by_constructor__
(
self
,
fconstructor
,
*
args
):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
handle
=
__init_by_constructor__
(
fconstructor
,
args
)
if
not
isinstance
(
handle
,
NodeHandle
):
handle
=
NodeHandle
(
handle
)
self
.
handle
=
handle
_set_class_node_base
(
NodeBase
)
_set_class_node_base
(
NodeBase
)
python/tvm/_ffi/_cython/function.pxi
View file @
6eecec92
...
@@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode):
...
@@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode)
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef inline int FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef TVMValue[3] values
cdef TVMValue[3] values
cdef int[3] tcodes
cdef int[3] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
nargs = len(args)
temp_args = []
temp_args = []
for i in range(nargs):
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs,
&ret_val, &ret_
code))
nargs,
ret_val, ret_t
code))
return
make_ret(ret_val, ret_code)
return
0
cdef inline object FuncCall(void* chandle, tuple args):
cdef inline int FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode) except -1:
cdef int nargs
cdef int nargs
nargs = len(args)
nargs = len(args)
if nargs <= 3:
if nargs <= 3:
return FuncCall3(chandle, args, nargs)
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return 0
cdef vector[TVMValue] values
cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
temp_args = []
for i in range(nargs):
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
nargs, ret_val, ret_tcode))
return make_ret(ret_val, ret_code)
return 0
cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0
cdef class FunctionBase:
cdef class FunctionBase:
...
@@ -264,7 +281,10 @@ cdef class FunctionBase:
...
@@ -264,7 +281,10 @@ cdef class FunctionBase:
CALL(TVMFuncFree(self.chandle))
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
def __call__(self, *args):
return FuncCall(self.chandle, args)
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_MODULE = None
...
...
python/tvm/_ffi/_cython/node.pxi
View file @
6eecec92
...
@@ -65,4 +65,27 @@ cdef class NodeBase:
...
@@ -65,4 +65,27 @@ cdef class NodeBase:
"'%s' object has no attribute '%s'" % (type(self), name))
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
return make_ret(ret_val, ret_type_code)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(NodeBase)
_set_class_node_base(NodeBase)
python/tvm/_ffi/function.py
View file @
6eecec92
...
@@ -262,23 +262,7 @@ def extract_ext_funcs(finit):
...
@@ -262,23 +262,7 @@ def extract_ext_funcs(finit):
def
_get_api
(
f
):
def
_get_api
(
f
):
flocal
=
f
flocal
=
f
flocal
.
is_global
=
True
flocal
.
is_global
=
True
def
my_api_func
(
*
args
):
return
flocal
"""
This is a type erased API that calls into Global PackedFunc.
These APIs corresponds to functions registered from C++ backend
and can be used as developer functions.
args : list
The positional arguments to the function call.
Returns
-------
value : int, float, None, Node or Function
The result of the API function call.
"""
return
flocal
(
*
args
)
return
my_api_func
def
_init_api
(
namespace
,
target_module_name
=
None
):
def
_init_api
(
namespace
,
target_module_name
=
None
):
"""Initialize api for a given module name
"""Initialize api for a given module name
...
...
python/tvm/api.py
View file @
6eecec92
...
@@ -134,9 +134,9 @@ def any(*args):
...
@@ -134,9 +134,9 @@ def any(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
make
.
Or
(
args
[
0
],
args
[
1
])
ret
=
_
expr
.
Or
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
make
.
Or
(
ret
,
args
[
i
])
ret
=
_
expr
.
Or
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -158,9 +158,9 @@ def all(*args):
...
@@ -158,9 +158,9 @@ def all(*args):
raise
ValueError
(
"Any must take at least 1 argument"
)
raise
ValueError
(
"Any must take at least 1 argument"
)
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
ret
=
_
make
.
And
(
args
[
0
],
args
[
1
])
ret
=
_
expr
.
And
(
args
[
0
],
args
[
1
])
for
i
in
range
(
2
,
len
(
args
)):
for
i
in
range
(
2
,
len
(
args
)):
ret
=
_
make
.
And
(
ret
,
args
[
i
])
ret
=
_
expr
.
And
(
ret
,
args
[
i
])
return
ret
return
ret
...
@@ -616,7 +616,7 @@ def select(cond, t, f):
...
@@ -616,7 +616,7 @@ def select(cond, t, f):
node : Node
node : Node
The tvm.expr.Select node
The tvm.expr.Select node
"""
"""
return
_
make
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
return
_
expr
.
Select
(
convert
(
cond
),
convert
(
t
),
convert
(
f
))
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
def
comm_reducer
(
fcombine
,
fidentity
,
name
=
"reduce"
):
...
@@ -699,7 +699,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -699,7 +699,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
axis
=
convert
(
axis
if
isinstance
(
axis
,
(
list
,
tuple
))
else
[
axis
])
axis
=
convert
(
axis
if
isinstance
(
axis
,
(
list
,
tuple
))
else
[
axis
])
if
where
is
None
:
if
where
is
None
:
where
=
convert
(
True
)
where
=
convert
(
True
)
outputs
=
tuple
(
_
make
.
Reduce
(
combiner
,
expr
,
axis
,
where
,
i
)
outputs
=
tuple
(
_
expr
.
Reduce
(
combiner
,
expr
,
axis
,
where
,
i
)
for
i
in
range
(
size
))
for
i
in
range
(
size
))
return
outputs
[
0
]
if
size
==
1
else
outputs
return
outputs
[
0
]
if
size
==
1
else
outputs
...
@@ -751,5 +751,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -751,5 +751,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api
(
"tvm.api"
)
_init_api
(
"tvm.api"
)
#pylint: disable=unnecessary-lambda
#pylint: disable=unnecessary-lambda
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
sum
=
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
lambda
t
:
const
(
0
,
dtype
=
t
),
name
=
"sum"
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
min
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Min
(
x
,
y
),
max_value
,
name
=
'min'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
make
.
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
max
=
comm_reducer
(
lambda
x
,
y
:
_
expr
.
Max
(
x
,
y
),
min_value
,
name
=
'max'
)
python/tvm/expr.py
View file @
6eecec92
This diff is collapsed.
Click to expand it.
python/tvm/make.py
View file @
6eecec92
...
@@ -6,9 +6,10 @@ The functions are automatically exported from C++ side via PackedFunc.
...
@@ -6,9 +6,10 @@ The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner.
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
You can use make function to build the IR node.
"""
"""
from
__future__
import
absolute_import
as
_abs
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
._ffi.runtime_ctypes
import
TVMType
from
._ffi.runtime_ctypes
import
TVMType
from
.
import
stmt
as
_stmt
def
range_by_min_extent
(
min_value
,
extent
):
def
range_by_min_extent
(
min_value
,
extent
):
"""Construct a Range by min and extent.
"""Construct a Range by min and extent.
...
@@ -98,44 +99,4 @@ def node(type_key, **kwargs):
...
@@ -98,44 +99,4 @@ def node(type_key, **kwargs):
return
_Node
(
*
args
)
return
_Node
(
*
args
)
def
stmt_seq
(
*
args
):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret
=
None
for
value
in
args
:
if
not
isinstance
(
value
,
_stmt
.
Stmt
):
value
=
Evaluate
(
value
)
ret
=
value
if
ret
is
None
else
Block
(
ret
,
value
)
return
ret
if
ret
else
Evaluate
(
0
)
def
stmt_list
(
stmt
):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if
isinstance
(
stmt
,
_stmt
.
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
elif
isinstance
(
stmt
,
_stmt
.
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
_init_api
(
"tvm.make"
)
_init_api
(
"tvm.make"
)
python/tvm/stmt.py
View file @
6eecec92
...
@@ -15,65 +15,376 @@ Each statement node have subfields that can be visited from python side.
...
@@ -15,65 +15,376 @@ Each statement node have subfields that can be visited from python side.
"""
"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.node
import
NodeBase
,
register_node
from
.
import
make
as
_make
class
Stmt
(
NodeBase
):
class
Stmt
(
NodeBase
):
pass
pass
@register_node
@register_node
class
LetStmt
(
Stmt
):
class
LetStmt
(
Stmt
):
pass
"""LetStmt node.
Parameters
----------
var : Var
The variable in the binding.
value : Expr
The value in to be binded.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
var
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
LetStmt
,
var
,
value
,
body
)
@register_node
@register_node
class
AssertStmt
(
Stmt
):
class
AssertStmt
(
Stmt
):
pass
"""AssertStmt node.
Parameters
----------
condition : Expr
The assert condition.
message : Expr
The error message.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
condition
,
message
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
AssertStmt
,
condition
,
message
,
body
)
@register_node
@register_node
class
ProducerConsumer
(
Stmt
):
class
ProducerConsumer
(
Stmt
):
pass
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
func
,
is_producer
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
ProducerConsumer
,
func
,
is_producer
,
body
)
@register_node
@register_node
class
For
(
Stmt
):
class
For
(
Stmt
):
"""For node.
Parameters
----------
loop_var : Var
The loop variable.
min_val : Expr
The begining value.
extent : Expr
The length of the loop.
for_type : int
The for type.
device_api : int
The device api type.
body : Stmt
The body statement.
"""
Serial
=
0
Serial
=
0
Parallel
=
1
Parallel
=
1
Vectorized
=
2
Vectorized
=
2
Unrolled
=
3
Unrolled
=
3
def
__init__
(
self
,
loop_var
,
min_val
,
extent
,
for_type
,
device_api
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
For
,
loop_var
,
min_val
,
extent
,
for_type
,
device_api
,
body
)
@register_node
@register_node
class
Store
(
Stmt
):
class
Store
(
Stmt
):
pass
"""Store node.
Parameters
----------
buffer_var : Var
The buffer Variable.
value : Expr
The value we want to store.
index : Expr
The index in the store expression.
predicate : Expr
The store predicate.
"""
def
__init__
(
self
,
buffer_var
,
value
,
index
,
predicate
):
self
.
__init_handle_by_constructor__
(
_make
.
Store
,
buffer_var
,
value
,
index
,
predicate
)
@register_node
@register_node
class
Provide
(
Stmt
):
class
Provide
(
Stmt
):
pass
"""Provide node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
value : Expr
The value to be stored.
args : list of Expr
The index arguments of the Provide.
"""
def
__init__
(
self
,
func
,
value_index
,
value
,
args
):
self
.
__init_handle_by_constructor__
(
_make
.
Provide
,
func
,
value_index
,
value
,
args
)
@register_node
@register_node
class
Allocate
(
Stmt
):
class
Allocate
(
Stmt
):
pass
"""Allocate node.
Parameters
----------
buffer_var : Var
The buffer variable.
dtype : str
The data type of the buffer.
extents : list of Expr
The extents of the allocate
condition : Expr
The condition.
body : Stmt
The body statement.
"""
def
__init__
(
self
,
buffer_var
,
dtype
,
extents
,
condition
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Allocate
,
buffer_var
,
dtype
,
extents
,
condition
,
body
)
@register_node
@register_node
class
AttrStmt
(
Stmt
):
class
AttrStmt
(
Stmt
):
pass
"""AttrStmt node.
Parameters
----------
node : Node
The node to annotate the attribute
attr_key : str
Attribute type key.
value : Expr
The value of the attribute
body : Stmt
The body statement.
"""
def
__init__
(
self
,
node
,
attr_key
,
value
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
AttrStmt
,
node
,
attr_key
,
value
,
body
)
@register_node
@register_node
class
Free
(
Stmt
):
class
Free
(
Stmt
):
pass
"""Free node.
Parameters
----------
buffer_var : Var
The buffer variable.
"""
def
__init__
(
self
,
buffer_var
):
self
.
__init_handle_by_constructor__
(
_make
.
Free
,
buffer_var
)
@register_node
@register_node
class
Realize
(
Stmt
):
class
Realize
(
Stmt
):
pass
"""Realize node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type of the operation.
bounds : list of range
The bound of realize
condition : Expr
The realize condition.
body : Stmt
The realize body
"""
def
__init__
(
self
,
func
,
value_index
,
dtype
,
bounds
,
condition
,
body
):
self
.
__init_handle_by_constructor__
(
_make
.
Realize
,
func
,
value_index
,
dtype
,
bounds
,
condition
,
body
)
@register_node
@register_node
class
Block
(
Stmt
):
class
Block
(
Stmt
):
pass
"""Block node.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
"""
def
__init__
(
self
,
first
,
rest
):
self
.
__init_handle_by_constructor__
(
_make
.
Block
,
first
,
rest
)
@register_node
@register_node
class
IfThenElse
(
Stmt
):
class
IfThenElse
(
Stmt
):
pass
"""IfThenElse node.
Parameters
----------
condition : Expr
The expression
then_case : Stmt
The statement to execute if condition is true.
else_case : Stmt
The statement to execute if condition is false.
"""
def
__init__
(
self
,
condition
,
then_case
,
else_case
):
self
.
__init_handle_by_constructor__
(
_make
.
IfThenElse
,
condition
,
then_case
,
else_case
)
@register_node
@register_node
class
Evaluate
(
Stmt
):
class
Evaluate
(
Stmt
):
pass
"""Evaluate node.
Parameters
----------
value : Expr
The expression to be evalued.
"""
def
__init__
(
self
,
value
):
self
.
__init_handle_by_constructor__
(
_make
.
Evaluate
,
value
)
@register_node
@register_node
class
Prefetch
(
Stmt
):
class
Prefetch
(
Stmt
):
pass
"""Prefetch node.
Parameters
----------
func : Operation
The operation to create the function.
value_index : int
The output value index
dtype : str
The data type to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
def
__init__
(
self
,
func
,
value_index
,
dtype
,
bounds
):
self
.
__init_handle_by_constructor__
(
_make
.
Prefetch
,
func
,
value_index
,
dtype
,
bounds
)
def
stmt_seq
(
*
args
):
"""Make sequence of statements
Parameters
----------
args : list of Expr or Var
List of statements to be combined as sequence.
Returns
-------
stmt : Stmt
The combined statement.
"""
ret
=
None
for
value
in
args
:
if
not
isinstance
(
value
,
Stmt
):
value
=
Evaluate
(
value
)
ret
=
value
if
ret
is
None
else
Block
(
ret
,
value
)
return
ret
if
ret
else
Evaluate
(
0
)
def
stmt_list
(
stmt
):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if
isinstance
(
stmt
,
Block
):
return
stmt_list
(
stmt
.
first
)
+
stmt_list
(
stmt
.
rest
)
elif
isinstance
(
stmt
,
ProducerConsumer
):
return
stmt_list
(
stmt
.
body
)
return
[
stmt
]
_make
.
stmt_list
=
stmt_list
_make
.
stmt_seq
=
stmt_seq
src/api/api_ir.cc
View file @
6eecec92
...
@@ -170,6 +170,7 @@ REGISTER_MAKE3(Select);
...
@@ -170,6 +170,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Cast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Shuffle
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
AssertStmt
);
REGISTER_MAKE3
(
AssertStmt
);
...
...
tests/python/unittest/test_lang_constructor.py
0 → 100644
View file @
6eecec92
import
tvm
def
test_expr_constructor
():
x
=
tvm
.
expr
.
Var
(
"xx"
,
"float32"
)
assert
isinstance
(
x
,
tvm
.
expr
.
Var
)
assert
x
.
name
==
"xx"
x
=
tvm
.
expr
.
Reduce
(
None
,
[
1
],
[
tvm
.
api
.
_IterVar
((
0
,
1
),
"x"
,
2
)],
None
,
0
)
assert
isinstance
(
x
,
tvm
.
expr
.
Reduce
)
assert
x
.
combiner
==
None
assert
x
.
value_index
==
0
x
=
tvm
.
expr
.
FloatImm
(
"float32"
,
1.0
)
assert
isinstance
(
x
,
tvm
.
expr
.
FloatImm
)
assert
x
.
value
==
1.0
assert
x
.
dtype
==
"float32"
x
=
tvm
.
expr
.
IntImm
(
"int64"
,
2
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
assert
x
.
value
==
2
assert
x
.
dtype
==
"int64"
x
=
tvm
.
expr
.
UIntImm
(
"uint16"
,
2
)
assert
isinstance
(
x
,
tvm
.
expr
.
UIntImm
)
assert
x
.
value
==
2
assert
x
.
dtype
==
"uint16"
x
=
tvm
.
expr
.
StringImm
(
"xyza"
)
assert
isinstance
(
x
,
tvm
.
expr
.
StringImm
)
assert
x
.
value
==
"xyza"
x
=
tvm
.
expr
.
Cast
(
"float32"
,
tvm
.
expr
.
IntImm
(
"int32"
,
1
))
assert
isinstance
(
x
,
tvm
.
expr
.
Cast
)
assert
x
.
dtype
==
"float32"
assert
x
.
value
.
value
==
1
a
=
tvm
.
const
(
1.0
,
dtype
=
"float32"
)
b
=
tvm
.
var
(
"x"
,
dtype
=
"float32"
)
for
cls
in
[
tvm
.
expr
.
Add
,
tvm
.
expr
.
Sub
,
tvm
.
expr
.
Mul
,
tvm
.
expr
.
Div
,
tvm
.
expr
.
Mod
,
tvm
.
expr
.
Min
,
tvm
.
expr
.
Max
,
tvm
.
expr
.
LT
,
tvm
.
expr
.
LE
,
tvm
.
expr
.
GT
,
tvm
.
expr
.
GE
]:
x
=
cls
(
a
,
b
)
assert
isinstance
(
x
,
cls
)
assert
x
.
a
==
a
assert
x
.
b
.
same_as
(
b
)
a
=
tvm
.
convert
(
tvm
.
var
(
"x"
)
>
1
)
b
=
tvm
.
convert
(
tvm
.
var
(
"x"
)
==
1
)
for
cls
in
[
tvm
.
expr
.
And
,
tvm
.
expr
.
Or
]:
x
=
cls
(
a
,
b
)
assert
isinstance
(
x
,
cls
)
assert
x
.
a
==
a
assert
x
.
b
.
same_as
(
b
)
x
=
tvm
.
expr
.
Not
(
a
)
assert
isinstance
(
x
,
tvm
.
expr
.
Not
)
assert
x
.
a
==
a
x
=
tvm
.
expr
.
Select
(
a
,
a
,
b
)
assert
isinstance
(
x
,
tvm
.
expr
.
Select
)
assert
x
.
true_value
==
a
assert
x
.
false_value
==
b
assert
x
.
condition
==
a
buffer_var
=
tvm
.
var
(
"x"
,
dtype
=
"handle"
)
x
=
tvm
.
expr
.
Load
(
"float32"
,
buffer_var
,
1
,
a
)
assert
isinstance
(
x
,
tvm
.
expr
.
Load
)
assert
x
.
dtype
==
"float32"
assert
x
.
buffer_var
==
buffer_var
assert
x
.
index
.
value
==
1
assert
x
.
predicate
==
a
x
=
tvm
.
expr
.
Ramp
(
1
,
2
,
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
Ramp
)
assert
x
.
base
.
value
==
1
assert
x
.
stride
.
value
==
2
assert
x
.
lanes
==
10
x
=
tvm
.
expr
.
Broadcast
(
a
,
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
Broadcast
)
assert
x
.
value
==
a
assert
x
.
lanes
==
10
x
=
tvm
.
expr
.
Shuffle
([
a
],
[
0
])
assert
isinstance
(
x
,
tvm
.
expr
.
Shuffle
)
assert
x
.
vectors
[
0
]
==
a
assert
x
.
indices
[
0
]
.
value
==
0
x
=
tvm
.
expr
.
Call
(
"float32"
,
"xyz"
,
[
a
],
tvm
.
expr
.
Call
.
Extern
,
None
,
0
)
assert
isinstance
(
x
,
tvm
.
expr
.
Call
)
assert
x
.
dtype
==
"float32"
assert
x
.
name
==
"xyz"
assert
x
.
args
[
0
]
==
a
assert
x
.
call_type
==
tvm
.
expr
.
Call
.
Extern
assert
x
.
func
==
None
assert
x
.
value_index
==
0
v
=
tvm
.
var
(
"aa"
)
x
=
tvm
.
expr
.
Let
(
v
,
1
,
v
)
assert
x
.
var
==
v
assert
x
.
value
.
value
==
1
assert
x
.
body
==
v
def
test_stmt_constructor
():
v
=
tvm
.
var
(
"aa"
)
buffer_var
=
tvm
.
var
(
"buf"
,
dtype
=
"handle"
)
nop
=
tvm
.
stmt
.
Evaluate
(
1
)
x
=
tvm
.
stmt
.
LetStmt
(
v
,
1
,
tvm
.
stmt
.
Evaluate
(
1
))
assert
isinstance
(
x
,
tvm
.
stmt
.
LetStmt
)
assert
x
.
var
==
v
assert
x
.
value
.
value
==
1
assert
isinstance
(
x
.
body
,
tvm
.
stmt
.
Evaluate
)
x
=
tvm
.
stmt
.
AttrStmt
(
v
==
1
,
"xx"
,
1
,
tvm
.
stmt
.
Evaluate
(
1
))
assert
isinstance
(
x
,
tvm
.
stmt
.
AttrStmt
)
assert
x
.
value
.
value
==
1
x
=
tvm
.
stmt
.
Block
(
tvm
.
stmt
.
Evaluate
(
11
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Block
)
assert
x
.
first
.
value
.
value
==
11
assert
x
.
rest
==
nop
x
=
tvm
.
stmt
.
AssertStmt
(
tvm
.
const
(
1
,
"uint1"
),
tvm
.
convert
(
"hellow"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
AssertStmt
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
ProducerConsumer
(
None
,
True
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
ProducerConsumer
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
For
(
tvm
.
var
(
"x"
),
0
,
10
,
0
,
0
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
For
)
assert
x
.
min
.
value
==
0
assert
x
.
extent
.
value
==
10
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
Store
(
buffer_var
,
1
,
10
,
tvm
.
const
(
1
,
"uint1"
))
assert
isinstance
(
x
,
tvm
.
stmt
.
Store
)
assert
x
.
buffer_var
==
buffer_var
assert
x
.
index
.
value
==
10
assert
x
.
value
.
value
==
1
tensor
=
tvm
.
placeholder
((),
dtype
=
"float32"
)
x
=
tvm
.
stmt
.
Provide
(
tensor
.
op
,
0
,
10
,
[])
assert
isinstance
(
x
,
tvm
.
stmt
.
Provide
)
assert
x
.
value_index
==
0
assert
x
.
value
.
value
==
10
x
=
tvm
.
stmt
.
Allocate
(
buffer_var
,
"float32"
,
[
10
],
tvm
.
const
(
1
,
"uint1"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Allocate
)
assert
x
.
dtype
==
"float32"
assert
x
.
buffer_var
==
buffer_var
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
AttrStmt
(
buffer_var
,
"xyz"
,
1
,
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
AttrStmt
)
assert
x
.
node
==
buffer_var
assert
x
.
attr_key
==
"xyz"
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
Free
(
buffer_var
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Free
)
assert
x
.
buffer_var
==
buffer_var
x
=
tvm
.
stmt
.
Realize
(
None
,
0
,
"float"
,
[],
tvm
.
const
(
1
,
"uint1"
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
Realize
)
assert
x
.
body
==
nop
x
=
tvm
.
stmt
.
IfThenElse
(
tvm
.
const
(
1
,
"uint1"
),
tvm
.
stmt
.
Evaluate
(
11
),
nop
)
assert
isinstance
(
x
,
tvm
.
stmt
.
IfThenElse
)
assert
x
.
then_case
.
value
.
value
==
11
assert
x
.
else_case
==
nop
x
=
tvm
.
stmt
.
Prefetch
(
None
,
1
,
"float32"
,
[])
assert
isinstance
(
x
,
tvm
.
stmt
.
Prefetch
)
assert
x
.
value_index
==
1
if
__name__
==
"__main__"
:
test_expr_constructor
()
test_stmt_constructor
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment