Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a42d1e3c
Commit
a42d1e3c
authored
Jan 04, 2019
by
Jian Weng
Committed by
Tianqi Chen
Jan 04, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Unify the symbol tables to one; support `tvm.container.Array` (#2366)
parent
151f550b
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
113 additions
and
61 deletions
+113
-61
docs/langref/hybrid_script.rst
+33
-10
python/tvm/hybrid/calls.py
+23
-9
python/tvm/hybrid/intrin.py
+14
-26
python/tvm/hybrid/parser.py
+0
-0
python/tvm/hybrid/util.py
+2
-1
tests/python/unittest/test_hybrid_script.py
+41
-15
No files found.
docs/langref/hybrid_script.rst
View file @
a42d1e3c
...
...
@@ -52,7 +52,8 @@ The current parse interface looks like:
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node:
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
or ``tvm.container.Array``, to this function, it returns a op node:
.. code-block:: python
...
...
@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
b = tvm.placeholder((99, ), name='b')
c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.**
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
to LLVM module.
Tuning
~~~~~~
**Under construction, not
truly
supported yet.**
**Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
...
...
@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
Users can access containers by either constants or constants loops annotated.
.. code-block:: python
@tvm.hybrid.script
def foo(a, b): # b is a tvm.container.Array
c = output_tensor(a.shape, a.dtype)
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
c[i] = a[i] + b[i]
return c
Variables
~~~~~~~~~
...
...
@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes
~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
Currently, only constant-indexed access is supported.
.. code-block:: python
...
...
@@ -133,8 +151,11 @@ Conditional Statement and Expression
.. code-block:: python
if condition:
# do something
if condition1 and condition2 and condition3:
# do something
else:
# do something else
# Select
a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet.
...
...
@@ -153,7 +174,9 @@ Array Allocation
**Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array.
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
even for ``tvm.container.Array`` for compilation.
Thread Bind
...
...
@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
, ``const_expr``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
python/tvm/hybrid/calls.py
View file @
a42d1e3c
...
...
@@ -12,15 +12,17 @@ from .util import _internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'const_range'
:
(
For
.
Unrolled
,
),
}
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
...
...
@@ -33,13 +35,13 @@ def _range(annotation, args):
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
range
=
unroll
=
vectorize
=
parallel
=
const_range
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
...
...
@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
_internal_assert
(
args
.
__len__
(
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
...
...
@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
n
=
args
.
__len__
(
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
...
...
@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
def
len
(
func_id
,
args
):
"""Iterpret the len function"""
_internal_assert
(
args
.
__len__
()
==
1
,
"Only 1 argument is expected!"
)
_internal_assert
(
func_id
==
"len"
,
"This function cannot be directly invoked!"
)
try
:
return
_api
.
convert
(
args
[
0
]
.
__len__
())
except
:
#pylint: disable=bare-except
_internal_assert
(
args
[
0
]
.
shape
.
__len__
()
==
1
,
"Only one-dimension array can get len"
)
return
_api
.
convert
(
args
[
0
]
.
shape
[
0
])
python/tvm/hybrid/intrin.py
View file @
a42d1e3c
...
...
@@ -2,32 +2,19 @@
import
numpy
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
def
__init__
(
self
,
a
,
b
=
None
):
if
b
is
None
:
self
.
low
=
0
self
.
ext
=
a
else
:
self
.
low
=
a
self
.
ext
=
b
class
bind
(
object
):
#pylint: disable=invalid-name
"""GPU bind software emulataion runtime."""
def
__init__
(
self
,
_
,
ext
):
self
.
ext
=
ext
def
__iter__
(
self
):
i
=
0
while
i
<
self
.
ext
:
yield
i
+
self
.
low
yield
i
i
+=
1
class
bind
(
_range
):
#pylint: disable=invalid-name
def
__init__
(
self
,
tag
,
ext
):
super
(
bind
,
self
)
.
__init__
(
ext
)
self
.
tag
=
tag
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
allocate
(
shape
,
dtype
=
'float32'
,
scope
=
'global'
):
#pylint: disable=unused-argument
"""Allocate a buffer with given shape
...
...
@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
"""
return
numpy
.
zeros
(
shape
)
.
astype
(
dtype
)
output_tensor
=
allocate
#pylint: disable=invalid-name
def
popcount
(
x
):
"""
...
...
@@ -87,17 +73,19 @@ def sigmoid(x):
HYBRID_GLOBALS
=
{
'
unroll'
:
unroll
,
'
vectorize'
:
vectoriz
e
,
'
parallel'
:
parallel
,
'
allocate'
:
allocat
e
,
'
output_tensor'
:
output_tensor
,
'
len'
:
len
,
'
unroll'
:
rang
e
,
'
vectorize'
:
range
,
'
parallel'
:
rang
e
,
'
const_range'
:
range
,
'bind'
:
bind
,
'allocate'
:
allocate
,
'output_tensor'
:
allocate
,
'sqrt'
:
numpy
.
sqrt
,
'log'
:
numpy
.
log
,
'tanh'
:
numpy
.
tanh
,
'power'
:
numpy
.
power
,
'exp'
:
numpy
.
exp
,
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
'popcount'
:
popcount
,
}
python/tvm/hybrid/parser.py
View file @
a42d1e3c
This diff is collapsed.
Click to expand it.
python/tvm/hybrid/util.py
View file @
a42d1e3c
...
...
@@ -11,12 +11,13 @@ from .. import api as _api
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..container
import
Array
from
..tensor
import
Tensor
#pylint: disable=invalid-name
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
,
_expr
.
ConstExpr
)
tvm_arg_types
=
(
Tensor
,
Array
,
_expr
.
Var
,
_expr
.
ConstExpr
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
def
_internal_assert
(
cond
,
err
):
...
...
tests/python/unittest/test_hybrid_script.py
View file @
a42d1e3c
...
...
@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx
=
tvm
.
context
(
target
,
0
)
op
=
None
outs
=
func
(
*
args
)
outs
=
func
(
*
tuple
(
tvm
.
convert
(
i
)
if
isinstance
(
i
,
list
)
else
i
for
i
in
args
)
)
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
emu_args
=
[]
...
...
@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
shape
=
[
tvm_val_2_py_val
(
j
)
for
j
in
i
.
shape
]
emu_args
.
append
(
numpy
.
random
.
randn
(
*
shape
)
.
astype
(
i
.
dtype
))
nd_args
.
append
(
tvm
.
nd
.
array
(
emu_args
[
-
1
],
ctx
))
else
:
assert
isinstance
(
i
,
tvm
.
expr
.
Var
)
elif
isinstance
(
i
,
tvm
.
expr
.
Var
):
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
nd_args
.
append
(
emu_args
[
-
1
])
else
:
assert
isinstance
(
i
,
list
)
emu_args
.
append
(
numpy
.
array
(
i
))
sch
=
tvm
.
create_schedule
(
op
)
module
=
tvm
.
build
(
sch
,
args
+
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
module
=
tvm
.
build
(
sch
,
[
i
for
i
in
args
if
isinstance
(
i
,
(
tvm
.
tensor
.
Tensor
,
tvm
.
expr
.
Var
))]
+
\
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
assert
module
out_tensors
=
[]
...
...
@@ -192,20 +197,20 @@ def test_fanout():
def
test_looptype
():
@script
def
looptype
(
a
,
b
,
c
):
d
=
output_tensor
((
8
,
),
'int32'
)
e
=
output_tensor
((
8
,
),
'int32'
)
f
=
output_tensor
((
8
,
),
'int32'
)
for
i
in
parallel
(
8
):
d
=
output_tensor
((
16
,
),
'int32'
)
e
=
output_tensor
((
16
,
),
'int32'
)
f
=
output_tensor
((
16
,
),
'int32'
)
for
i
in
parallel
(
16
):
d
[
i
]
=
a
[
i
]
for
j
in
vectorize
(
8
):
for
j
in
vectorize
(
16
):
e
[
j
]
=
b
[
j
]
for
k
in
unroll
(
8
):
for
k
in
unroll
(
16
):
f
[
k
]
=
c
[
k
]
return
d
,
e
,
f
a
=
tvm
.
placeholder
((
8
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
8
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
8
,
),
name
=
'c'
,
dtype
=
'int32'
)
a
=
tvm
.
placeholder
((
16
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
16
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
16
,
),
name
=
'c'
,
dtype
=
'int32'
)
try
:
d
,
e
,
f
=
looptype
(
a
,
b
,
c
)
ir
=
d
.
op
.
body
...
...
@@ -509,9 +514,9 @@ def test_value_index():
def
test_func_call
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
for
i
in
range
(
len
(
a
)
):
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
...
...
@@ -538,6 +543,26 @@ def test_bool():
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
def
test_const_range
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
c
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
d
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
for
i
in
const_range
(
2
):
for
j
in
const_range
(
5
):
c
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
for
i
in
const_range
(
len
(
b
)):
for
j
in
const_range
(
len
(
b
[
0
])):
d
[
i
,
j
]
=
a
[
i
,
j
]
+
b
[
i
,
j
]
return
c
,
d
a
=
tvm
.
placeholder
((
2
,
5
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
[[
1
,
2
,
3
,
4
,
5
],
[
5
,
4
,
3
,
2
,
1
]]
run_and_check
(
foo
,
[
a
,
b
])
if
__name__
==
"__main__"
:
test_outer_product
()
test_fanout
()
...
...
@@ -553,5 +578,6 @@ if __name__ == "__main__":
test_value_index
()
test_func_call
()
test_bool
()
test_const_range
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment