Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a42d1e3c
Commit
a42d1e3c
authored
Jan 04, 2019
by
Jian Weng
Committed by
Tianqi Chen
Jan 04, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Unify the symbol tables to one; support `tvm.container.Array` (#2366)
parent
151f550b
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
258 additions
and
157 deletions
+258
-157
docs/langref/hybrid_script.rst
+33
-10
python/tvm/hybrid/calls.py
+23
-9
python/tvm/hybrid/intrin.py
+14
-26
python/tvm/hybrid/parser.py
+145
-96
python/tvm/hybrid/util.py
+2
-1
tests/python/unittest/test_hybrid_script.py
+41
-15
No files found.
docs/langref/hybrid_script.rst
View file @
a42d1e3c
...
...
@@ -52,7 +52,8 @@ The current parse interface looks like:
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node:
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
or ``tvm.container.Array``, to this function, it returns a op node:
.. code-block:: python
...
...
@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
b = tvm.placeholder((99, ), name='b')
c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.**
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
to LLVM module.
Tuning
~~~~~~
**Under construction, not
truly
supported yet.**
**Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
...
...
@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
Users can access containers by either constants or constants loops annotated.
.. code-block:: python
@tvm.hybrid.script
def foo(a, b): # b is a tvm.container.Array
c = output_tensor(a.shape, a.dtype)
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
c[i] = a[i] + b[i]
return c
Variables
~~~~~~~~~
...
...
@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes
~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
Currently, only constant-indexed access is supported.
.. code-block:: python
...
...
@@ -133,8 +151,11 @@ Conditional Statement and Expression
.. code-block:: python
if condition:
# do something
if condition1 and condition2 and condition3:
# do something
else:
# do something else
# Select
a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet.
...
...
@@ -153,7 +174,9 @@ Array Allocation
**Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array.
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
even for ``tvm.container.Array`` for compilation.
Thread Bind
...
...
@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
, ``const_expr``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
python/tvm/hybrid/calls.py
View file @
a42d1e3c
...
...
@@ -12,15 +12,17 @@ from .util import _internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'const_range'
:
(
For
.
Unrolled
,
),
}
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
...
...
@@ -33,13 +35,13 @@ def _range(annotation, args):
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
range
=
unroll
=
vectorize
=
parallel
=
const_range
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
...
...
@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
...
...
@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
...
...
@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
def
len
(
func_id
,
args
):
"""Iterpret the len function"""
_internal_assert
(
args
.
__len__
()
==
1
,
"Only 1 argument is expected!"
)
_internal_assert
(
func_id
==
"len"
,
"This function cannot be directly invoked!"
)
try
:
return
_api
.
convert
(
args
[
0
]
.
__len__
())
except
:
#pylint: disable=bare-except
_internal_assert
(
args
[
0
]
.
shape
.
__len__
()
==
1
,
"Only one-dimension array can get len"
)
return
_api
.
convert
(
args
[
0
]
.
shape
[
0
])
python/tvm/hybrid/intrin.py
View file @
a42d1e3c
...
...
@@ -2,32 +2,19 @@
import
numpy
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
def
__init__
(
self
,
a
,
b
=
None
):
if
b
is
None
:
self
.
low
=
0
self
.
ext
=
a
else
:
self
.
low
=
a
self
.
ext
=
b
class
bind
(
object
):
#pylint: disable=invalid-name
"""GPU bind software emulataion runtime."""
def
__init__
(
self
,
_
,
ext
):
self
.
ext
=
ext
def
__iter__
(
self
):
i
=
0
while
i
<
self
.
ext
:
yield
i
+
self
.
low
yield
i
i
+=
1
class
bind
(
_range
):
#pylint: disable=invalid-name
def
__init__
(
self
,
tag
,
ext
):
super
(
bind
,
self
)
.
__init__
(
ext
)
self
.
tag
=
tag
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
allocate
(
shape
,
dtype
=
'float32'
,
scope
=
'global'
):
#pylint: disable=unused-argument
"""Allocate a buffer with given shape
...
...
@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
"""
return
numpy
.
zeros
(
shape
)
.
astype
(
dtype
)
output_tensor
=
allocate
#pylint: disable=invalid-name
def
popcount
(
x
):
"""
...
...
@@ -87,17 +73,19 @@ def sigmoid(x):
HYBRID_GLOBALS
=
{
'
unroll'
:
unroll
,
'
vectorize'
:
vectoriz
e
,
'
parallel'
:
parallel
,
'
allocate'
:
allocat
e
,
'
output_tensor'
:
output_tensor
,
'
len'
:
len
,
'
unroll'
:
rang
e
,
'
vectorize'
:
range
,
'
parallel'
:
rang
e
,
'
const_range'
:
range
,
'bind'
:
bind
,
'allocate'
:
allocate
,
'output_tensor'
:
allocate
,
'sqrt'
:
numpy
.
sqrt
,
'log'
:
numpy
.
log
,
'tanh'
:
numpy
.
tanh
,
'power'
:
numpy
.
power
,
'exp'
:
numpy
.
exp
,
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
'popcount'
:
popcount
,
}
python/tvm/hybrid/parser.py
View file @
a42d1e3c
...
...
@@ -4,7 +4,10 @@ import ast
import
operator
import
logging
import
sys
from
numbers
import
Integral
import
types
import
numbers
from
enum
import
Enum
from
.util
import
_internal_assert
from
.
import
calls
...
...
@@ -12,18 +15,15 @@ from . import util
from
.var_decl
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..container
import
Array
from
..tensor
import
Tensor
,
Operation
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..
import
api
as
_api
from
..
import
ir_pass
as
_ir_pass
def
list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
return
util
.
make_nop
()
def
pack_list_to_block
(
lst
):
if
len
(
lst
)
==
1
:
return
lst
[
0
]
body
=
lst
[
0
]
...
...
@@ -32,6 +32,29 @@ def list_to_block(visit, lst):
return
body
def
visit_list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
return
util
.
make_nop
()
return
pack_list_to_block
(
lst
)
class
Symbol
(
Enum
):
"""Enumerates types in the symbol table"""
Callable
=
0
Input
=
1
OutputBuffer
=
2
GlobalBuffer
=
3
LocalBuffer
=
4
SharedBuffer
=
5
ConstVar
=
6
BufferVar
=
7
LoopVar
=
8
ConstLoopVar
=
9
class
HybridParser
(
ast
.
NodeVisitor
):
"""Python AST visitor pass which finally lowers it to HalideIR"""
...
...
@@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor):
"""
self
.
args
=
list
(
args
)
self
.
usage
=
usage
.
copy
()
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
alloc_buffers
=
{}
# Buffers formed by explicit allocate instructions
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
variables
=
{}
# The status of defined variables
self
.
symbols
=
{}
# Symbol table
for
k
,
v
in
symbols
.
items
():
if
isinstance
(
v
,
types
.
FunctionType
):
self
.
symbols
[
k
]
=
Symbol
.
Callable
,
v
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
returned
=
False
# If this function has a valid return
self
.
symbols
=
symbols
# The global context
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
pop_buf
=
[]
pop_var
=
[]
to_pop
=
[]
for
key
,
val
in
self
.
usage
.
items
():
_
,
level
,
_
=
val
if
level
!=
node
:
continue
if
key
in
self
.
_args
.
keys
():
_internal_assert
(
key
in
self
.
symbols
.
keys
(),
"Unknown symbol
%
s!"
%
key
)
ty
,
entry
=
self
.
symbols
[
key
]
#pylint: disable=invalid-name
if
ty
in
[
Symbol
.
Input
,
Symbol
.
OutputBuffer
]:
continue
if
key
in
self
.
alloc_buffers
.
keys
():
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
if
_scope
==
'output'
:
continue
pop_buf
.
append
(
key
)
elif
'Buffer'
in
ty
.
name
:
_buf
=
entry
_scope
=
ty
.
name
[:
-
6
]
.
lower
()
if
ty
is
not
Symbol
.
BufferVar
else
'global'
to_pop
.
append
(
key
)
else
:
_internal_assert
(
key
in
self
.
variables
.
keys
(),
"Key should be either in one of args, buffers, and vars"
)
if
not
isinstance
(
self
.
variables
[
key
],
tuple
):
continue
_buf
,
_scope
=
self
.
variables
[
key
]
pop_var
.
append
(
key
)
continue
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
for
elem
in
pop_buf
:
self
.
alloc_buffers
.
pop
(
elem
)
for
elem
in
pop_var
:
self
.
variables
.
pop
(
elem
)
return
body
for
elem
in
to_pop
:
self
.
symbols
.
pop
(
elem
)
return
body
def
_get_buffer_from_id
(
self
,
s
,
for_provide
=
False
):
_internal_assert
((
s
in
self
.
_args
.
keys
())
+
(
s
in
self
.
alloc_buffers
.
keys
())
==
1
,
"This
%
s is expected to be in either
\
argument list or allocated buffer!"
%
s
)
if
s
in
self
.
_args
.
keys
():
if
for_provide
:
self
.
side_effect
.
add
(
self
.
_args
[
s
])
return
self
.
_args
[
s
]
return
self
.
alloc_buffers
[
s
][
0
]
def
_const
(
self
,
value
,
dtype
=
None
):
if
dtype
is
None
:
if
isinstance
(
value
,
bool
):
dtype
=
"bool"
elif
isinstance
(
value
,
Integral
):
dtype
=
"int32"
else
:
dtype
=
"float32"
return
_api
.
const
(
value
,
dtype
)
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
"Only one-function source code
can
be fed to this parser!"
)
"Only one-function source code
will
be fed to this parser!"
)
return
self
.
visit
(
node
.
body
[
0
])
...
...
@@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor):
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
_args
[
getattr
(
arg
,
_attr
)]
=
self
.
args
[
idx
]
res
=
list_to_block
(
self
.
visit
,
node
.
body
)
self
.
symbols
[
getattr
(
arg
,
_attr
)]
=
(
Symbol
.
Input
,
self
.
args
[
idx
])
res
=
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
return
res
...
...
@@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
name
=
node
.
id
if
name
in
self
.
loops_above
.
keys
():
return
self
.
loops_above
[
name
]
elif
name
in
self
.
variables
.
keys
():
res
=
self
.
variables
[
name
]
if
isinstance
(
res
,
tuple
):
buf
=
res
[
0
]
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
[
self
.
_const
(
0
)],
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
[
self
.
_const
(
0
)]
ty
,
entry
=
self
.
symbols
[
name
]
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
return
entry
elif
ty
is
Symbol
.
ConstVar
:
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
elif
ty
is
Symbol
.
BufferVar
:
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
res
return
None
buf
=
self
.
_get_buffer_from_id
(
name
)
return
buf
return
_make
.
Call
(
entry
.
dtype
,
entry
.
name
,
[
_api
.
const
(
0
,
'int32'
)],
\
_expr
.
Call
.
Halide
,
entry
.
op
,
entry
.
value_index
)
return
entry
,
[
_api
.
const
(
0
,
'int32'
)]
# Do I need any assertion here?
return
entry
def
visit_Num
(
self
,
node
):
return
self
.
_const
(
node
.
n
)
if
isinstance
(
node
.
n
,
numbers
.
Integral
):
dtype
=
"int32"
elif
isinstance
(
node
.
n
,
float
):
dtype
=
"float32"
else
:
_internal_assert
(
isinstance
(
node
.
n
,
bool
),
"The data type should be one of (int, float, bool)"
)
dtype
=
"bool"
return
_api
.
const
(
node
.
n
,
dtype
)
def
visit_AugAssign
(
self
,
node
):
...
...
@@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
len
(
buf
)
==
2
,
"LHS is supposed to be (buf, args)!"
)
buf
,
args
=
buf
else
:
args
=
[
self
.
_const
(
0
)]
args
=
[
_api
.
const
(
0
,
'int32'
)]
_internal_assert
(
isinstance
(
buf
,
Tensor
),
"LHS is supposed to be Tensor!"
)
read
=
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
...
...
@@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor):
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
self
.
alloc_buffers
[
node
.
targets
[
i
]
.
id
]
=
(
rhs
.
output
(
i
),
'global'
)
self
.
symbols
[
node
.
targets
[
i
]
.
id
]
=
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
...
...
@@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor):
#TODO: support defined intermediate buffer later
lhs_
=
lhs
lhs
=
lhs
.
id
_internal_assert
(
lhs
not
in
self
.
loops_above
.
keys
(),
\
"Loop variable cannot be overwritten!"
)
if
lhs
in
self
.
symbols
.
keys
():
ty
,
_
=
self
.
symbols
[
lhs
]
_internal_assert
(
ty
!=
Symbol
.
LoopVar
,
\
"Loop variable cannot be overwritten!"
)
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
if
decl
==
lhs_
:
_internal_assert
(
lhs
not
in
self
.
variables
.
keys
()
and
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
symbols
.
keys
(),
"This value should not be defined before this point!"
)
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
self
.
symbols
[
lhs
]
=
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
if
scope
==
'output'
:
self
.
outputs
.
append
(
lhs
)
return
util
.
make_nop
()
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
variables
[
lhs
]
=
rhs
self
.
symbols
[
lhs
]
=
Symbol
.
ConstVar
,
rhs
else
:
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
variables
[
lhs
]
=
(
ph
,
'global'
)
self
.
symbols
[
lhs
]
=
Symbol
.
BufferVar
,
ph
lhs
=
self
.
visit
(
lhs_
)
if
lhs
is
not
None
:
buf
,
args
=
lhs
...
...
@@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor):
def
visit_Attribute
(
self
,
node
):
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Name
),
\
"For atrribute access, only both names are supported so far!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
id
)
buf
=
self
.
visit
(
node
.
value
)
return
getattr
(
buf
,
node
.
attr
)
def
visit_Subscript
(
self
,
node
):
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
buf
=
self
.
visit
(
node
.
value
)
if
isinstance
(
buf
,
Array
):
for
i
in
args
:
if
isinstance
(
i
,
numbers
.
Integral
):
buf
=
buf
[
i
]
else
:
_internal_assert
(
isinstance
(
i
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
"All indices are supposed to be constants"
)
buf
=
buf
[
i
.
value
]
return
buf
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
args
shape
=
self
.
visit
(
node
.
value
)
...
...
@@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
context
,
ast
.
Call
),
"The object must be a Python func call!"
)
_internal_assert
(
isinstance
(
option
,
ast
.
Name
),
"The object after 'as' must be an id!"
)
self
.
annotation
[
option
.
id
]
=
context
.
func
.
id
return
list_to_block
(
self
.
visit
,
node
.
body
)
return
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
def
visit_If
(
self
,
node
):
cond
=
self
.
visit
(
node
.
test
)
if_body
=
list_to_block
(
self
.
visit
,
node
.
body
)
if_body
=
visit_
list_to_block
(
self
.
visit
,
node
.
body
)
if
node
.
orelse
:
else_body
=
list_to_block
(
self
.
visit
,
node
.
orelse
)
else_body
=
visit_
list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else_body
=
util
.
make_nop
()
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
...
@@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor):
except
AttributeError
:
_internal_assert
(
func_id
in
self
.
symbols
.
keys
(),
\
"The function called is not in the context either!"
)
outs
=
self
.
symbols
[
func_id
](
*
args
)
ty
,
entry
=
self
.
symbols
[
func_id
]
_internal_assert
(
ty
is
Symbol
.
Callable
,
\
"Are you sure what you call is a function?!"
)
outs
=
entry
(
*
args
)
op
=
outs
.
op
if
isinstance
(
outs
,
Tensor
)
else
outs
[
0
]
.
op
return
op
...
...
@@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor):
iter_var
,
low
,
ext
,
for_type
=
self
.
visit
(
node
.
iter
)
_internal_assert
(
isinstance
(
node
.
target
,
ast
.
Name
),
\
"The loop iterator should be a variable!"
)
_name
=
node
.
target
.
id
if
iter_var
is
None
:
if
isinstance
(
for_type
,
tuple
):
low
=
_ir_pass
.
Simplify
(
low
)
ext
=
_ir_pass
.
Simplify
(
ext
)
_internal_assert
(
isinstance
(
low
,
_expr
.
ConstExpr
)
and
isinstance
(
ext
,
_expr
.
ConstExpr
),
\
"Const range should start from a const"
+
\
"and iterate const times"
)
low
,
ext
=
low
.
value
,
ext
.
value
if
ext
>
114514
:
logging
.
log
(
logging
.
CRITICAL
,
\
'[Warning] Are you sure to unroll a large loop in Python?'
)
bodies
=
[]
for
i
in
range
(
low
,
low
+
ext
):
self
.
symbols
[
_name
]
=
Symbol
.
ConstLoopVar
,
i
bodies
.
append
(
visit_list_to_block
(
self
.
visit
,
node
.
body
))
return
pack_list_to_block
(
bodies
)
elif
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
self
.
_const
(
0
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
offset
=
iter_var
+
low
self
.
loops_above
[
_name
]
=
offset
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
offset
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
else
:
_internal_assert
(
for_type
is
None
,
"The loop iterating function parse error!"
)
self
.
loops_above
[
_name
]
=
iter_var
.
var
_body
=
list_to_block
(
self
.
visit
,
node
.
body
)
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
iter_var
.
var
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
if
for_type
is
None
:
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
el
se
:
res
=
_make
.
For
(
iter_var
,
self
.
_const
(
0
),
ext
,
for_type
,
0
,
_body
)
self
.
loops_above
.
pop
(
_name
)
el
if
not
isinstance
(
for_type
,
tuple
)
:
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
,
'int32'
),
ext
,
for_type
,
0
,
_body
)
self
.
symbols
.
pop
(
_name
)
return
res
def
visit_Return
(
self
,
node
):
_internal_assert
(
not
self
.
loops_above
,
"Return should not be in a loop body!"
)
_internal_assert
(
all
(
ty
!=
Symbol
.
LoopVar
for
ty
,
_
in
self
.
symbols
.
values
()),
\
"Return should not be in a loop body!"
)
ids
=
[]
if
isinstance
(
node
.
value
,
ast
.
Name
):
ids
.
append
(
node
.
value
.
id
)
ids
=
[
node
.
value
.
id
]
else
:
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Tuple
),
\
"You should return either a single tensor or a tuple"
)
for
i
in
node
.
value
.
elts
:
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
ids
.
append
(
i
.
id
)
_internal_assert
(
all
(
isinstance
(
i
,
ast
.
Name
)
for
i
in
node
.
value
.
elts
),
\
"What do you return?"
)
ids
=
[
i
.
id
for
i
in
node
.
value
.
elts
]
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
if
len
(
ids
)
<
len
(
self
.
outputs
):
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
self
.
outputs
=
[
self
.
alloc_buffers
[
i
][
0
]
for
i
in
ids
]
self
.
outputs
=
[
self
.
symbols
[
i
][
1
]
for
i
in
ids
]
self
.
returned
=
True
return
util
.
make_nop
()
...
...
python/tvm/hybrid/util.py
View file @
a42d1e3c
...
...
@@ -11,12 +11,13 @@ from .. import api as _api
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..container
import
Array
from
..tensor
import
Tensor
#pylint: disable=invalid-name
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
,
_expr
.
ConstExpr
)
tvm_arg_types
=
(
Tensor
,
Array
,
_expr
.
Var
,
_expr
.
ConstExpr
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
def
_internal_assert
(
cond
,
err
):
...
...
tests/python/unittest/test_hybrid_script.py
View file @
a42d1e3c
...
...
@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx
=
tvm
.
context
(
target
,
0
)
op
=
None
outs
=
func
(
*
args
)
outs
=
func
(
*
tuple
(
tvm
.
convert
(
i
)
if
isinstance
(
i
,
list
)
else
i
for
i
in
args
)
)
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
emu_args
=
[]
...
...
@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
shape
=
[
tvm_val_2_py_val
(
j
)
for
j
in
i
.
shape
]
emu_args
.
append
(
numpy
.
random
.
randn
(
*
shape
)
.
astype
(
i
.
dtype
))
nd_args
.
append
(
tvm
.
nd
.
array
(
emu_args
[
-
1
],
ctx
))
else
:
assert
isinstance
(
i
,
tvm
.
expr
.
Var
)
elif
isinstance
(
i
,
tvm
.
expr
.
Var
):
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
nd_args
.
append
(
emu_args
[
-
1
])
else
:
assert
isinstance
(
i
,
list
)
emu_args
.
append
(
numpy
.
array
(
i
))
sch
=
tvm
.
create_schedule
(
op
)
module
=
tvm
.
build
(
sch
,
args
+
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
module
=
tvm
.
build
(
sch
,
[
i
for
i
in
args
if
isinstance
(
i
,
(
tvm
.
tensor
.
Tensor
,
tvm
.
expr
.
Var
))]
+
\
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
assert
module
out_tensors
=
[]
...
...
@@ -192,20 +197,20 @@ def test_fanout():
def
test_looptype
():
@script
def
looptype
(
a
,
b
,
c
):
d
=
output_tensor
((
8
,
),
'int32'
)
e
=
output_tensor
((
8
,
),
'int32'
)
f
=
output_tensor
((
8
,
),
'int32'
)
for
i
in
parallel
(
8
):
d
=
output_tensor
((
16
,
),
'int32'
)
e
=
output_tensor
((
16
,
),
'int32'
)
f
=
output_tensor
((
16
,
),
'int32'
)
for
i
in
parallel
(
16
):
d
[
i
]
=
a
[
i
]
for
j
in
vectorize
(
8
):
for
j
in
vectorize
(
16
):
e
[
j
]
=
b
[
j
]
for
k
in
unroll
(
8
):
for
k
in
unroll
(
16
):
f
[
k
]
=
c
[
k
]
return
d
,
e
,
f
a
=
tvm
.
placeholder
((
8
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
8
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
8
,
),
name
=
'c'
,
dtype
=
'int32'
)
a
=
tvm
.
placeholder
((
16
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
16
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
16
,
),
name
=
'c'
,
dtype
=
'int32'
)
try
:
d
,
e
,
f
=
looptype
(
a
,
b
,
c
)
ir
=
d
.
op
.
body
...
...
@@ -509,9 +514,9 @@ def test_value_index():
def
test_func_call
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
...
...
@@ -538,6 +543,26 @@ def test_bool():
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
def
test_const_range
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
c
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
d
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
for
i
in
const_range
(
2
):
for
j
in
const_range
(
5
):
c
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
for
i
in
const_range
(
len
(
b
)):
for
j
in
const_range
(
len
(
b
[
0
])):
d
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
return
c
,
d
a
=
tvm
.
placeholder
((
2
,
5
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
[[
1
,
2
,
3
,
4
,
5
],
[
5
,
4
,
3
,
2
,
1
]]
run_and_check
(
foo
,
[
a
,
b
])
if
__name__
==
"__main__"
:
test_outer_product
()
test_fanout
()
...
...
@@ -553,5 +578,6 @@ if __name__ == "__main__":
test_value_index
()
test_func_call
()
test_bool
()
test_const_range
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment