Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ba15a729
Commit
ba15a729
authored
May 11, 2019
by
Lianmin Zheng
Committed by
Tianqi Chen
May 10, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[HybridScript] Capture constant external python variables (#3157)
parent
654192de
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
21 deletions
+78
-21
python/tvm/hybrid/__init__.py
+5
-1
python/tvm/hybrid/module.py
+1
-1
python/tvm/hybrid/parser.py
+35
-15
python/tvm/hybrid/preprocessor.py
+12
-4
python/tvm/hybrid/util.py
+6
-0
tests/python/unittest/test_hybrid_script.py
+19
-0
No files found.
python/tvm/hybrid/__init__.py
View file @
ba15a729
...
...
@@ -31,6 +31,8 @@ HalideIR.
from
__future__
import
absolute_import
as
_abs
import
inspect
from
.._ffi.base
import
decorate
from
.._ffi.function
import
_init_api
from
..build_module
import
form_body
...
...
@@ -55,7 +57,9 @@ def script(pyfunc):
from
.util
import
_is_tvm_arg_types
if
_is_tvm_arg_types
(
args
):
src
=
_pruned_source
(
func
)
return
source_to_op
(
src
,
func
.
__globals__
,
args
)
closure_vars
=
inspect
.
getclosurevars
(
func
)
.
nonlocals
closure_vars
.
update
(
inspect
.
getclosurevars
(
func
)
.
globals
)
return
source_to_op
(
src
,
args
,
func
.
__globals__
,
closure_vars
)
from
.runtime
import
_enter_hybrid_runtime
,
_restore_runtime
intersect
=
_enter_hybrid_runtime
(
func
)
...
...
python/tvm/hybrid/module.py
View file @
ba15a729
...
...
@@ -62,7 +62,7 @@ class HybridModule(object):
def
__call__
(
self
,
*
args
):
if
_is_tvm_arg_types
(
args
):
return
source_to_op
(
self
.
root_
,
globals
(),
args
)
return
source_to_op
(
self
.
root_
,
args
,
globals
(),
{}
)
return
self
.
func_
(
*
args
)
...
...
python/tvm/hybrid/parser.py
View file @
ba15a729
...
...
@@ -25,7 +25,7 @@ import numbers
from
enum
import
Enum
from
.util
import
_internal_assert
from
.util
import
_internal_assert
,
_apply_indices
from
.
import
calls
from
.
import
util
from
.preprocessor
import
determine_variable_usage
...
...
@@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor):
}
def
__init__
(
self
,
args
,
usage
,
symbols
,
func_name
=
None
):
def
__init__
(
self
,
args
,
usage
,
symbols
,
closure_vars
,
func_name
=
None
):
"""
Parameters
----------
...
...
@@ -122,6 +122,12 @@ class HybridParser(ast.NodeVisitor):
usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
func_name: str
...
...
@@ -136,6 +142,8 @@ class HybridParser(ast.NodeVisitor):
if
isinstance
(
v
,
types
.
FunctionType
):
self
.
add_symbol
(
k
,
Symbol
.
Callable
,
v
)
self
.
closure_vars
=
closure_vars
self
.
binds
=
{}
# Thread binds
self
.
device
=
0
# Is it generating device
...
...
@@ -236,7 +244,11 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
name
=
node
.
id
if
sys
.
version_info
[
0
]
==
2
and
name
in
[
'True'
,
'False'
]:
return
_api
.
convert
(
eval
(
name
))
#pylint: disable=eval-used
return
_api
.
convert
(
ast
.
literal_eval
(
name
))
if
name
in
self
.
closure_vars
:
return
_api
.
convert
(
self
.
closure_vars
[
name
])
ty
,
entry
=
self
.
symbols
[
name
]
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
...
...
@@ -356,10 +368,12 @@ class HybridParser(ast.NodeVisitor):
buf
=
self
.
visit
(
node
.
value
)
return
getattr
(
buf
,
node
.
attr
)
def
visit_Subscript
(
self
,
node
):
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
node
.
value
.
id
in
self
.
closure_vars
:
args
=
ast
.
literal_eval
(
str
(
args
))
return
_api
.
convert
(
_apply_indices
(
self
.
closure_vars
[
node
.
value
.
id
],
args
))
buf
=
self
.
visit
(
node
.
value
)
if
isinstance
(
buf
,
Array
):
...
...
@@ -576,7 +590,7 @@ class HybridParser(ast.NodeVisitor):
return
_make
.
AssertStmt
(
test
,
mesg
,
util
.
make_nop
())
def
parse_python
(
src
,
symbols
,
arg
s
):
def
parse_python
(
src
,
args
,
symbols
,
closure_var
s
):
"""The helper function of calling the AST visitor
Parameters
...
...
@@ -585,14 +599,17 @@ def parse_python(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
root : Stmt
...
...
@@ -600,14 +617,14 @@ def parse_python(src, symbols, args):
"""
root
=
ast
.
parse
(
src
)
if
isinstance
(
src
,
str
)
else
src
_internal_assert
(
root
,
ast
.
AST
)
var_usage
=
determine_variable_usage
(
root
,
args
,
symbols
)
parser
=
HybridParser
(
args
,
var_usage
,
symbols
)
var_usage
=
determine_variable_usage
(
root
,
args
,
symbols
,
closure_vars
)
parser
=
HybridParser
(
args
,
var_usage
,
symbols
,
closure_vars
)
parser
.
parsed_body
=
parser
.
visit
(
root
)
_internal_assert
(
parser
.
returned
,
'No valid return found in the function body!'
)
return
parser
def
source_to_op
(
src
,
symbols
,
arg
s
):
def
source_to_op
(
src
,
args
,
symbols
,
closure_var
s
):
"""Another level of wrapper
Parameters
...
...
@@ -616,20 +633,23 @@ def source_to_op(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
res : list of output tensors
The result of output tensors of the formed OpNode.
"""
parser
=
parse_python
(
src
,
symbols
,
arg
s
)
parser
=
parse_python
(
src
,
args
,
symbols
,
closure_var
s
)
input_tensors
=
[]
for
i
in
args
:
...
...
python/tvm/hybrid/preprocessor.py
View file @
ba15a729
...
...
@@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def
__init__
(
self
,
args
,
symbols
):
def
__init__
(
self
,
args
,
symbols
,
closure_vars
):
self
.
status
=
{}
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
args
=
args
self
.
aug_assign_
=
False
self
.
symbols
=
symbols
self
.
closure_vars
=
closure_vars
def
visit_FunctionDef
(
self
,
node
):
self
.
scope_level
.
append
(
node
)
...
...
@@ -89,6 +89,14 @@ class PyVariableUsage(ast.NodeVisitor):
"Iter var cannot be overwritten"
)
if
node
.
id
not
in
self
.
status
.
keys
():
# It is a captured value in closure
if
node
.
id
in
self
.
closure_vars
:
try
:
ast
.
literal_eval
(
str
(
self
.
closure_vars
[
node
.
id
]))
except
ValueError
:
raise
ValueError
(
"Only support capturing constant values in closure"
)
return
_internal_assert
(
isinstance
(
node
.
ctx
,
ast
.
Store
),
\
'Undeclared variable
%
s'
%
node
.
id
)
if
self
.
aug_assign_
:
...
...
@@ -102,8 +110,8 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
def
determine_variable_usage
(
root
,
args
,
symbols
):
def
determine_variable_usage
(
root
,
args
,
symbols
,
closure_vars
):
"""The helper function for calling the dedicated visitor."""
visitor
=
PyVariableUsage
(
args
,
symbols
)
visitor
=
PyVariableUsage
(
args
,
symbols
,
closure_vars
)
visitor
.
visit
(
root
)
return
visitor
.
status
python/tvm/hybrid/util.py
View file @
ba15a729
...
...
@@ -101,3 +101,9 @@ def _is_tvm_arg_types(args):
_internal_assert
(
isinstance
(
elem
,
np_arg_types
),
\
"Expect a numpy type but
%
s get!"
%
str
(
type
(
elem
)))
return
False
def
_apply_indices
(
value
,
indices
):
"""Apply multidimensional index"""
if
indices
:
return
_apply_indices
(
value
[
indices
[
0
]],
indices
[
1
:])
return
value
tests/python/unittest/test_hybrid_script.py
View file @
ba15a729
...
...
@@ -768,6 +768,24 @@ def test_schedule():
# Test loop binds
def
test_capture
():
n
=
8
constant_tuple
=
(
10
,
n
)
constant_list
=
[[
1
,
2
],
[
3
,
n
]]
const_value
=
1
@tvm.hybrid.script
def
add_something
(
a
):
c
=
output_tensor
((
constant_tuple
[
1
],),
'int32'
)
for
i
in
range
(
constant_tuple
[
1
]):
c
[
i
]
=
a
[
i
]
+
constant_list
[
1
][
const_value
]
return
c
a
=
tvm
.
placeholder
((
n
,
),
dtype
=
'int32'
,
name
=
'a'
)
func
,
ins
,
outs
=
run_and_check
(
add_something
,
[
a
])
run_and_check
(
func
,
ins
,
outs
=
outs
)
if
__name__
==
"__main__"
:
test_outer_product
()
...
...
@@ -786,5 +804,6 @@ if __name__ == "__main__":
test_bool
()
test_const_range
()
test_schedule
()
test_capture
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment