Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
838e7181
Commit
838e7181
authored
Dec 19, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Inter-function call supported! (#2287)
parent
001ab525
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
303 additions
and
184 deletions
+303
-184
python/tvm/hybrid/api.py
+1
-3
python/tvm/hybrid/calls.py
+92
-0
python/tvm/hybrid/intrin.py
+1
-14
python/tvm/hybrid/parser.py
+148
-159
python/tvm/hybrid/util.py
+18
-0
python/tvm/hybrid/var_decl.py
+10
-5
tests/python/unittest/test_hybrid_script.py
+33
-3
No files found.
python/tvm/hybrid/api.py
View file @
838e7181
...
...
@@ -24,17 +24,15 @@ def script(pyfunc):
from
.util
import
_enter_hybrid_runtime
,
_restore_runtime
,
_is_tvm_arg_types
if
_is_tvm_arg_types
(
args
):
src
=
_pruned_source
(
func
)
parser
=
parse_python
(
src
,
args
)
parser
=
parse_python
(
src
,
func
.
__globals__
,
args
)
input_tensors
=
[]
for
i
in
args
:
if
isinstance
(
i
,
Tensor
):
input_tensors
.
append
(
i
)
op
=
_tvm_internal
.
_HybridOp
(
parser
.
func_name
,
"HybridOp"
,
None
,
input_tensors
,
parser
.
outputs
,
parser
.
parsed_body
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
parser
.
outputs
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
intersect
=
_enter_hybrid_runtime
(
func
)
...
...
python/tvm/hybrid/calls.py
0 → 100644
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from
..
import
api
as
_api
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..container
import
Array
from
..
import
ir_pass
from
..stmt
import
For
from
.util
import
_internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
}
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
n
=
len
(
args
)
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
_internal_assert
(
n
==
2
,
"A loop intrinsic should only have 1 or 2 arguments!"
)
low
,
ext
=
args
[
0
],
args
[
1
]
if
not
ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
ext
=
ext
-
low
for_type
=
LOOP_INTRIN
[
annotation
]
iter_var
=
None
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
low
,
ext
=
_api
.
const
(
0
),
args
[
1
]
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
def
_math_intrin
(
func_id
,
args
):
from
..
import
intrin
return
getattr
(
intrin
,
func_id
)(
*
args
)
sqrt
=
log
=
exp
=
tanh
=
sigmoid
=
power
=
popcount
=
_math_intrin
#pylint: disable=invalid-name
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
min
=
max
=
_min_max
#pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
_internal_assert
(
isinstance
(
args
[
1
],
str
),
"The data type should be an str"
)
_internal_assert
(
args
[
1
]
.
startswith
(
'int'
)
or
args
[
1
]
.
startswith
(
'float'
),
\
"The data type should be either int or float!"
)
dtype
=
args
[
1
]
else
:
dtype
=
'float32'
if
n
>
2
:
_internal_assert
(
isinstance
(
args
[
2
],
str
),
\
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
args
[
2
]
else
:
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
python/tvm/hybrid/intrin.py
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python
emulation
runtime"""
import
numpy
from
..stmt
import
For
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
...
...
@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
}
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'bind'
:
None
}
MATH_INTRIN
=
[
'sqrt'
,
'log'
,
'exp'
,
'tanh'
,
'sigmoid'
,
'power'
,
'popcount'
]
python/tvm/hybrid/parser.py
View file @
838e7181
...
...
@@ -4,24 +4,24 @@ import ast
import
operator
import
logging
import
sys
from
.util
import
make_nop
,
halide_imm_types
,
is_docstring
,
_internal_assert
from
.intrin
import
LOOP_INTRIN
,
MATH_INTRIN
from
.util
import
_internal_assert
from
.
import
calls
from
.
import
util
from
.var_decl
import
determine_variable_usage
from
..api
import
thread_axis
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..tensor
import
Tensor
,
Operation
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..
import
intrin
from
..
import
api
as
_api
from
..
import
ir_pass
as
_ir_pass
def
list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
make_nop
())]
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
return
make_nop
()
return
util
.
make_nop
()
if
len
(
lst
)
==
1
:
return
lst
[
0
]
body
=
lst
[
0
]
...
...
@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor):
}
def
__init__
(
self
,
args
,
usage
,
func_name
=
None
):
def
__init__
(
self
,
args
,
usage
,
symbols
,
func_name
=
None
):
"""
Parameters
----------
...
...
@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor):
self
.
args
=
list
(
args
)
self
.
usage
=
usage
.
copy
()
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
alloc_buffers
=
{}
# Buffers formed by allocate instructions
self
.
alloc_buffers
=
{}
# Buffers formed by
explicit
allocate instructions
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
var
_consts
=
{}
# Variables that are determined as readonly in previous stage
self
.
var
iables
=
{}
# The status of defined variables
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
returned
=
False
self
.
returned
=
False
# If this function has a valid return
self
.
symbols
=
symbols
# The global context
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
pop_buf
=
[]
pop_var
=
[]
for
key
,
val
in
self
.
usage
.
items
():
if
key
in
self
.
var_consts
.
keys
():
continue
_
,
level
,
_
=
val
if
level
==
node
:
if
key
in
self
.
_args
.
keys
():
if
level
!=
node
:
continue
if
key
in
self
.
_args
.
keys
():
continue
if
key
in
self
.
alloc_buffers
.
keys
():
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
if
_scope
==
'output'
:
continue
else
:
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
pop_buf
.
append
(
key
)
else
:
_internal_assert
(
key
in
self
.
variables
.
keys
(),
"Key should be either in one of args, buffers, and vars"
)
if
not
isinstance
(
self
.
variables
[
key
],
tuple
):
continue
_buf
,
_scope
=
self
.
variables
[
key
]
pop_var
.
append
(
key
)
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
for
elem
in
pop_buf
:
self
.
alloc_buffers
.
pop
(
elem
)
for
elem
in
pop_var
:
self
.
variables
.
pop
(
elem
)
return
body
...
...
@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor):
return
self
.
alloc_buffers
[
s
][
0
]
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
...
...
@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
len
(
node
.
args
.
args
)
==
len
(
self
.
args
),
\
"The number of arguments passed to the
\
function should be the same as it is defined!"
)
if
self
.
func_name
is
None
:
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
_args
[
getattr
(
arg
,
_attr
)]
=
self
.
args
[
idx
]
res
=
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
if
self
.
func_name
is
None
:
self
.
func_name
=
node
.
name
return
res
...
...
@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
_id
=
node
.
id
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
(
_expr
.
Var
,
_expr
.
ConstExpr
)):
return
self
.
_args
[
_id
]
elif
_id
in
self
.
loops_above
.
keys
():
return
self
.
loops_above
[
_id
]
_internal_assert
(
_id
not
in
self
.
_args
.
keys
(),
\
"This id
%
s should be handled in visit_Subscript!"
%
_id
)
_internal_assert
(
_id
in
self
.
usage
.
keys
(),
\
"This id
%
s is expected to be a defined variable!"
%
_id
)
# Buffer
if
_id
in
self
.
alloc_buffers
.
keys
():
_buf
,
_
=
self
.
alloc_buffers
[
_id
]
return
_make
.
Call
(
_buf
.
dtype
,
_id
,
[
_api
.
const
(
0
)],
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
# Compilation time constant
_internal_assert
(
_id
in
self
.
var_consts
.
keys
(),
"This id
%
s is expected to a compilation time constant!"
%
_id
)
return
self
.
var_consts
[
_id
]
name
=
node
.
id
if
name
in
self
.
loops_above
.
keys
():
return
self
.
loops_above
[
name
]
elif
name
in
self
.
variables
.
keys
():
res
=
self
.
variables
[
name
]
if
isinstance
(
res
,
tuple
):
buf
=
res
[
0
]
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
[
_api
.
const
(
0
)],
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
[
_api
.
const
(
0
)]
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
res
return
None
buf
=
self
.
_get_buffer_from_id
(
name
)
return
buf
def
visit_Num
(
self
,
node
):
...
...
@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor):
def
visit_AugAssign
(
self
,
node
):
lhs
=
self
.
visit
(
node
.
target
)
buf
=
self
.
visit
(
node
.
target
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
lhs
,
rhs
)
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
"The LHS of an AugAssign is supposed to be a call!"
)
return
_make
.
Provide
(
lhs
.
func
,
0
,
rhs
,
lhs
.
args
)
if
isinstance
(
buf
,
tuple
):
_internal_assert
(
len
(
buf
)
==
2
,
"LHS is supposed to be (buf, args)!"
)
buf
,
args
=
buf
else
:
args
=
[
_api
.
const
(
0
)]
_internal_assert
(
isinstance
(
buf
,
Tensor
),
"LHS is supposed to be Tensor!"
)
read
=
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
value
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
read
,
rhs
)
return
_make
.
Provide
(
buf
.
op
,
0
,
value
,
args
)
def
visit_Assign
(
self
,
node
):
rhs
=
self
.
visit
(
node
.
value
)
if
isinstance
(
rhs
,
Operation
):
rmap
=
{}
_internal_assert
(
len
(
node
.
targets
)
==
rhs
.
num_outputs
,
\
"Unable to detuple the outs to targets"
)
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
self
.
alloc_buffers
[
node
.
targets
[
i
]
.
id
]
=
(
rhs
.
output
(
i
),
'global'
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
lhs
=
node
.
targets
[
0
]
rhs
=
self
.
visit
(
node
.
value
)
if
isinstance
(
rhs
,
_expr
.
Expr
):
rhs
=
_ir_pass
.
Simplify
(
rhs
)
if
isinstance
(
lhs
,
ast
.
Name
):
...
...
@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor):
"Loop variable cannot be overwritten!"
)
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
if
decl
==
lhs_
:
_internal_assert
(
lhs
not
in
self
.
var_consts
.
keys
(),
\
"A constant cannot be overwritten!"
)
_internal_assert
(
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
variables
.
keys
()
and
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
"This value should not be defined before this point!"
)
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
if
scope
!=
'output'
:
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
else
:
self
.
_args
[
lhs
]
=
ph
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
if
scope
==
'output'
:
self
.
outputs
.
append
(
lhs
)
return
make_nop
()
if
isinstance
(
rhs
,
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
var
_const
s
[
lhs
]
=
rhs
return
util
.
make_nop
()
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
var
iable
s
[
lhs
]
=
rhs
else
:
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
'global'
)
if
lhs
in
self
.
var_consts
.
keys
():
return
make_nop
()
_internal_assert
(
lhs
in
self
.
alloc_buffers
.
keys
(),
\
"This variable should be defined before!"
)
tgt
,
_
=
self
.
alloc_buffers
[
lhs
]
return
_make
.
Provide
(
tgt
.
op
,
0
,
rhs
,
[
_api
.
const
(
0
,
dtype
=
rhs
.
dtype
)])
self
.
variables
[
lhs
]
=
(
ph
,
'global'
)
lhs
=
self
.
visit
(
lhs_
)
if
lhs
is
not
None
:
buf
,
args
=
lhs
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
args
)
return
util
.
make_nop
()
else
:
lhs
=
self
.
visit
(
lhs
)
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
lhs
,
args
=
self
.
visit
(
lhs
)
_internal_assert
(
isinstance
(
lhs
,
Tensor
),
\
"An array access's LHS is expected to be a expr.Call!"
)
#TODO: support slice later
buf
=
self
.
_get_buffer_from_id
(
lhs
.
name
,
for_provide
=
True
)
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
lhs
.
args
)
res
=
_make
.
Provide
(
lhs
.
op
,
lhs
.
value_index
,
rhs
,
args
)
return
res
def
visit_Index
(
self
,
node
):
if
isinstance
(
node
.
value
,
ast
.
Tuple
):
return
[
self
.
visit
(
i
)
for
i
in
node
.
value
.
elts
]
return
self
.
visit
(
node
.
value
)
return
[
self
.
visit
(
node
.
value
)]
def
visit_Attribute
(
self
,
node
):
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Name
),
\
"For atrribute access, only both names are supported so far!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
id
)
return
getattr
(
buf
,
node
.
attr
)
def
visit_Subscript
(
self
,
node
):
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
array
=
node
.
value
.
id
_buf
=
self
.
_get_buffer_from_id
(
array
)
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
_buf
.
value_index
)
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
"Only variable and attribute's subscript supported so far"
)
_internal_assert
(
isinstance
(
node
.
value
.
value
,
ast
.
Name
),
\
"The root of array access is expect to be a id!"
)
_internal_assert
(
node
.
value
.
attr
==
"shape"
,
\
"Attribute access so far only 'shape' is supported!"
)
buf
=
self
.
visit
(
node
.
value
)
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
buf
,
args
shape
=
self
.
visit
(
node
.
value
)
_internal_assert
(
len
(
args
)
==
1
,
"For 'shape' access the argument should be only one!"
)
args
=
args
[
0
]
#TODO: maybe support non-constant value later?
_internal_assert
(
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
"So far only constant shape access supported!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
value
.
id
)
return
buf
.
shape
[
args
.
value
]
return
shape
[
args
.
value
]
def
visit_With
(
self
,
node
):
...
...
@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor):
if
node
.
orelse
:
else_body
=
list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else_body
=
make_nop
()
else_body
=
util
.
make_nop
()
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
...
@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
node
.
op
,
ast
.
Not
),
\
"Unary is supposed to be not!"
)
return
operator
.
not_
(
self
.
visit
(
node
.
values
[
0
]))
elif
n
==
2
:
_internal_assert
(
isinstance
(
node
.
op
,
(
ast
.
And
,
ast
.
Or
)),
\
"Binary is supposed to be and/or!"
)
values
=
[
self
.
visit
(
i
)
for
i
in
node
.
values
]
return
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
*
values
)
else
:
raise
ValueError
(
"This Bool Op is not supported yet!"
)
_internal_assert
(
isinstance
(
node
.
op
,
(
ast
.
And
,
ast
.
Or
)),
\
"Binary is supposed to be and/or!"
)
values
=
[
self
.
visit
(
i
)
for
i
in
node
.
values
]
return
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
*
values
)
def
visit_UnaryOp
(
self
,
node
):
...
...
@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor):
# Yet, no function pointer supported
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
\
"Only id-function function call is supported so far!"
)
func_id
=
node
.
func
.
id
n
=
len
(
node
.
args
)
if
func_id
in
LOOP_INTRIN
.
keys
()
and
func_id
!=
'bind'
:
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
0
])
else
:
_internal_assert
(
n
==
2
,
"A loop intrinsic should only have 1 or 2 arguments!"
)
low
,
ext
=
self
.
visit
(
node
.
args
[
0
]),
self
.
visit
(
node
.
args
[
1
])
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
ext
=
ext
-
low
for_type
=
LOOP_INTRIN
[
func_id
]
iter_var
=
None
return
iter_var
,
low
,
ext
,
for_type
elif
func_id
==
'bind'
:
_internal_assert
(
n
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Str
),
\
"A loop bind's first argument should be a string!"
)
_vn
=
node
.
args
[
0
]
.
s
iter_var
=
thread_axis
(
node
.
args
[
0
]
.
s
)
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
1
])
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
elif
func_id
in
MATH_INTRIN
:
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
elif
func_id
in
[
'allocate'
,
'output_tensor'
]:
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Tuple
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
args
[
0
]
.
elts
)
if
func_id
==
'output_tensor'
:
_internal_assert
(
not
self
.
loops_above
,
\
"Are you sure to allocate a output buffer multiple times?"
)
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
if
isinstance
(
node
.
args
[
1
],
ast
.
Str
):
dtype
=
node
.
args
[
1
]
.
s
else
:
_internal_assert
(
isinstance
(
node
.
args
[
1
],
ast
.
Attribute
),
\
"Unable to evaluate to get data type"
)
to_eval
=
node
.
args
[
1
]
_internal_assert
(
isinstance
(
to_eval
.
value
,
ast
.
Name
),
\
"Unable to evaluate the attribute to get data type"
)
_internal_assert
(
to_eval
.
attr
==
'dtype'
,
\
"Only dtype attribute is supported so far"
)
dtype
=
self
.
_get_buffer_from_id
(
to_eval
.
value
.
id
)
.
dtype
else
:
dtype
=
'float32'
if
n
>
2
:
_internal_assert
(
isinstance
(
node
.
args
[
2
],
ast
.
Str
),
\
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
node
.
args
[
2
]
.
s
else
:
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
elif
func_id
==
'max'
or
func_id
==
'min'
:
_internal_assert
(
n
==
2
,
"Max/Min function should have 2 elements"
)
a
,
b
=
self
.
visit
(
node
.
args
[
0
]),
self
.
visit
(
node
.
args
[
1
])
return
getattr
(
_make
,
func_id
.
title
())(
a
,
b
)
else
:
raise
ValueError
(
"Function call not supported yet!"
)
args
=
[
self
.
visit
(
i
)
for
i
in
node
.
args
]
try
:
return
getattr
(
calls
,
func_id
)(
func_id
,
args
)
except
AttributeError
:
_internal_assert
(
func_id
in
self
.
symbols
.
keys
(),
\
"The function called is not in the context either!"
)
outs
=
self
.
symbols
[
func_id
](
*
args
)
op
=
outs
.
op
if
isinstance
(
outs
,
Tensor
)
else
outs
[
0
]
.
op
return
op
def
visit_For
(
self
,
node
):
...
...
@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor):
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
)):
offset
=
iter_var
+
low
self
.
loops_above
[
_name
]
=
offset
else
:
...
...
@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor):
if
for_type
is
None
:
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
else
:
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
,
dtype
=
'int32'
),
ext
,
for_type
,
0
,
_body
)
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
),
ext
,
for_type
,
0
,
_body
)
self
.
loops_above
.
pop
(
_name
)
return
res
...
...
@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
ids
.
append
(
i
.
id
)
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
if
len
(
ids
)
!=
len
(
self
.
outputs
):
if
len
(
ids
)
<
len
(
self
.
outputs
):
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
self
.
outputs
=
[
self
.
_args
[
i
]
for
i
in
ids
]
self
.
outputs
=
[
self
.
alloc_buffers
[
i
][
0
]
for
i
in
ids
]
self
.
returned
=
True
return
make_nop
()
return
util
.
make_nop
()
def
visit_Tuple
(
self
,
node
):
return
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
elts
)
def
parse_python
(
src
,
args
):
def
visit_Str
(
self
,
node
):
return
node
.
s
def
parse_python
(
src
,
symbols
,
args
):
"""The helper function of calling the AST visitor
Parameters
...
...
@@ -443,6 +429,9 @@ def parse_python(src, args):
src : str
The source code of the function to be parsed.
src : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
...
...
@@ -454,8 +443,8 @@ def parse_python(src, args):
The result Halide IR and the parser class instance.
"""
root
=
ast
.
parse
(
src
)
var_usage
=
determine_variable_usage
(
root
,
args
)
parser
=
HybridParser
(
args
,
var_usage
)
var_usage
=
determine_variable_usage
(
root
,
args
,
symbols
)
parser
=
HybridParser
(
args
,
var_usage
,
symbols
)
parser
.
parsed_body
=
parser
.
visit
(
root
)
_internal_assert
(
parser
.
returned
,
'No valid return found in the function body!'
)
return
parser
python/tvm/hybrid/util.py
View file @
838e7181
...
...
@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
from
..
import
api
as
_api
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..tensor
import
Tensor
...
...
@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
_globals
.
pop
(
elem
)
for
k
,
v
in
intersect
:
_globals
[
k
]
=
v
def
replace_io
(
body
,
rmap
):
"""Replacing tensors usage according to the dict given"""
from
..
import
ir_pass
def
replace
(
op
):
if
isinstance
(
op
,
_stmt
.
Provide
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Provide
(
buf
.
op
,
op
.
value_index
,
op
.
value
,
op
.
args
)
elif
isinstance
(
op
,
_expr
.
Call
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
op
.
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
None
return
ir_pass
.
IRTransform
(
body
,
None
,
replace
,
[
'Provide'
,
'Call'
])
python/tvm/hybrid/var_decl.py
View file @
838e7181
...
...
@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
,
symbols
):
self
.
status
=
{}
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
args
=
args
self
.
aug_assign_
=
False
self
.
symbols
=
symbols
def
visit_FunctionDef
(
self
,
node
):
...
...
@@ -43,8 +44,10 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
"Function call should be an id"
)
func_id
=
node
.
func
.
id
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
[
'range'
,
'max'
,
'min'
],
\
"Function call id not in intrinsics' list"
)
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
\
[
'range'
,
'max'
,
'min'
]
+
\
list
(
self
.
symbols
.
keys
()),
\
"Function call id not in intrinsics' list"
)
for
elem
in
node
.
args
:
self
.
visit
(
elem
)
...
...
@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
usage
.
add
(
type
(
node
.
ctx
))
_internal_assert
(
loop
in
self
.
scope_level
,
"
%
s is used out of the scope it is defined!"
%
node
.
id
)
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
def
determine_variable_usage
(
root
,
args
):
def
determine_variable_usage
(
root
,
args
,
symbols
):
"""The helper function for calling the dedicated visitor."""
visitor
=
PyVariableUsage
(
args
)
visitor
=
PyVariableUsage
(
args
,
symbols
)
visitor
.
visit
(
root
)
return
visitor
.
status
tests/python/unittest/test_hybrid_script.py
View file @
838e7181
...
...
@@ -270,7 +270,7 @@ def test_bind():
return
@script
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
c
=
output_tensor
((
1000
,
),
'float32'
)
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
return
c
...
...
@@ -506,7 +506,37 @@ def test_value_index():
module
(
tvm
.
ndarray
.
array
(
np_a
),
res
)
tvm
.
testing
.
assert_allclose
(
res
.
asnumpy
(),
ref
)
def
test_func_call
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
for
i
in
range
(
10
):
for
j
in
range
(
10
):
d
[
i
,
j
]
=
c
[
i
,
j
]
+
i
*
j
return
d
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
b
=
tvm
.
placeholder
((
10
,
),
name
=
'b'
)
run_and_check
(
foo
,
[
a
,
b
])
def
test_bool
():
@tvm.hybrid.script
def
foo
(
a
):
b
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
b
[
0
]
=
1.2
for
i
in
range
(
1
,
a
.
shape
[
0
]
-
1
):
if
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
]
or
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
-
1
]
or
i
*
a
[
i
]
==
a
[
i
]:
b
[
i
]
=
a
[
i
]
else
:
b
[
i
]
=
0.0
return
b
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
if
__name__
==
"__main__"
:
test_outer_product
()
...
...
@@ -521,7 +551,7 @@ if __name__ == "__main__":
test_downstream
()
test_const_param
()
test_value_index
()
test_func_call
()
test_bool
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment