Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
838e7181
Commit
838e7181
authored
Dec 19, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Hybrid Script] Inter-function call supported! (#2287)
parent
001ab525
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
290 additions
and
171 deletions
+290
-171
python/tvm/hybrid/api.py
+1
-3
python/tvm/hybrid/calls.py
+92
-0
python/tvm/hybrid/intrin.py
+1
-14
python/tvm/hybrid/parser.py
+136
-147
python/tvm/hybrid/util.py
+18
-0
python/tvm/hybrid/var_decl.py
+9
-4
tests/python/unittest/test_hybrid_script.py
+33
-3
No files found.
python/tvm/hybrid/api.py
View file @
838e7181
...
@@ -24,17 +24,15 @@ def script(pyfunc):
...
@@ -24,17 +24,15 @@ def script(pyfunc):
from
.util
import
_enter_hybrid_runtime
,
_restore_runtime
,
_is_tvm_arg_types
from
.util
import
_enter_hybrid_runtime
,
_restore_runtime
,
_is_tvm_arg_types
if
_is_tvm_arg_types
(
args
):
if
_is_tvm_arg_types
(
args
):
src
=
_pruned_source
(
func
)
src
=
_pruned_source
(
func
)
parser
=
parse_python
(
src
,
args
)
parser
=
parse_python
(
src
,
func
.
__globals__
,
args
)
input_tensors
=
[]
input_tensors
=
[]
for
i
in
args
:
for
i
in
args
:
if
isinstance
(
i
,
Tensor
):
if
isinstance
(
i
,
Tensor
):
input_tensors
.
append
(
i
)
input_tensors
.
append
(
i
)
op
=
_tvm_internal
.
_HybridOp
(
parser
.
func_name
,
"HybridOp"
,
None
,
input_tensors
,
op
=
_tvm_internal
.
_HybridOp
(
parser
.
func_name
,
"HybridOp"
,
None
,
input_tensors
,
parser
.
outputs
,
parser
.
parsed_body
)
parser
.
outputs
,
parser
.
parsed_body
)
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
parser
.
outputs
))]
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
parser
.
outputs
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
return
res
[
0
]
if
len
(
res
)
==
1
else
res
intersect
=
_enter_hybrid_runtime
(
func
)
intersect
=
_enter_hybrid_runtime
(
func
)
...
...
python/tvm/hybrid/calls.py
0 → 100644
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from
..
import
api
as
_api
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..container
import
Array
from
..
import
ir_pass
from
..stmt
import
For
from
.util
import
_internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
}
def
_range
(
annotation
,
args
):
"""Handling TVM loop types"""
n
=
len
(
args
)
if
n
==
1
:
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
args
[
0
]
else
:
_internal_assert
(
n
==
2
,
"A loop intrinsic should only have 1 or 2 arguments!"
)
low
,
ext
=
args
[
0
],
args
[
1
]
if
not
ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
ext
=
ext
-
low
for_type
=
LOOP_INTRIN
[
annotation
]
iter_var
=
None
return
iter_var
,
low
,
ext
,
for_type
range
=
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
bind
(
func_id
,
args
):
"""Handling TVM thread binding"""
_internal_assert
(
func_id
==
"bind"
,
"This function cannot be directly invoked!"
)
_internal_assert
(
len
(
args
)
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
low
,
ext
=
_api
.
const
(
0
),
args
[
1
]
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
def
_math_intrin
(
func_id
,
args
):
from
..
import
intrin
return
getattr
(
intrin
,
func_id
)(
*
args
)
sqrt
=
log
=
exp
=
tanh
=
sigmoid
=
power
=
popcount
=
_math_intrin
#pylint: disable=invalid-name
def
_min_max
(
func_id
,
args
):
_internal_assert
(
len
(
args
)
==
2
,
"Max/Min function should have 2 elements"
)
return
getattr
(
_make
,
func_id
.
title
())(
args
[
0
],
args
[
1
])
min
=
max
=
_min_max
#pylint: disable=invalid-name
def
_allocate_tensor
(
func_id
,
args
):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n
=
len
(
args
)
_internal_assert
(
isinstance
(
_api
.
convert
(
args
[
0
]),
Array
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
args
[
0
]
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
_internal_assert
(
isinstance
(
args
[
1
],
str
),
"The data type should be an str"
)
_internal_assert
(
args
[
1
]
.
startswith
(
'int'
)
or
args
[
1
]
.
startswith
(
'float'
),
\
"The data type should be either int or float!"
)
dtype
=
args
[
1
]
else
:
dtype
=
'float32'
if
n
>
2
:
_internal_assert
(
isinstance
(
args
[
2
],
str
),
\
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
args
[
2
]
else
:
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
output_tensor
=
allocate
=
_allocate_tensor
#pylint: disable=invalid-name
python/tvm/hybrid/intrin.py
View file @
838e7181
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python
emulation
runtime"""
import
numpy
import
numpy
from
..stmt
import
For
class
_range
(
object
):
class
_range
(
object
):
"""Base class of the loop ranges in hybrid script"""
"""Base class of the loop ranges in hybrid script"""
...
@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
...
@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
'sigmoid'
:
sigmoid
,
'sigmoid'
:
sigmoid
,
'popcount'
:
popcount
'popcount'
:
popcount
}
}
LOOP_INTRIN
=
{
'range'
:
For
.
Serial
,
'unroll'
:
For
.
Unrolled
,
'parallel'
:
For
.
Parallel
,
'vectorize'
:
For
.
Vectorized
,
'bind'
:
None
}
MATH_INTRIN
=
[
'sqrt'
,
'log'
,
'exp'
,
'tanh'
,
'sigmoid'
,
'power'
,
'popcount'
]
python/tvm/hybrid/parser.py
View file @
838e7181
...
@@ -4,24 +4,24 @@ import ast
...
@@ -4,24 +4,24 @@ import ast
import
operator
import
operator
import
logging
import
logging
import
sys
import
sys
from
.util
import
make_nop
,
halide_imm_types
,
is_docstring
,
_internal_assert
from
.util
import
_internal_assert
from
.intrin
import
LOOP_INTRIN
,
MATH_INTRIN
from
.
import
calls
from
.
import
util
from
.var_decl
import
determine_variable_usage
from
.var_decl
import
determine_variable_usage
from
..api
import
thread_axis
from
..api
import
all
as
_all
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..api
import
any
as
_any
from
..tensor
import
Tensor
,
Operation
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
make
as
_make
from
..
import
make
as
_make
from
..
import
intrin
from
..
import
api
as
_api
from
..
import
api
as
_api
from
..
import
ir_pass
as
_ir_pass
from
..
import
ir_pass
as
_ir_pass
def
list_to_block
(
visit
,
lst
):
def
list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
is_docstring
(
stmt
)]
lst
=
[
visit
(
stmt
)
for
stmt
in
lst
if
not
util
.
is_docstring
(
stmt
)]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
make_nop
())]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
util
.
make_nop
())]
if
not
lst
:
if
not
lst
:
return
make_nop
()
return
util
.
make_nop
()
if
len
(
lst
)
==
1
:
if
len
(
lst
)
==
1
:
return
lst
[
0
]
return
lst
[
0
]
body
=
lst
[
0
]
body
=
lst
[
0
]
...
@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor):
}
}
def
__init__
(
self
,
args
,
usage
,
func_name
=
None
):
def
__init__
(
self
,
args
,
usage
,
symbols
,
func_name
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor):
...
@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor):
self
.
args
=
list
(
args
)
self
.
args
=
list
(
args
)
self
.
usage
=
usage
.
copy
()
self
.
usage
=
usage
.
copy
()
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
alloc_buffers
=
{}
# Buffers formed by allocate instructions
self
.
alloc_buffers
=
{}
# Buffers formed by
explicit
allocate instructions
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
var
_consts
=
{}
# Variables that are determined as readonly in previous stage
self
.
var
iables
=
{}
# The status of defined variables
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
self
.
outputs
=
[]
# Output tensors' name
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
side_effect
=
set
()
# Tensors with side effects
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
parsed_body
=
None
# The parsed HalideIR body
self
.
returned
=
False
self
.
returned
=
False
# If this function has a valid return
self
.
symbols
=
symbols
# The global context
def
wrap_up_realize
(
self
,
node
,
body
):
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
"""Wrap up all the variables which will no longer be used"""
pop_buf
=
[]
pop_var
=
[]
for
key
,
val
in
self
.
usage
.
items
():
for
key
,
val
in
self
.
usage
.
items
():
if
key
in
self
.
var_consts
.
keys
():
continue
_
,
level
,
_
=
val
_
,
level
,
_
=
val
if
level
==
node
:
if
level
!=
node
:
continue
if
key
in
self
.
_args
.
keys
():
if
key
in
self
.
_args
.
keys
():
continue
continue
else
:
if
key
in
self
.
alloc_buffers
.
keys
()
:
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
if
_scope
==
'output'
:
continue
pop_buf
.
append
(
key
)
else
:
_internal_assert
(
key
in
self
.
variables
.
keys
(),
"Key should be either in one of args, buffers, and vars"
)
if
not
isinstance
(
self
.
variables
[
key
],
tuple
):
continue
_buf
,
_scope
=
self
.
variables
[
key
]
pop_var
.
append
(
key
)
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
_true
=
_api
.
convert
(
True
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
),
body
)
for
elem
in
pop_buf
:
self
.
alloc_buffers
.
pop
(
elem
)
for
elem
in
pop_var
:
self
.
variables
.
pop
(
elem
)
return
body
return
body
...
@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor):
...
@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor):
return
self
.
alloc_buffers
[
s
][
0
]
return
self
.
alloc_buffers
[
s
][
0
]
#pylint: disable=invalid-name, missing-docstring
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
...
@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor):
...
@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
len
(
node
.
args
.
args
)
==
len
(
self
.
args
),
\
_internal_assert
(
len
(
node
.
args
.
args
)
==
len
(
self
.
args
),
\
"The number of arguments passed to the
\
"The number of arguments passed to the
\
function should be the same as it is defined!"
)
function should be the same as it is defined!"
)
if
self
.
func_name
is
None
:
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
_args
[
getattr
(
arg
,
_attr
)]
=
self
.
args
[
idx
]
self
.
_args
[
getattr
(
arg
,
_attr
)]
=
self
.
args
[
idx
]
res
=
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
if
self
.
func_name
is
None
:
self
.
func_name
=
node
.
name
return
res
return
res
...
@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor):
...
@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
_id
=
node
.
id
name
=
node
.
id
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
(
_expr
.
Var
,
_expr
.
ConstExpr
)):
if
name
in
self
.
loops_above
.
keys
():
return
self
.
_args
[
_id
]
return
self
.
loops_above
[
name
]
elif
_id
in
self
.
loops_above
.
keys
():
elif
name
in
self
.
variables
.
keys
():
return
self
.
loops_above
[
_id
]
res
=
self
.
variables
[
name
]
_internal_assert
(
_id
not
in
self
.
_args
.
keys
(),
\
if
isinstance
(
res
,
tuple
):
"This id
%
s should be handled in visit_Subscript!"
%
_id
)
buf
=
res
[
0
]
_internal_assert
(
_id
in
self
.
usage
.
keys
(),
\
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
"This id
%
s is expected to be a defined variable!"
%
_id
)
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
[
_api
.
const
(
0
)],
\
# Buffer
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
if
_id
in
self
.
alloc_buffers
.
keys
():
return
buf
,
[
_api
.
const
(
0
)]
_buf
,
_
=
self
.
alloc_buffers
[
_id
]
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
_buf
.
dtype
,
_id
,
[
_api
.
const
(
0
)],
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
return
res
# Compilation time constant
return
None
_internal_assert
(
_id
in
self
.
var_consts
.
keys
(),
buf
=
self
.
_get_buffer_from_id
(
name
)
"This id
%
s is expected to a compilation time constant!"
%
_id
)
return
buf
return
self
.
var_consts
[
_id
]
def
visit_Num
(
self
,
node
):
def
visit_Num
(
self
,
node
):
...
@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor):
...
@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor):
def
visit_AugAssign
(
self
,
node
):
def
visit_AugAssign
(
self
,
node
):
lhs
=
self
.
visit
(
node
.
target
)
buf
=
self
.
visit
(
node
.
target
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
lhs
,
rhs
)
if
isinstance
(
buf
,
tuple
):
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
_internal_assert
(
len
(
buf
)
==
2
,
"LHS is supposed to be (buf, args)!"
)
"The LHS of an AugAssign is supposed to be a call!"
)
buf
,
args
=
buf
return
_make
.
Provide
(
lhs
.
func
,
0
,
rhs
,
lhs
.
args
)
else
:
args
=
[
_api
.
const
(
0
)]
_internal_assert
(
isinstance
(
buf
,
Tensor
),
"LHS is supposed to be Tensor!"
)
read
=
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
value
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
read
,
rhs
)
return
_make
.
Provide
(
buf
.
op
,
0
,
value
,
args
)
def
visit_Assign
(
self
,
node
):
def
visit_Assign
(
self
,
node
):
rhs
=
self
.
visit
(
node
.
value
)
if
isinstance
(
rhs
,
Operation
):
rmap
=
{}
_internal_assert
(
len
(
node
.
targets
)
==
rhs
.
num_outputs
,
\
"Unable to detuple the outs to targets"
)
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
self
.
alloc_buffers
[
node
.
targets
[
i
]
.
id
]
=
(
rhs
.
output
(
i
),
'global'
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
lhs
=
node
.
targets
[
0
]
lhs
=
node
.
targets
[
0
]
rhs
=
self
.
visit
(
node
.
value
)
if
isinstance
(
rhs
,
_expr
.
Expr
):
if
isinstance
(
rhs
,
_expr
.
Expr
):
rhs
=
_ir_pass
.
Simplify
(
rhs
)
rhs
=
_ir_pass
.
Simplify
(
rhs
)
if
isinstance
(
lhs
,
ast
.
Name
):
if
isinstance
(
lhs
,
ast
.
Name
):
...
@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor):
...
@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor):
"Loop variable cannot be overwritten!"
)
"Loop variable cannot be overwritten!"
)
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
decl
,
_
,
rw
=
self
.
usage
[
lhs
]
if
decl
==
lhs_
:
if
decl
==
lhs_
:
_internal_assert
(
lhs
not
in
self
.
var_consts
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
variables
.
keys
()
and
"A constant cannot be overwritten!"
)
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
_internal_assert
(
lhs
not
in
self
.
alloc_buffers
.
keys
(),
\
"This value should not be defined before this point!"
)
"This value should not be defined before this point!"
)
if
isinstance
(
rhs
,
tuple
):
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
if
scope
!=
'output'
:
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
else
:
if
scope
==
'output'
:
self
.
_args
[
lhs
]
=
ph
self
.
outputs
.
append
(
lhs
)
self
.
outputs
.
append
(
lhs
)
return
make_nop
()
return
util
.
make_nop
()
if
isinstance
(
rhs
,
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
var
_const
s
[
lhs
]
=
rhs
self
.
var
iable
s
[
lhs
]
=
rhs
else
:
else
:
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
'global'
)
self
.
variables
[
lhs
]
=
(
ph
,
'global'
)
if
lhs
in
self
.
var_consts
.
keys
():
lhs
=
self
.
visit
(
lhs_
)
return
make_nop
()
if
lhs
is
not
None
:
_internal_assert
(
lhs
in
self
.
alloc_buffers
.
keys
(),
\
buf
,
args
=
lhs
"This variable should be defined before!"
)
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
args
)
tgt
,
_
=
self
.
alloc_buffers
[
lhs
]
return
util
.
make_nop
()
return
_make
.
Provide
(
tgt
.
op
,
0
,
rhs
,
[
_api
.
const
(
0
,
dtype
=
rhs
.
dtype
)])
else
:
else
:
lhs
=
self
.
visit
(
lhs
)
lhs
,
args
=
self
.
visit
(
lhs
)
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
_internal_assert
(
isinstance
(
lhs
,
Tensor
),
\
"An array access's LHS is expected to be a expr.Call!"
)
"An array access's LHS is expected to be a expr.Call!"
)
#TODO: support slice later
res
=
_make
.
Provide
(
lhs
.
op
,
lhs
.
value_index
,
rhs
,
args
)
buf
=
self
.
_get_buffer_from_id
(
lhs
.
name
,
for_provide
=
True
)
return
res
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
lhs
.
args
)
def
visit_Index
(
self
,
node
):
def
visit_Index
(
self
,
node
):
if
isinstance
(
node
.
value
,
ast
.
Tuple
):
if
isinstance
(
node
.
value
,
ast
.
Tuple
):
return
[
self
.
visit
(
i
)
for
i
in
node
.
value
.
elts
]
return
self
.
visit
(
node
.
value
)
return
[
self
.
visit
(
node
.
value
)]
return
[
self
.
visit
(
node
.
value
)]
def
visit_Attribute
(
self
,
node
):
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Name
),
\
"For atrribute access, only both names are supported so far!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
id
)
return
getattr
(
buf
,
node
.
attr
)
def
visit_Subscript
(
self
,
node
):
def
visit_Subscript
(
self
,
node
):
args
=
self
.
visit
(
node
.
slice
)
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
isinstance
(
node
.
value
,
ast
.
Name
):
array
=
node
.
value
.
id
buf
=
self
.
visit
(
node
.
value
)
_buf
=
self
.
_get_buffer_from_id
(
array
)
if
isinstance
(
node
.
ctx
,
ast
.
Load
):
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
_buf
.
value_index
)
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
return
buf
,
args
"Only variable and attribute's subscript supported so far"
)
_internal_assert
(
isinstance
(
node
.
value
.
value
,
ast
.
Name
),
\
shape
=
self
.
visit
(
node
.
value
)
"The root of array access is expect to be a id!"
)
_internal_assert
(
node
.
value
.
attr
==
"shape"
,
\
"Attribute access so far only 'shape' is supported!"
)
_internal_assert
(
len
(
args
)
==
1
,
"For 'shape' access the argument should be only one!"
)
_internal_assert
(
len
(
args
)
==
1
,
"For 'shape' access the argument should be only one!"
)
args
=
args
[
0
]
args
=
args
[
0
]
#TODO: maybe support non-constant value later?
#TODO: maybe support non-constant value later?
_internal_assert
(
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
_internal_assert
(
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)),
\
"So far only constant shape access supported!"
)
"So far only constant shape access supported!"
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
value
.
id
)
return
shape
[
args
.
value
]
return
buf
.
shape
[
args
.
value
]
def
visit_With
(
self
,
node
):
def
visit_With
(
self
,
node
):
...
@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor):
if
node
.
orelse
:
if
node
.
orelse
:
else_body
=
list_to_block
(
self
.
visit
,
node
.
orelse
)
else_body
=
list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else
:
else_body
=
make_nop
()
else_body
=
util
.
make_nop
()
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor):
...
@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
node
.
op
,
ast
.
Not
),
\
_internal_assert
(
isinstance
(
node
.
op
,
ast
.
Not
),
\
"Unary is supposed to be not!"
)
"Unary is supposed to be not!"
)
return
operator
.
not_
(
self
.
visit
(
node
.
values
[
0
]))
return
operator
.
not_
(
self
.
visit
(
node
.
values
[
0
]))
elif
n
==
2
:
_internal_assert
(
isinstance
(
node
.
op
,
(
ast
.
And
,
ast
.
Or
)),
\
_internal_assert
(
isinstance
(
node
.
op
,
(
ast
.
And
,
ast
.
Or
)),
\
"Binary is supposed to be and/or!"
)
"Binary is supposed to be and/or!"
)
values
=
[
self
.
visit
(
i
)
for
i
in
node
.
values
]
values
=
[
self
.
visit
(
i
)
for
i
in
node
.
values
]
return
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
*
values
)
return
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
*
values
)
else
:
raise
ValueError
(
"This Bool Op is not supported yet!"
)
def
visit_UnaryOp
(
self
,
node
):
def
visit_UnaryOp
(
self
,
node
):
...
@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor):
...
@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor):
# Yet, no function pointer supported
# Yet, no function pointer supported
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
\
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
\
"Only id-function function call is supported so far!"
)
"Only id-function function call is supported so far!"
)
func_id
=
node
.
func
.
id
func_id
=
node
.
func
.
id
n
=
len
(
node
.
args
)
args
=
[
self
.
visit
(
i
)
for
i
in
node
.
args
]
if
func_id
in
LOOP_INTRIN
.
keys
()
and
func_id
!=
'bind'
:
try
:
if
n
==
1
:
return
getattr
(
calls
,
func_id
)(
func_id
,
args
)
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
0
])
except
AttributeError
:
else
:
_internal_assert
(
func_id
in
self
.
symbols
.
keys
(),
\
_internal_assert
(
n
==
2
,
"A loop intrinsic should only have 1 or 2 arguments!"
)
"The function called is not in the context either!"
)
low
,
ext
=
self
.
visit
(
node
.
args
[
0
]),
self
.
visit
(
node
.
args
[
1
])
outs
=
self
.
symbols
[
func_id
](
*
args
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
op
=
outs
.
op
if
isinstance
(
outs
,
Tensor
)
else
outs
[
0
]
.
op
ext
=
ext
-
low
return
op
for_type
=
LOOP_INTRIN
[
func_id
]
iter_var
=
None
return
iter_var
,
low
,
ext
,
for_type
elif
func_id
==
'bind'
:
_internal_assert
(
n
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Str
),
\
"A loop bind's first argument should be a string!"
)
_vn
=
node
.
args
[
0
]
.
s
iter_var
=
thread_axis
(
node
.
args
[
0
]
.
s
)
low
,
ext
=
_api
.
const
(
0
,
dtype
=
'int32'
),
self
.
visit
(
node
.
args
[
1
])
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
elif
func_id
in
MATH_INTRIN
:
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
elif
func_id
in
[
'allocate'
,
'output_tensor'
]:
_internal_assert
(
isinstance
(
node
.
args
[
0
],
ast
.
Tuple
),
\
"allocate's first argument should be a tuple of shape!"
)
shape
=
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
args
[
0
]
.
elts
)
if
func_id
==
'output_tensor'
:
_internal_assert
(
not
self
.
loops_above
,
\
"Are you sure to allocate a output buffer multiple times?"
)
for
i
in
shape
:
_internal_assert
(
isinstance
(
i
,
_expr
.
Expr
),
"The shape should be an expression"
)
if
n
>
1
:
if
isinstance
(
node
.
args
[
1
],
ast
.
Str
):
dtype
=
node
.
args
[
1
]
.
s
else
:
_internal_assert
(
isinstance
(
node
.
args
[
1
],
ast
.
Attribute
),
\
"Unable to evaluate to get data type"
)
to_eval
=
node
.
args
[
1
]
_internal_assert
(
isinstance
(
to_eval
.
value
,
ast
.
Name
),
\
"Unable to evaluate the attribute to get data type"
)
_internal_assert
(
to_eval
.
attr
==
'dtype'
,
\
"Only dtype attribute is supported so far"
)
dtype
=
self
.
_get_buffer_from_id
(
to_eval
.
value
.
id
)
.
dtype
else
:
dtype
=
'float32'
if
n
>
2
:
_internal_assert
(
isinstance
(
node
.
args
[
2
],
ast
.
Str
),
\
"The data scope should be an string"
)
_internal_assert
(
func_id
!=
'output_tensor'
,
"Output tensor cannot specify scope"
)
scope
=
node
.
args
[
2
]
.
s
else
:
scope
=
'global'
if
func_id
!=
'output_tensor'
else
'output'
return
(
shape
,
dtype
,
scope
)
elif
func_id
==
'max'
or
func_id
==
'min'
:
_internal_assert
(
n
==
2
,
"Max/Min function should have 2 elements"
)
a
,
b
=
self
.
visit
(
node
.
args
[
0
]),
self
.
visit
(
node
.
args
[
1
])
return
getattr
(
_make
,
func_id
.
title
())(
a
,
b
)
else
:
raise
ValueError
(
"Function call not supported yet!"
)
def
visit_For
(
self
,
node
):
def
visit_For
(
self
,
node
):
...
@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor):
if
iter_var
is
None
:
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop bind function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
)):
offset
=
iter_var
+
low
offset
=
iter_var
+
low
self
.
loops_above
[
_name
]
=
offset
self
.
loops_above
[
_name
]
=
offset
else
:
else
:
...
@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor):
if
for_type
is
None
:
if
for_type
is
None
:
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
else
:
else
:
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
,
dtype
=
'int32'
),
ext
,
for_type
,
0
,
_body
)
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
),
ext
,
for_type
,
0
,
_body
)
self
.
loops_above
.
pop
(
_name
)
self
.
loops_above
.
pop
(
_name
)
return
res
return
res
...
@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor):
...
@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
_internal_assert
(
isinstance
(
i
,
ast
.
Name
),
"What do you return?"
)
ids
.
append
(
i
.
id
)
ids
.
append
(
i
.
id
)
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
_internal_assert
(
len
(
set
(
ids
))
==
len
(
ids
),
"Duplicated tensors in the return tuples"
)
if
len
(
ids
)
!=
len
(
self
.
outputs
):
if
len
(
ids
)
<
len
(
self
.
outputs
):
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
logging
.
log
(
logging
.
CRITICAL
,
'[Warning] Not all the output buffers returned!'
)
self
.
outputs
=
[
self
.
_args
[
i
]
for
i
in
ids
]
self
.
outputs
=
[
self
.
alloc_buffers
[
i
][
0
]
for
i
in
ids
]
self
.
returned
=
True
self
.
returned
=
True
return
make_nop
()
return
util
.
make_nop
()
def
visit_Tuple
(
self
,
node
):
return
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
elts
)
def
parse_python
(
src
,
args
):
def
visit_Str
(
self
,
node
):
return
node
.
s
def
parse_python
(
src
,
symbols
,
args
):
"""The helper function of calling the AST visitor
"""The helper function of calling the AST visitor
Parameters
Parameters
...
@@ -443,6 +429,9 @@ def parse_python(src, args):
...
@@ -443,6 +429,9 @@ def parse_python(src, args):
src : str
src : str
The source code of the function to be parsed.
The source code of the function to be parsed.
src : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
args : list of Tensors or Vars
The argument lists to the function.
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function without arguments.
...
@@ -454,8 +443,8 @@ def parse_python(src, args):
...
@@ -454,8 +443,8 @@ def parse_python(src, args):
The result Halide IR and the parser class instance.
The result Halide IR and the parser class instance.
"""
"""
root
=
ast
.
parse
(
src
)
root
=
ast
.
parse
(
src
)
var_usage
=
determine_variable_usage
(
root
,
args
)
var_usage
=
determine_variable_usage
(
root
,
args
,
symbols
)
parser
=
HybridParser
(
args
,
var_usage
)
parser
=
HybridParser
(
args
,
var_usage
,
symbols
)
parser
.
parsed_body
=
parser
.
visit
(
root
)
parser
.
parsed_body
=
parser
.
visit
(
root
)
_internal_assert
(
parser
.
returned
,
'No valid return found in the function body!'
)
_internal_assert
(
parser
.
returned
,
'No valid return found in the function body!'
)
return
parser
return
parser
python/tvm/hybrid/util.py
View file @
838e7181
...
@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
...
@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
from
..
import
api
as
_api
from
..
import
api
as
_api
from
..
import
make
as
_make
from
..
import
make
as
_make
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
stmt
as
_stmt
from
..tensor
import
Tensor
from
..tensor
import
Tensor
...
@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
...
@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
_globals
.
pop
(
elem
)
_globals
.
pop
(
elem
)
for
k
,
v
in
intersect
:
for
k
,
v
in
intersect
:
_globals
[
k
]
=
v
_globals
[
k
]
=
v
def
replace_io
(
body
,
rmap
):
"""Replacing tensors usage according to the dict given"""
from
..
import
ir_pass
def
replace
(
op
):
if
isinstance
(
op
,
_stmt
.
Provide
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Provide
(
buf
.
op
,
op
.
value_index
,
op
.
value
,
op
.
args
)
elif
isinstance
(
op
,
_expr
.
Call
)
and
op
.
func
in
rmap
.
keys
():
buf
=
rmap
[
op
.
func
]
return
_make
.
Call
(
buf
.
dtype
,
buf
.
name
,
op
.
args
,
\
_expr
.
Call
.
Halide
,
buf
.
op
,
buf
.
value_index
)
return
None
return
ir_pass
.
IRTransform
(
body
,
None
,
replace
,
[
'Provide'
,
'Call'
])
python/tvm/hybrid/var_decl.py
View file @
838e7181
...
@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
#pylint: disable=missing-docstring
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
,
symbols
):
self
.
status
=
{}
self
.
status
=
{}
self
.
scope_level
=
[]
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
_args
=
{}
self
.
args
=
args
self
.
args
=
args
self
.
aug_assign_
=
False
self
.
aug_assign_
=
False
self
.
symbols
=
symbols
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
...
@@ -43,7 +44,9 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -43,7 +44,9 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far
#No function pointer supported so far
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
"Function call should be an id"
)
_internal_assert
(
isinstance
(
node
.
func
,
ast
.
Name
),
"Function call should be an id"
)
func_id
=
node
.
func
.
id
func_id
=
node
.
func
.
id
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
[
'range'
,
'max'
,
'min'
],
\
_internal_assert
(
func_id
in
list
(
HYBRID_GLOBALS
.
keys
())
+
\
[
'range'
,
'max'
,
'min'
]
+
\
list
(
self
.
symbols
.
keys
()),
\
"Function call id not in intrinsics' list"
)
"Function call id not in intrinsics' list"
)
for
elem
in
node
.
args
:
for
elem
in
node
.
args
:
self
.
visit
(
elem
)
self
.
visit
(
elem
)
...
@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
else
:
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
usage
.
add
(
type
(
node
.
ctx
))
usage
.
add
(
type
(
node
.
ctx
))
_internal_assert
(
loop
in
self
.
scope_level
,
"
%
s is used out of the scope it is defined!"
%
node
.
id
)
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
def
determine_variable_usage
(
root
,
args
):
def
determine_variable_usage
(
root
,
args
,
symbols
):
"""The helper function for calling the dedicated visitor."""
"""The helper function for calling the dedicated visitor."""
visitor
=
PyVariableUsage
(
args
)
visitor
=
PyVariableUsage
(
args
,
symbols
)
visitor
.
visit
(
root
)
visitor
.
visit
(
root
)
return
visitor
.
status
return
visitor
.
status
tests/python/unittest/test_hybrid_script.py
View file @
838e7181
...
@@ -270,7 +270,7 @@ def test_bind():
...
@@ -270,7 +270,7 @@ def test_bind():
return
return
@script
@script
def
vec_add
(
a
,
b
):
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
c
=
output_tensor
((
1000
,
),
'float32'
)
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
return
c
return
c
...
@@ -506,7 +506,37 @@ def test_value_index():
...
@@ -506,7 +506,37 @@ def test_value_index():
module
(
tvm
.
ndarray
.
array
(
np_a
),
res
)
module
(
tvm
.
ndarray
.
array
(
np_a
),
res
)
tvm
.
testing
.
assert_allclose
(
res
.
asnumpy
(),
ref
)
tvm
.
testing
.
assert_allclose
(
res
.
asnumpy
(),
ref
)
def
test_func_call
():
@tvm.hybrid.script
def
foo
(
a
,
b
):
for
i
in
range
(
10
):
a
[
i
]
=
i
+
1.0
for
i
in
range
(
10
):
b
[
i
]
=
i
+
1.0
c
=
outer_product
(
10
,
10
,
a
,
b
)
d
=
output_tensor
(
c
.
shape
,
c
.
dtype
)
for
i
in
range
(
10
):
for
j
in
range
(
10
):
d
[
i
,
j
]
=
c
[
i
,
j
]
+
i
*
j
return
d
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
b
=
tvm
.
placeholder
((
10
,
),
name
=
'b'
)
run_and_check
(
foo
,
[
a
,
b
])
def
test_bool
():
@tvm.hybrid.script
def
foo
(
a
):
b
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
b
[
0
]
=
1.2
for
i
in
range
(
1
,
a
.
shape
[
0
]
-
1
):
if
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
]
or
a
[
i
]
*
a
[
i
-
1
]
<
a
[
i
-
1
]
or
i
*
a
[
i
]
==
a
[
i
]:
b
[
i
]
=
a
[
i
]
else
:
b
[
i
]
=
0.0
return
b
a
=
tvm
.
placeholder
((
10
,
),
name
=
'a'
)
run_and_check
(
foo
,
[
a
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
...
@@ -521,7 +551,7 @@ if __name__ == "__main__":
...
@@ -521,7 +551,7 @@ if __name__ == "__main__":
test_downstream
()
test_downstream
()
test_const_param
()
test_const_param
()
test_value_index
()
test_value_index
()
test_func_call
()
test_bool
()
# TODO:
# TODO:
# test_inplace()
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment