Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
838e7181
Commit
838e7181
authored
Dec 19, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Inter-function call supported! (#2287)
parent
001ab525
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
155 additions
and
25 deletions
+155
-25
python/tvm/hybrid/api.py
+1
-3
python/tvm/hybrid/calls.py
+92
-0
python/tvm/hybrid/intrin.py
+1
-14
python/tvm/hybrid/parser.py
+0
-0
python/tvm/hybrid/util.py
+18
-0
python/tvm/hybrid/var_decl.py
+10
-5
tests/python/unittest/test_hybrid_script.py
+33
-3
No files found.
python/tvm/hybrid/api.py
View file @
838e7181
...
...
@@ -24,17 +24,15 @@ def script(pyfunc):
from
.util
import
_enter_hybrid_runtime
,
_restore_runtime
,
_is_tvm_arg_types
if
_is_tvm_arg_types
(
args
):
src
=
_pruned_source
(
func
)
parser
=
parse_python
(
src
,
args
)
parser
=
parse_python
(
src
,
func
.
__globals__
,
args
)
input_tensors
=
[]
for
i
in
args
:
if
isinstance
(
i
,
Tensor
):
input_tensors
.
append
(
i
)
op
=
_tvm_internal
.
_HybridOp
(
parser
.
func_name
,
"HybridOp"
,
None
,
input_tensors
,
parser
.
outputs
,
parser
.
parsed_body
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
parser
.
outputs
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
intersect
=
_enter_hybrid_runtime
(
func
)
...
...
python/tvm/hybrid/calls.py
0 → 100644
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from
..
import
api
as
_api
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..container
import
Array
from
..
import
ir_pass
from
..stmt
import
For
from
.util
import
_internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
}
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
n
=
len
(
args
)
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
_internal_assert
(
n
==
2
,
"A loop intrinsic should only have 1 or 2 arguments!"
)
low
,
ext
=
args
[
0
],
args
[
1
]
if
not
ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
ext
=
ext
-
low
for_type
=
LOOP_INTRIN
[
annotation
]
iter_var
=
None
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
low
,
ext
=
_api
.
const
(
0
),
args
[
1
]
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
def
_math_intrin
(
func_id
,
args
):
from
..
import
intrin
return
getattr
(
intrin
,
func_id
)(
*
args
)
sqrt
=
log
=
exp
=
tanh
=
sigmoid
=
power
=
popcount
=
_math_intrin
#pylint: disable=invalid-name
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
min
=
max
=
_min_max
#pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
_internal_assert
(
isinstance
(
args
[
1
],
str
),
"The data type should be an str"
)
_internal_assert
(
args
[
1
]
.
startswith
(
'int'
)
or
args
[
1
]
.
startswith
(
'float'
),
\
"The data type should be either int or float!"
)
dtype
=
args
[
1
]
else
:
dtype
=
'float32'
if
n
>
2
:
_internal_assert
(
isinstance
(
args
[
2
],
str
),
\
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
args
[
2
]
else
:
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
python/tvm/hybrid/intrin.py
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python
emulation
runtime"""
import
numpy
from
..stmt
import
For
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
...
...
@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
}
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'bind'
:
None
}
MATH_INTRIN
=
[
'sqrt'
,
'log'
,
'exp'
,
'tanh'
,
'sigmoid'
,
'power'
,
'popcount'
]
python/tvm/hybrid/parser.py
View file @
838e7181
This diff is collapsed.
Click to expand it.
python/tvm/hybrid/util.py
View file @
838e7181
...
...
@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
from
..
import
api
as
_api
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..tensor
import
Tensor
...
...
@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
_globals
.
pop
(
elem
)
for
k
,
v
in
intersect
:
_globals
[
k
]
=
v
def
replace_io
(
body
,
rmap
):
"""Replacing tensors usage according to the dict given"""
from
..
import
ir_pass
def
replace
(
op
):
if
isinstance
(
op
,
_stmt
.
Provide
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Provide
(
buf
.
op
,
op
.
value_index
,
op
.
value
,
op
.
args
)
elif
isinstance
(
op
,
_expr
.
Call
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
op
.
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
None
return
ir_pass
.
IRTransform
(
body
,
None
,
replace
,
[
'Provide'
,
'Call'
])
python/tvm/hybrid/var_decl.py
View file @
838e7181
...
...
@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
,
symbols
):
self
.
status
=
{}
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
args
=
args
self
.
aug_assign_
=
False
self
.
symbols
=
symbols
def
visit_FunctionDef
(
self
,
node
):
...
...
@@ -43,8 +44,10 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
"Function call should be an id"
)
func_id
=
node
.
func
.
id
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
[
'range'
,
'max'
,
'min'
],
\
"Function call id not in intrinsics' list"
)
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
\
[
'range'
,
'max'
,
'min'
]
+
\
list
(
self
.
symbols
.
keys
()),
\
"Function call id not in intrinsics' list"
)
for
elem
in
node
.
args
:
self
.
visit
(
elem
)
...
...
@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
usage
.
add
(
type
(
node
.
ctx
))
_internal_assert
(
loop
in
self
.
scope_level
,
"
%
s is used out of the scope it is defined!"
%
node
.
id
)
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
def
determine_variable_usage
(
root
,
args
):
def
determine_variable_usage
(
root
,
args
,
symbols
):
"""The helper function for calling the dedicated visitor."""
visitor
=
PyVariableUsage
(
args
)
visitor
=
PyVariableUsage
(
args
,
symbols
)
visitor
.
visit
(
root
)
return
visitor
.
status
tests/python/unittest/test_hybrid_script.py
View file @
838e7181
...
...
@@ -270,7 +270,7 @@ def test_bind():
return
@script
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
c
=
output_tensor
((
1000
,
),
'float32'
)
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
return
c
...
...
@@ -506,7 +506,37 @@ def test_value_index():
module
(
tvm
.
ndarray
.
array
(
np_a
),
res
)
tvm
.
testing
.
assert_allclose
(
res
.
asnumpy
(),
ref
)
def
test_func_call
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
for
i
in
range
(
10
):
for
j
in
range
(
10
):
d
[
i
,
j
]
=
c
[
i
,
j
]
+
i
*
j
return
d
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
b
=
tvm
.
placeholder
((
10
,
),
name
=
'b'
)
run_and_check
(
foo
,
[
a
,
b
])
def
test_bool
():
@tvm.hybrid.script
def
foo
(
a
):
b
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
b
[
0
]
=
1.2
for
i
in
range
(
1
,
a
.
shape
[
0
]
-
1
):
if
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
]
or
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
-
1
]
or
i
*
a
[
i
]
==
a
[
i
]:
b
[
i
]
=
a
[
i
]
else
:
b
[
i
]
=
0.0
return
b
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
if
__name__
==
"__main__"
:
test_outer_product
()
...
...
@@ -521,7 +551,7 @@ if __name__ == "__main__":
test_downstream
()
test_const_param
()
test_value_index
()
test_func_call
()
test_bool
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment