Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a42d1e3c
Commit
a42d1e3c
authored
Jan 04, 2019
by
Jian Weng
Committed by
Tianqi Chen
Jan 04, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Unify the symbol tables to one; support `tvm.container.Array` (#2366)
parent
151f550b
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
252 additions
and
151 deletions
+252
-151
docs/langref/hybrid_script.rst
+32
-9
python/tvm/hybrid/calls.py
+20
-6
python/tvm/hybrid/intrin.py
+14
-26
python/tvm/hybrid/parser.py
+143
-94
python/tvm/hybrid/util.py
+2
-1
tests/python/unittest/test_hybrid_script.py
+41
-15
No files found.
docs/langref/hybrid_script.rst
View file @
a42d1e3c
...
@@ -52,7 +52,8 @@ The current parse interface looks like:
...
@@ -52,7 +52,8 @@ The current parse interface looks like:
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node:
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
or ``tvm.container.Array``, to this function, it returns a op node:
.. code-block:: python
.. code-block:: python
...
@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
...
@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
b = tvm.placeholder((99, ), name='b')
b = tvm.placeholder((99, ), name='b')
c = outer_product(a, b, c) # return the output tensor(s) of the operator
c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.**
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
to LLVM module.
Tuning
Tuning
~~~~~~
~~~~~~
**Under construction, not
truly
supported yet.**
**Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
Follow up the example above, you can use some tvm like interfaces to tune the code:
...
@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
...
@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
these **4** keywords to annotate the corresponding types of for loops.
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.
The the usage is roughly the same as Python standard ``range``.
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
Users can access containers by either constants or constants loops annotated.
.. code-block:: python
@tvm.hybrid.script
def foo(a, b): # b is a tvm.container.Array
c = output_tensor(a.shape, a.dtype)
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
c[i] = a[i] + b[i]
return c
Variables
Variables
~~~~~~~~~
~~~~~~~~~
...
@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
...
@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
s += a[i, j] # do something with sum
s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level
b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes
Attributes
~~~~~~~~~~
~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
Currently, only constant-indexed access is supported.
.. code-block:: python
.. code-block:: python
...
@@ -133,8 +151,11 @@ Conditional Statement and Expression
...
@@ -133,8 +151,11 @@ Conditional Statement and Expression
.. code-block:: python
.. code-block:: python
if condition:
if condition
1 and condition2 and condition3
:
# do something
# do something
else:
# do something else
# Select
a = b if condition else c
a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet.
However, NO ``True`` and ``False`` keyword supported yet.
...
@@ -153,7 +174,9 @@ Array Allocation
...
@@ -153,7 +174,9 @@ Array Allocation
**Under construction, this function will be supported later!**
**Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array.
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
even for ``tvm.container.Array`` for compilation.
Thread Bind
Thread Bind
...
@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
...
@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
Keywords
Keywords
~~~~~~~~
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
, ``const_expr``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
python/tvm/hybrid/calls.py
View file @
a42d1e3c
...
@@ -15,12 +15,14 @@ LOOP_INTRIN = {
...
@@ -15,12 +15,14 @@ LOOP_INTRIN = {
'range'
:
For
.
Serial
,
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'vectorize'
:
For
.
Vectorized
,
'const_range'
:
(
For
.
Unrolled
,
),
}
}
def
_range
(
annotation
,
args
):
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
"""Handling TVM loop types"""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
if
n
==
1
:
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
else
:
...
@@ -33,13 +35,13 @@ def _range(annotation, args):
...
@@ -33,13 +35,13 @@ def _range(annotation, args):
return
iter_var
,
low
,
ext
,
for_type
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
range
=
unroll
=
vectorize
=
parallel
=
const_range
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
iter_var
=
_api
.
thread_axis
(
args
[
0
])
...
@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
...
@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
def
_min_max
(
func_id
,
args
):
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
...
@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
...
@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
shape
=
args
[
0
]
...
@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
...
@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
def
len
(
func_id
,
args
):
"""Iterpret the len function"""
_internal_assert
(
args
.
__len__
()
==
1
,
"Only 1 argument is expected!"
)
_internal_assert
(
func_id
==
"len"
,
"This function cannot be directly invoked!"
)
try
:
return
_api
.
convert
(
args
[
0
]
.
__len__
())
except
:
#pylint: disable=bare-except
_internal_assert
(
args
[
0
]
.
shape
.
__len__
()
==
1
,
"Only one-dimension array can get len"
)
return
_api
.
convert
(
args
[
0
]
.
shape
[
0
])
python/tvm/hybrid/intrin.py
View file @
a42d1e3c
...
@@ -2,32 +2,19 @@
...
@@ -2,32 +2,19 @@
import
numpy
import
numpy
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
class
bind
(
object
):
#pylint: disable=invalid-name
def
__init__
(
self
,
a
,
b
=
None
):
"""GPU bind software emulataion runtime."""
if
b
is
None
:
def
__init__
(
self
,
_
,
ext
):
self
.
low
=
0
self
.
ext
=
ext
self
.
ext
=
a
else
:
self
.
low
=
a
self
.
ext
=
b
def
__iter__
(
self
):
def
__iter__
(
self
):
i
=
0
i
=
0
while
i
<
self
.
ext
:
while
i
<
self
.
ext
:
yield
i
+
self
.
low
yield
i
i
+=
1
i
+=
1
class
bind
(
_range
):
#pylint: disable=invalid-name
def
__init__
(
self
,
tag
,
ext
):
super
(
bind
,
self
)
.
__init__
(
ext
)
self
.
tag
=
tag
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
allocate
(
shape
,
dtype
=
'float32'
,
scope
=
'global'
):
#pylint: disable=unused-argument
def
allocate
(
shape
,
dtype
=
'float32'
,
scope
=
'global'
):
#pylint: disable=unused-argument
"""Allocate a buffer with given shape
"""Allocate a buffer with given shape
...
@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
...
@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
"""
"""
return
numpy
.
zeros
(
shape
)
.
astype
(
dtype
)
return
numpy
.
zeros
(
shape
)
.
astype
(
dtype
)
output_tensor
=
allocate
#pylint: disable=invalid-name
def
popcount
(
x
):
def
popcount
(
x
):
"""
"""
...
@@ -87,17 +73,19 @@ def sigmoid(x):
...
@@ -87,17 +73,19 @@ def sigmoid(x):
HYBRID_GLOBALS
=
{
HYBRID_GLOBALS
=
{
'
unroll'
:
unroll
,
'
len'
:
len
,
'
vectorize'
:
vectoriz
e
,
'
unroll'
:
rang
e
,
'
parallel'
:
parallel
,
'
vectorize'
:
range
,
'
allocate'
:
allocat
e
,
'
parallel'
:
rang
e
,
'
output_tensor'
:
output_tensor
,
'
const_range'
:
range
,
'bind'
:
bind
,
'bind'
:
bind
,
'allocate'
:
allocate
,
'output_tensor'
:
allocate
,
'sqrt'
:
numpy
.
sqrt
,
'sqrt'
:
numpy
.
sqrt
,
'log'
:
numpy
.
log
,
'log'
:
numpy
.
log
,
'tanh'
:
numpy
.
tanh
,
'tanh'
:
numpy
.
tanh
,
'power'
:
numpy
.
power
,
'power'
:
numpy
.
power
,
'exp'
:
numpy
.
exp
,
'exp'
:
numpy
.
exp
,
'sigmoid'
:
sigmoid
,
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
'popcount'
:
popcount
,
}
}
python/tvm/hybrid/parser.py
View file @
a42d1e3c
...
@@ -4,7 +4,10 @@ import ast
...
@@ -4,7 +4,10 @@ import ast
import
operator
import
operator
import
logging
import
logging
import
sys
import
sys
from
numbers
import
Integral
import
types
import
numbers
from
enum
import
Enum
from
.util
import
_internal_assert
from
.util
import
_internal_assert
from
.
import
calls
from
.
import
calls
...
@@ -12,18 +15,15 @@ from . import util
...
@@ -12,18 +15,15 @@ from . import util
from
.var_decl
import
determine_variable_usage
from
.var_decl
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..api
import
any
as
_any
from
..container
import
Array
from
..tensor
import
Tensor
,
Operation
from
..tensor
import
Tensor
,
Operation
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..
import
make
as
_make
from
..
import
api
as
_api
from
..
import
api
as
_api
from
..
import
ir_pass
as
_ir_pass
from
..
import
ir_pass
as
_ir_pass
def
list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
def
pack_list_to_block
(
lst
):
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
return
util
.
make_nop
()
if
len
(
lst
)
==
1
:
if
len
(
lst
)
==
1
:
return
lst
[
0
]
return
lst
[
0
]
body
=
lst
[
0
]
body
=
lst
[
0
]
...
@@ -32,6 +32,29 @@ def list_to_block(visit, lst):
...
@@ -32,6 +32,29 @@ def list_to_block(visit, lst):
return
body
return
body
def
visit_list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
return
util
.
make_nop
()
return
pack_list_to_block
(
lst
)
class
Symbol
(
Enum
):
"""Enumerates types in the symbol table"""
Callable
=
0
Input
=
1
OutputBuffer
=
2
GlobalBuffer
=
3
LocalBuffer
=
4
SharedBuffer
=
5
ConstVar
=
6
BufferVar
=
7
LoopVar
=
8
ConstLoopVar
=
9
class
HybridParser
(
ast
.
NodeVisitor
):
class
HybridParser
(
ast
.
NodeVisitor
):
"""Python AST visitor pass which finally lowers it to HalideIR"""
"""Python AST visitor pass which finally lowers it to HalideIR"""
...
@@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor):
...
@@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor):
"""
"""
self
.
args
=
list
(
args
)
self
.
args
=
list
(
args
)
self
.
usage
=
usage
.
copy
()
self
.
usage
=
usage
.
copy
()
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
alloc_buffers
=
{}
# Buffers formed by explicit allocate instructions
self
.
symbols
=
{}
# Symbol table
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
for
k
,
v
in
symbols
.
items
():
self
.
variables
=
{}
# The status of defined variables
if
isinstance
(
v
,
types
.
FunctionType
):
self
.
symbols
[
k
]
=
Symbol
.
Callable
,
v
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
self
.
outputs
=
[]
# Output tensors' name
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
returned
=
False
# If this function has a valid return
self
.
returned
=
False
# If this function has a valid return
self
.
symbols
=
symbols
# The global context
def
wrap_up_realize
(
self
,
node
,
body
):
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
"""Wrap up all the variables which will no longer be used"""
pop_buf
=
[]
to_pop
=
[]
pop_var
=
[]
for
key
,
val
in
self
.
usage
.
items
():
for
key
,
val
in
self
.
usage
.
items
():
_
,
level
,
_
=
val
_
,
level
,
_
=
val
if
level
!=
node
:
if
level
!=
node
:
continue
continue
if
key
in
self
.
_args
.
keys
():
_internal_assert
(
key
in
self
.
symbols
.
keys
(),
"Unknown symbol
%
s!"
%
key
)
continue
if
key
in
self
.
alloc_buffers
.
keys
():
ty
,
entry
=
self
.
symbols
[
key
]
#pylint: disable=invalid-name
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
if
ty
in
[
Symbol
.
Input
,
Symbol
.
OutputBuffer
]:
if
_scope
==
'output'
:
continue
continue
pop_buf
.
append
(
key
)
elif
'Buffer'
in
ty
.
name
:
_buf
=
entry
_scope
=
ty
.
name
[:
-
6
]
.
lower
()
if
ty
is
not
Symbol
.
BufferVar
else
'global'
to_pop
.
append
(
key
)
else
:
else
:
_internal_assert
(
key
in
self
.
variables
.
keys
(),
"Key should be either in one of args, buffers, and vars"
)
if
not
isinstance
(
self
.
variables
[
key
],
tuple
):
continue
continue
_buf
,
_scope
=
self
.
variables
[
key
]
pop_var
.
append
(
key
)
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
_true
=
_api
.
convert
(
True
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
for
elem
in
pop_buf
:
for
elem
in
to_pop
:
self
.
alloc_buffers
.
pop
(
elem
)
self
.
symbols
.
pop
(
elem
)
for
elem
in
pop_var
:
self
.
variables
.
pop
(
elem
)
return
body
return
body
def
_get_buffer_from_id
(
self
,
s
,
for_provide
=
False
):
_internal_assert
((
s
in
self
.
_args
.
keys
())
+
(
s
in
self
.
alloc_buffers
.
keys
())
==
1
,
"This
%
s is expected to be in either
\
argument list or allocated buffer!"
%
s
)
if
s
in
self
.
_args
.
keys
():
if
for_provide
:
self
.
side_effect
.
add
(
self
.
_args
[
s
])
return
self
.
_args
[
s
]
return
self
.
alloc_buffers
[
s
][
0
]
def
_const
(
self
,
value
,
dtype
=
None
):
if
dtype
is
None
:
if
isinstance
(
value
,
bool
):
dtype
=
"bool"
elif
isinstance
(
value
,
Integral
):
dtype
=
"int32"
else
:
dtype
=
"float32"
return
_api
.
const
(
value
,
dtype
)
#pylint: disable=invalid-name, missing-docstring
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
"Only one-function source code
can
be fed to this parser!"
)
"Only one-function source code
will
be fed to this parser!"
)
return
self
.
visit
(
node
.
body
[
0
])
return
self
.
visit
(
node
.
body
[
0
])
...
@@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor):
self
.
func_name
=
node
.
name
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
_args
[
getattr
(
arg
,
_attr
)]
=
self
.
args
[
idx
]
self
.
symbols
[
getattr
(
arg
,
_attr
)]
=
(
Symbol
.
Input
,
self
.
args
[
idx
])
res
=
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
return
res
return
res
...
@@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor):
...
@@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
name
=
node
.
id
name
=
node
.
id
if
name
in
self
.
loops_above
.
keys
():
ty
,
entry
=
self
.
symbols
[
name
]
return
self
.
loops_above
[
name
]
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
elif
name
in
self
.
variables
.
keys
():
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
res
=
self
.
variables
[
name
]
return
entry
if
isinstance
(
res
,
tuple
):
elif
ty
is
Symbol
.
ConstVar
:
buf
=
res
[
0
]
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
elif
ty
is
Symbol
.
BufferVar
:
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
[
self
.
_const
(
0
)],
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
[
self
.
_const
(
0
)]
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
res
return
_make
.
Call
(
entry
.
dtype
,
entry
.
name
,
[
_api
.
const
(
0
,
'int32'
)],
\
return
None
_expr
.
Call
.
Halide
,
entry
.
op
,
entry
.
value_index
)
buf
=
self
.
_get_buffer_from_id
(
name
)
return
entry
,
[
_api
.
const
(
0
,
'int32'
)]
return
buf
# Do I need any assertion here?
return
entry
def
visit_Num
(
self
,
node
):
def
visit_Num
(
self
,
node
):
return
self
.
_const
(
node
.
n
)
if
isinstance
(
node
.
n
,
numbers
.
Integral
):
dtype
=
"int32"
elif
isinstance
(
node
.
n
,
float
):
dtype
=
"float32"
else
:
_internal_assert
(
isinstance
(
node
.
n
,
bool
),
"The data type should be one of (int, float, bool)"
)
dtype
=
"bool"
return
_api
.
const
(
node
.
n
,
dtype
)
def
visit_AugAssign
(
self
,
node
):
def
visit_AugAssign
(
self
,
node
):
...
@@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
len
(
buf
)
==
2
,
"LHS is supposed to be (buf, args)!"
)
_internal_assert
(
len
(
buf
)
==
2
,
"LHS is supposed to be (buf, args)!"
)
buf
,
args
=
buf
buf
,
args
=
buf
else
:
else
:
args
=
[
self
.
_const
(
0
)]
args
=
[
_api
.
const
(
0
,
'int32'
)]
_internal_assert
(
isinstance
(
buf
,
Tensor
),
"LHS is supposed to be Tensor!"
)
_internal_assert
(
isinstance
(
buf
,
Tensor
),
"LHS is supposed to be Tensor!"
)
read
=
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
read
=
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
...
@@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor):
for
i
in
range
(
rhs
.
num_outputs
):
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
"You should bind a pure name to the tensors"
)
self
.
alloc_buffers
[
node
.
targets
[
i
]
.
id
]
=
(
rhs
.
output
(
i
),
'global'
)
self
.
symbols
[
node
.
targets
[
i
]
.
id
]
=
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
...
@@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor):
...
@@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor):
#TODO: support defined intermediate buffer later
#TODO: support defined intermediate buffer later
lhs_
=
lhs
lhs_
=
lhs
lhs
=
lhs
.
id
lhs
=
lhs
.
id
_internal_assert
(
lhs
not
in
self
.
loops_above
.
keys
(),
\
if
lhs
in
self
.
symbols
.
keys
():
ty
,
_
=
self
.
symbols
[
lhs
]
_internal_assert
(
ty
!=
Symbol
.
LoopVar
,
\
"Loop variable cannot be overwritten!"
)
"Loop variable cannot be overwritten!"
)
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
if
decl
==
lhs_
:
if
decl
==
lhs_
:
_internal_assert
(
lhs
not
in
self
.
variables
.
keys
()
and
_internal_assert
(
lhs
not
in
self
.
symbols
.
keys
(),
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
"This value should not be defined before this point!"
)
"This value should not be defined before this point!"
)
if
isinstance
(
rhs
,
tuple
):
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
self
.
symbols
[
lhs
]
=
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
if
scope
==
'output'
:
if
scope
==
'output'
:
self
.
outputs
.
append
(
lhs
)
self
.
outputs
.
append
(
lhs
)
return
util
.
make_nop
()
return
util
.
make_nop
()
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
variables
[
lhs
]
=
rhs
self
.
symbols
[
lhs
]
=
Symbol
.
ConstVar
,
rhs
else
:
else
:
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
variables
[
lhs
]
=
(
ph
,
'global'
)
self
.
symbols
[
lhs
]
=
Symbol
.
BufferVar
,
ph
lhs
=
self
.
visit
(
lhs_
)
lhs
=
self
.
visit
(
lhs_
)
if
lhs
is
not
None
:
if
lhs
is
not
None
:
buf
,
args
=
lhs
buf
,
args
=
lhs
...
@@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor):
...
@@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor):
def
visit_Attribute
(
self
,
node
):
def
visit_Attribute
(
self
,
node
):
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Name
),
\
"For atrribute access, only both names are supported so far!"
)
"For atrribute access, only both names are supported so far!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
id
)
buf
=
self
.
visit
(
node
.
value
)
return
getattr
(
buf
,
node
.
attr
)
return
getattr
(
buf
,
node
.
attr
)
def
visit_Subscript
(
self
,
node
):
def
visit_Subscript
(
self
,
node
):
args
=
self
.
visit
(
node
.
slice
)
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
isinstance
(
node
.
value
,
ast
.
Name
):
buf
=
self
.
visit
(
node
.
value
)
buf
=
self
.
visit
(
node
.
value
)
if
isinstance
(
buf
,
Array
):
for
i
in
args
:
if
isinstance
(
i
,
numbers
.
Integral
):
buf
=
buf
[
i
]
else
:
_internal_assert
(
isinstance
(
i
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
"All indices are supposed to be constants"
)
buf
=
buf
[
i
.
value
]
return
buf
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
\
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
args
return
buf
,
args
shape
=
self
.
visit
(
node
.
value
)
shape
=
self
.
visit
(
node
.
value
)
...
@@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor):
...
@@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
context
,
ast
.
Call
),
"The object must be a Python func call!"
)
_internal_assert
(
isinstance
(
context
,
ast
.
Call
),
"The object must be a Python func call!"
)
_internal_assert
(
isinstance
(
option
,
ast
.
Name
),
"The object after 'as' must be an id!"
)
_internal_assert
(
isinstance
(
option
,
ast
.
Name
),
"The object after 'as' must be an id!"
)
self
.
annotation
[
option
.
id
]
=
context
.
func
.
id
self
.
annotation
[
option
.
id
]
=
context
.
func
.
id
return
list_to_block
(
self
.
visit
,
node
.
body
)
return
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
def
visit_If
(
self
,
node
):
def
visit_If
(
self
,
node
):
cond
=
self
.
visit
(
node
.
test
)
cond
=
self
.
visit
(
node
.
test
)
if_body
=
list_to_block
(
self
.
visit
,
node
.
body
)
if_body
=
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
if
node
.
orelse
:
if
node
.
orelse
:
else_body
=
list_to_block
(
self
.
visit
,
node
.
orelse
)
else_body
=
visit_
list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else
:
else_body
=
util
.
make_nop
()
else_body
=
util
.
make_nop
()
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
@@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor):
...
@@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor):
except
AttributeError
:
except
AttributeError
:
_internal_assert
(
func_id
in
self
.
symbols
.
keys
(),
\
_internal_assert
(
func_id
in
self
.
symbols
.
keys
(),
\
"The function called is not in the context either!"
)
"The function called is not in the context either!"
)
outs
=
self
.
symbols
[
func_id
](
*
args
)
ty
,
entry
=
self
.
symbols
[
func_id
]
_internal_assert
(
ty
is
Symbol
.
Callable
,
\
"Are you sure what you call is a function?!"
)
outs
=
entry
(
*
args
)
op
=
outs
.
op
if
isinstance
(
outs
,
Tensor
)
else
outs
[
0
]
.
op
op
=
outs
.
op
if
isinstance
(
outs
,
Tensor
)
else
outs
[
0
]
.
op
return
op
return
op
...
@@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor):
...
@@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor):
iter_var
,
low
,
ext
,
for_type
=
self
.
visit
(
node
.
iter
)
iter_var
,
low
,
ext
,
for_type
=
self
.
visit
(
node
.
iter
)
_internal_assert
(
isinstance
(
node
.
target
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
target
,
ast
.
Name
),
\
"The loop iterator should be a variable!"
)
"The loop iterator should be a variable!"
)
_name
=
node
.
target
.
id
_name
=
node
.
target
.
id
if
iter_var
is
None
:
if
isinstance
(
for_type
,
tuple
):
low
=
_ir_pass
.
Simplify
(
low
)
ext
=
_ir_pass
.
Simplify
(
ext
)
_internal_assert
(
isinstance
(
low
,
_expr
.
ConstExpr
)
and
isinstance
(
ext
,
_expr
.
ConstExpr
),
\
"Const range should start from a const"
+
\
"and iterate const times"
)
low
,
ext
=
low
.
value
,
ext
.
value
if
ext
>
114514
:
logging
.
log
(
logging
.
CRITICAL
,
\
'[Warning] Are you sure to unroll a large loop in Python?'
)
bodies
=
[]
for
i
in
range
(
low
,
low
+
ext
):
self
.
symbols
[
_name
]
=
Symbol
.
ConstLoopVar
,
i
bodies
.
append
(
visit_list_to_block
(
self
.
visit
,
node
.
body
))
return
pack_list_to_block
(
bodies
)
elif
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
self
.
_const
(
0
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
offset
=
iter_var
+
low
offset
=
iter_var
+
low
self
.
loops_above
[
_name
]
=
offset
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
offset
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
else
:
else
:
_internal_assert
(
for_type
is
None
,
"The loop iterating function parse error!"
)
_internal_assert
(
for_type
is
None
,
"The loop iterating function parse error!"
)
self
.
loops_above
[
_name
]
=
iter_var
.
var
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
iter_var
.
var
_body
=
list_to_block
(
self
.
visit
,
node
.
body
)
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
if
for_type
is
None
:
if
for_type
is
None
:
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
el
se
:
el
if
not
isinstance
(
for_type
,
tuple
)
:
res
=
_make
.
For
(
iter_var
,
self
.
_const
(
0
),
ext
,
for_type
,
0
,
_body
)
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
,
'int32'
),
ext
,
for_type
,
0
,
_body
)
self
.
loops_above
.
pop
(
_name
)
self
.
symbols
.
pop
(
_name
)
return
res
return
res
def
visit_Return
(
self
,
node
):
def
visit_Return
(
self
,
node
):
_internal_assert
(
not
self
.
loops_above
,
"Return should not be in a loop body!"
)
_internal_assert
(
all
(
ty
!=
Symbol
.
LoopVar
for
ty
,
_
in
self
.
symbols
.
values
()),
\
"Return should not be in a loop body!"
)
ids
=
[]
ids
=
[]
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
isinstance
(
node
.
value
,
ast
.
Name
):
ids
.
append
(
node
.
value
.
id
)
ids
=
[
node
.
value
.
id
]
else
:
else
:
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Tuple
),
\
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Tuple
),
\
"You should return either a single tensor or a tuple"
)
"You should return either a single tensor or a tuple"
)
for
i
in
node
.
value
.
elts
:
_internal_assert
(
all
(
isinstance
(
i
,
ast
.
Name
)
for
i
in
node
.
value
.
elts
),
\
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
"What do you return?"
)
ids
.
append
(
i
.
id
)
ids
=
[
i
.
id
for
i
in
node
.
value
.
elts
]
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
if
len
(
ids
)
<
len
(
self
.
outputs
):
if
len
(
ids
)
<
len
(
self
.
outputs
):
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
self
.
outputs
=
[
self
.
alloc_buffers
[
i
][
0
]
for
i
in
ids
]
self
.
outputs
=
[
self
.
symbols
[
i
][
1
]
for
i
in
ids
]
self
.
returned
=
True
self
.
returned
=
True
return
util
.
make_nop
()
return
util
.
make_nop
()
...
...
python/tvm/hybrid/util.py
View file @
a42d1e3c
...
@@ -11,12 +11,13 @@ from .. import api as _api
...
@@ -11,12 +11,13 @@ from .. import api as _api
from
..
import
make
as
_make
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..
import
stmt
as
_stmt
from
..container
import
Array
from
..tensor
import
Tensor
from
..tensor
import
Tensor
#pylint: disable=invalid-name
#pylint: disable=invalid-name
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
,
_expr
.
ConstExpr
)
tvm_arg_types
=
(
Tensor
,
Array
,
_expr
.
Var
,
_expr
.
ConstExpr
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
def
_internal_assert
(
cond
,
err
):
def
_internal_assert
(
cond
,
err
):
...
...
tests/python/unittest/test_hybrid_script.py
View file @
a42d1e3c
...
@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
...
@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx
=
tvm
.
context
(
target
,
0
)
ctx
=
tvm
.
context
(
target
,
0
)
op
=
None
op
=
None
outs
=
func
(
*
args
)
outs
=
func
(
*
tuple
(
tvm
.
convert
(
i
)
if
isinstance
(
i
,
list
)
else
i
for
i
in
args
)
)
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
emu_args
=
[]
emu_args
=
[]
...
@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
...
@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
shape
=
[
tvm_val_2_py_val
(
j
)
for
j
in
i
.
shape
]
shape
=
[
tvm_val_2_py_val
(
j
)
for
j
in
i
.
shape
]
emu_args
.
append
(
numpy
.
random
.
randn
(
*
shape
)
.
astype
(
i
.
dtype
))
emu_args
.
append
(
numpy
.
random
.
randn
(
*
shape
)
.
astype
(
i
.
dtype
))
nd_args
.
append
(
tvm
.
nd
.
array
(
emu_args
[
-
1
],
ctx
))
nd_args
.
append
(
tvm
.
nd
.
array
(
emu_args
[
-
1
],
ctx
))
else
:
elif
isinstance
(
i
,
tvm
.
expr
.
Var
):
assert
isinstance
(
i
,
tvm
.
expr
.
Var
)
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
nd_args
.
append
(
emu_args
[
-
1
])
nd_args
.
append
(
emu_args
[
-
1
])
else
:
assert
isinstance
(
i
,
list
)
emu_args
.
append
(
numpy
.
array
(
i
))
sch
=
tvm
.
create_schedule
(
op
)
sch
=
tvm
.
create_schedule
(
op
)
module
=
tvm
.
build
(
sch
,
args
+
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
module
=
tvm
.
build
(
sch
,
[
i
for
i
in
args
if
isinstance
(
i
,
(
tvm
.
tensor
.
Tensor
,
tvm
.
expr
.
Var
))]
+
\
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
assert
module
assert
module
out_tensors
=
[]
out_tensors
=
[]
...
@@ -192,20 +197,20 @@ def test_fanout():
...
@@ -192,20 +197,20 @@ def test_fanout():
def
test_looptype
():
def
test_looptype
():
@script
@script
def
looptype
(
a
,
b
,
c
):
def
looptype
(
a
,
b
,
c
):
d
=
output_tensor
((
8
,
),
'int32'
)
d
=
output_tensor
((
16
,
),
'int32'
)
e
=
output_tensor
((
8
,
),
'int32'
)
e
=
output_tensor
((
16
,
),
'int32'
)
f
=
output_tensor
((
8
,
),
'int32'
)
f
=
output_tensor
((
16
,
),
'int32'
)
for
i
in
parallel
(
8
):
for
i
in
parallel
(
16
):
d
[
i
]
=
a
[
i
]
d
[
i
]
=
a
[
i
]
for
j
in
vectorize
(
8
):
for
j
in
vectorize
(
16
):
e
[
j
]
=
b
[
j
]
e
[
j
]
=
b
[
j
]
for
k
in
unroll
(
8
):
for
k
in
unroll
(
16
):
f
[
k
]
=
c
[
k
]
f
[
k
]
=
c
[
k
]
return
d
,
e
,
f
return
d
,
e
,
f
a
=
tvm
.
placeholder
((
8
,
),
name
=
'a'
,
dtype
=
'int32'
)
a
=
tvm
.
placeholder
((
16
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
8
,
),
name
=
'b'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
16
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
8
,
),
name
=
'c'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
16
,
),
name
=
'c'
,
dtype
=
'int32'
)
try
:
try
:
d
,
e
,
f
=
looptype
(
a
,
b
,
c
)
d
,
e
,
f
=
looptype
(
a
,
b
,
c
)
ir
=
d
.
op
.
body
ir
=
d
.
op
.
body
...
@@ -509,9 +514,9 @@ def test_value_index():
...
@@ -509,9 +514,9 @@ def test_value_index():
def
test_func_call
():
def
test_func_call
():
@tvm.hybrid.script
@tvm.hybrid.script
def
foo
(
a
,
b
):
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
a
[
i
]
=
i
+
1.0
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
b
[
i
]
=
i
+
1.0
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
...
@@ -538,6 +543,26 @@ def test_bool():
...
@@ -538,6 +543,26 @@ def test_bool():
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
run_and_check
(
foo
,
[
a
])
def
test_const_range
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
c
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
d
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
for
i
in
const_range
(
2
):
for
j
in
const_range
(
5
):
c
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
for
i
in
const_range
(
len
(
b
)):
for
j
in
const_range
(
len
(
b
[
0
])):
d
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
return
c
,
d
a
=
tvm
.
placeholder
((
2
,
5
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
[[
1
,
2
,
3
,
4
,
5
],
[
5
,
4
,
3
,
2
,
1
]]
run_and_check
(
foo
,
[
a
,
b
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
test_fanout
()
test_fanout
()
...
@@ -553,5 +578,6 @@ if __name__ == "__main__":
...
@@ -553,5 +578,6 @@ if __name__ == "__main__":
test_value_index
()
test_value_index
()
test_func_call
()
test_func_call
()
test_bool
()
test_bool
()
test_const_range
()
# TODO:
# TODO:
# test_inplace()
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment