Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d29b1c9e
Commit
d29b1c9e
authored
Jun 25, 2018
by
Jian Weng
Committed by
Tianqi Chen
Jun 25, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND] [HYBRID] Non-zero starting supported; Buffer AttrStmt add! (#1330)
parent
9b8cb1b6
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
219 additions
and
75 deletions
+219
-75
python/tvm/hybrid/intrin.py
+3
-1
python/tvm/hybrid/parser.py
+72
-29
python/tvm/hybrid/util.py
+0
-10
python/tvm/hybrid/var_decl.py
+2
-2
tests/python/unittest/test_hybrid_script.py
+142
-33
No files found.
python/tvm/hybrid/intrin.py
View file @
d29b1c9e
...
@@ -29,7 +29,7 @@ class bind(_range): #pylint: disable=invalid-name
...
@@ -29,7 +29,7 @@ class bind(_range): #pylint: disable=invalid-name
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
unroll
=
vectorize
=
parallel
=
_range
#pylint: disable=invalid-name
def
allocate
(
shape
,
dtype
=
'float32'
):
def
allocate
(
shape
,
dtype
=
'float32'
,
scope
=
'global'
):
#pylint: disable=unused-argument
"""Allocate a buffer with given shape
"""Allocate a buffer with given shape
Parameters
Parameters
...
@@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'):
...
@@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'):
The shape of the tensor to be allocated
The shape of the tensor to be allocated
dtype: string
dtype: string
The data type of the tensor
The data type of the tensor
scope: string
The storage scope of the tensor
Returns
Returns
-------
-------
...
...
python/tvm/hybrid/parser.py
View file @
d29b1c9e
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
ast
import
ast
import
operator
import
operator
import
sys
import
sys
from
.util
import
make_nop
,
make_const_true
,
make_range_one
,
halide_imm_types
from
.util
import
make_nop
,
halide_imm_types
from
.intrin
import
LOOP_INTRIN
,
MATH_INTRIN
from
.intrin
import
LOOP_INTRIN
,
MATH_INTRIN
from
.var_decl
import
determine_variable_usage
from
.var_decl
import
determine_variable_usage
from
..api
import
thread_axis
from
..api
import
thread_axis
...
@@ -75,7 +75,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -75,7 +75,8 @@ class HybridParser(ast.NodeVisitor):
self
.
args
=
args
[:]
self
.
args
=
args
[:]
self
.
usage
=
usage
.
copy
()
self
.
usage
=
usage
.
copy
()
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
_args
=
{}
# Dict maps arg name to actual arg instance (either a var or a buffer)
self
.
buffers
=
{}
self
.
var_buffers
=
{}
# Buffers formed by mutatble variables
self
.
alloc_buffers
=
{}
# Buffers formed by allocate instructions
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
loops_above
=
{}
# State variable that indicates loop levels above the current node
self
.
var_consts
=
{}
# Variables that are determined as readonly in previous stage
self
.
var_consts
=
{}
# Variables that are determined as readonly in previous stage
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
func_name
=
func_name
# The name of the function to be lowered
...
@@ -87,19 +88,30 @@ class HybridParser(ast.NodeVisitor):
...
@@ -87,19 +88,30 @@ class HybridParser(ast.NodeVisitor):
for
key
,
val
in
self
.
usage
.
items
():
for
key
,
val
in
self
.
usage
.
items
():
if
key
in
self
.
var_consts
.
keys
():
if
key
in
self
.
var_consts
.
keys
():
continue
continue
_
,
scope
,
_
=
val
_
,
level
,
_
=
val
if
scope
==
node
:
if
level
==
node
:
_buf
=
self
.
buffers
[
key
]
if
key
in
self
.
var_buffers
.
keys
():
_buf
=
self
.
var_buffers
[
key
]
_scope
=
'global'
else
:
_buf
,
_scope
=
self
.
alloc_buffers
[
key
]
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_dtype
=
_buf
.
dtype
_
one
=
make_range_one
(
)
_
true
=
_api
.
convert
(
True
)
_true
=
make_const_true
(
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
_domain
,
_true
,
body
)
body
=
_make
.
Realize
(
_buf
.
op
,
0
,
_dtype
,
[
_one
],
_true
,
body
)
body
=
_make
.
AttrStmt
(
_buf
.
op
,
'realize_scope'
,
_api
.
convert
(
_scope
)
,
body
)
return
body
return
body
def
_
check_id_a_buffer
(
self
,
s
):
def
_
get_buffer_from_id
(
self
,
s
):
if
s
not
in
self
.
_args
.
keys
():
if
s
not
in
self
.
_args
.
keys
()
and
s
not
in
self
.
alloc_buffers
.
keys
()
:
raise
ValueError
(
"This
%
s is expected to be in argument list or allocated buffer!"
%
s
)
raise
ValueError
(
"This
%
s is expected to be in argument list or allocated buffer!"
%
s
)
if
s
in
self
.
_args
.
keys
()
and
s
in
self
.
alloc_buffers
.
keys
():
raise
ValueError
(
"
%
s, a buffer cannot be both argument and allocated!"
%
s
)
if
s
in
self
.
_args
.
keys
():
return
self
.
_args
[
s
]
return
self
.
alloc_buffers
[
s
][
0
]
#pylint: disable=invalid-name, missing-docstring
#pylint: disable=invalid-name, missing-docstring
...
@@ -138,8 +150,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -138,8 +150,8 @@ class HybridParser(ast.NodeVisitor):
if
_id
not
in
self
.
usage
.
keys
():
if
_id
not
in
self
.
usage
.
keys
():
raise
ValueError
(
"This id
%
s is expected to be a defined variable!"
%
_id
)
raise
ValueError
(
"This id
%
s is expected to be a defined variable!"
%
_id
)
# Buffer
# Buffer
if
_id
in
self
.
buffers
.
keys
():
if
_id
in
self
.
var_
buffers
.
keys
():
_buf
=
self
.
buffers
[
_id
]
_buf
=
self
.
var_
buffers
[
_id
]
return
_make
.
Call
(
_buf
.
dtype
,
_id
,
[
_api
.
const
(
0
)],
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
return
_make
.
Call
(
_buf
.
dtype
,
_id
,
[
_api
.
const
(
0
)],
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
# Compilation time constant
# Compilation time constant
if
_id
not
in
self
.
var_consts
.
keys
():
if
_id
not
in
self
.
var_consts
.
keys
():
...
@@ -155,7 +167,9 @@ class HybridParser(ast.NodeVisitor):
...
@@ -155,7 +167,9 @@ class HybridParser(ast.NodeVisitor):
if
len
(
node
.
targets
)
!=
1
:
if
len
(
node
.
targets
)
!=
1
:
raise
ValueError
(
"So far only one-valued assignment is supported!"
)
raise
ValueError
(
"So far only one-valued assignment is supported!"
)
lhs
=
node
.
targets
[
0
]
lhs
=
node
.
targets
[
0
]
rhs
=
_ir_pass
.
Simplify
(
self
.
visit
(
node
.
value
))
rhs
=
self
.
visit
(
node
.
value
)
if
isinstance
(
rhs
,
_expr
.
Expr
):
rhs
=
_ir_pass
.
Simplify
(
rhs
)
if
isinstance
(
lhs
,
ast
.
Name
):
if
isinstance
(
lhs
,
ast
.
Name
):
#TODO: support defined intermediate buffer later
#TODO: support defined intermediate buffer later
lhs_
=
lhs
lhs_
=
lhs
...
@@ -166,25 +180,31 @@ class HybridParser(ast.NodeVisitor):
...
@@ -166,25 +180,31 @@ class HybridParser(ast.NodeVisitor):
if
decl
==
lhs_
:
if
decl
==
lhs_
:
if
lhs
in
self
.
var_consts
.
keys
():
if
lhs
in
self
.
var_consts
.
keys
():
raise
ValueError
(
"BUG: A constant cannot be overwritten!"
)
raise
ValueError
(
"BUG: A constant cannot be overwritten!"
)
if
lhs
in
self
.
buffers
.
keys
():
if
lhs
in
self
.
var_buffers
.
keys
()
or
lhs
in
self
.
alloc_
buffers
.
keys
():
raise
ValueError
(
"BUG: This value should not be defined before this point!"
)
raise
ValueError
(
"BUG: This value should not be defined before this point!"
)
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
self
.
alloc_buffers
[
lhs
]
=
(
ph
,
scope
)
return
make_nop
()
if
isinstance
(
rhs
,
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
if
isinstance
(
rhs
,
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
var_consts
[
lhs
]
=
rhs
self
.
var_consts
[
lhs
]
=
rhs
else
:
else
:
self
.
buffers
[
lhs
]
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
var_
buffers
[
lhs
]
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
if
lhs
in
self
.
var_consts
.
keys
():
if
lhs
in
self
.
var_consts
.
keys
():
return
make_nop
()
return
make_nop
()
else
:
else
:
if
lhs
not
in
self
.
buffers
.
keys
():
if
lhs
not
in
self
.
var_buffers
.
keys
():
raise
ValueError
(
"BUG: This value should be defined before!"
)
raise
ValueError
(
"BUG: This variable should be defined before!"
)
return
_make
.
Provide
(
self
.
buffers
[
lhs
]
.
op
,
0
,
rhs
,
[
_api
.
const
(
0
,
dtype
=
rhs
.
dtype
)])
tgt
=
self
.
var_buffers
[
lhs
]
return
_make
.
Provide
(
tgt
.
op
,
0
,
rhs
,
[
_api
.
const
(
0
,
dtype
=
rhs
.
dtype
)])
else
:
else
:
lhs
=
self
.
visit
(
lhs
)
lhs
=
self
.
visit
(
lhs
)
if
not
isinstance
(
lhs
,
_expr
.
Call
):
if
not
isinstance
(
lhs
,
_expr
.
Call
):
raise
ValueError
(
"An array access's LHS is expected to be a expr.Call!"
)
raise
ValueError
(
"An array access's LHS is expected to be a expr.Call!"
)
#TODO: support slice later
#TODO: support slice later
self
.
_check_id_a_buffer
(
lhs
.
name
)
buf
=
self
.
_get_buffer_from_id
(
lhs
.
name
)
return
_make
.
Provide
(
self
.
_args
[
lhs
.
name
]
.
op
,
0
,
rhs
,
lhs
.
args
)
return
_make
.
Provide
(
buf
.
op
,
0
,
rhs
,
lhs
.
args
)
def
visit_Index
(
self
,
node
):
def
visit_Index
(
self
,
node
):
...
@@ -197,8 +217,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -197,8 +217,7 @@ class HybridParser(ast.NodeVisitor):
args
=
self
.
visit
(
node
.
slice
)
args
=
self
.
visit
(
node
.
slice
)
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
isinstance
(
node
.
value
,
ast
.
Name
):
array
=
node
.
value
.
id
array
=
node
.
value
.
id
self
.
_check_id_a_buffer
(
array
)
_buf
=
self
.
_get_buffer_from_id
(
array
)
_buf
=
self
.
_args
[
array
]
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
elif
isinstance
(
node
.
value
,
ast
.
Attribute
):
elif
isinstance
(
node
.
value
,
ast
.
Attribute
):
if
not
isinstance
(
node
.
value
.
value
,
ast
.
Name
):
if
not
isinstance
(
node
.
value
.
value
,
ast
.
Name
):
...
@@ -211,8 +230,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -211,8 +230,8 @@ class HybridParser(ast.NodeVisitor):
#TODO: maybe support non-constant value later?
#TODO: maybe support non-constant value later?
if
not
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)):
if
not
isinstance
(
args
,
(
_expr
.
IntImm
,
_expr
.
UIntImm
)):
raise
ValueError
(
"So far only constant shape access supported!"
)
raise
ValueError
(
"So far only constant shape access supported!"
)
self
.
_check_id_a_buffer
(
node
.
value
.
value
.
id
)
buf
=
self
.
_get_buffer_from_id
(
node
.
value
.
value
.
id
)
return
self
.
_args
[
node
.
value
.
value
.
id
]
.
shape
[
args
.
value
]
return
buf
.
shape
[
args
.
value
]
else
:
else
:
raise
ValueError
(
"Not supported yet!"
)
raise
ValueError
(
"Not supported yet!"
)
...
@@ -303,8 +322,30 @@ class HybridParser(ast.NodeVisitor):
...
@@ -303,8 +322,30 @@ class HybridParser(ast.NodeVisitor):
elif
func_id
in
MATH_INTRIN
:
elif
func_id
in
MATH_INTRIN
:
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
return
getattr
(
intrin
,
func_id
)(
*
[
self
.
visit
(
arg
)
for
arg
in
node
.
args
])
elif
func_id
==
'allocate'
:
elif
func_id
==
'allocate'
:
#TODO: Support it later!
if
not
isinstance
(
node
.
args
[
0
],
ast
.
Tuple
):
return
make_nop
()
raise
ValueError
(
"allocate's first argument should be a tuple of shape!"
)
shape
=
tuple
(
self
.
visit
(
i
)
for
i
in
node
.
args
[
0
]
.
elts
)
for
i
in
shape
:
if
not
isinstance
(
i
,
_expr
.
Expr
):
raise
ValueError
(
"The shape should be an expression"
)
if
n
>
1
:
if
not
isinstance
(
node
.
args
[
1
],
ast
.
Str
):
raise
ValueError
(
"The data type should be an string"
)
dtype
=
node
.
args
[
1
]
.
s
else
:
dtype
=
'float32'
if
n
>
2
:
if
not
isinstance
(
node
.
args
[
2
],
ast
.
Str
):
raise
ValueError
(
"The data type should be an string"
)
scope
=
node
.
args
[
2
]
.
s
else
:
scope
=
'global'
return
(
shape
,
dtype
,
scope
)
elif
func_id
==
'max'
or
func_id
==
'min'
:
if
n
!=
2
:
raise
ValueError
(
"Max/Min function should have 2 elements"
)
a
,
b
=
self
.
visit
(
node
.
args
[
0
]),
self
.
visit
(
node
.
args
[
1
])
return
getattr
(
_make
,
func_id
.
title
())(
a
,
b
)
else
:
else
:
raise
ValueError
(
"Function call not supported yet!"
)
raise
ValueError
(
"Function call not supported yet!"
)
...
@@ -317,8 +358,10 @@ class HybridParser(ast.NodeVisitor):
...
@@ -317,8 +358,10 @@ class HybridParser(ast.NodeVisitor):
if
iter_var
is
None
:
if
iter_var
is
None
:
if
for_type
is
None
:
if
for_type
is
None
:
raise
ValueError
(
"The loop bind function parse error!"
)
raise
ValueError
(
"The loop bind function parse error!"
)
iter_var
=
_api
.
var
(
_name
)
offset
=
iter_var
=
_api
.
var
(
_name
)
self
.
loops_above
[
_name
]
=
iter_var
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
dtype
=
'int32'
)):
offset
=
iter_var
+
low
self
.
loops_above
[
_name
]
=
offset
else
:
else
:
if
for_type
is
not
None
:
if
for_type
is
not
None
:
raise
ValueError
(
"The loop iterating function parse error!"
)
raise
ValueError
(
"The loop iterating function parse error!"
)
...
@@ -328,7 +371,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -328,7 +371,7 @@ class HybridParser(ast.NodeVisitor):
if
for_type
is
None
:
if
for_type
is
None
:
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
res
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
else
:
else
:
res
=
_make
.
For
(
iter_var
,
low
,
ext
,
for_type
,
0
,
_body
)
res
=
_make
.
For
(
iter_var
,
_api
.
const
(
0
,
dtype
=
'int32'
)
,
ext
,
for_type
,
0
,
_body
)
self
.
loops_above
.
pop
(
_name
)
self
.
loops_above
.
pop
(
_name
)
return
res
return
res
...
...
python/tvm/hybrid/util.py
View file @
d29b1c9e
...
@@ -22,16 +22,6 @@ def make_nop():
...
@@ -22,16 +22,6 @@ def make_nop():
return
_make
.
Evaluate
(
_api
.
const
(
0
,
dtype
=
'int32'
))
return
_make
.
Evaluate
(
_api
.
const
(
0
,
dtype
=
'int32'
))
def
make_range_one
():
"""Returns a [0, 1] range node in HalideIR."""
return
_make
.
range_by_min_extent
(
0
,
1
)
def
make_const_true
():
"""Returns a constant True node in HalideIR."""
return
_api
.
convert
(
True
)
def
_pruned_source
(
func
):
def
_pruned_source
(
func
):
"""Prune source code's extra leading spaces"""
"""Prune source code's extra leading spaces"""
lines
=
inspect
.
getsource
(
func
)
.
split
(
'
\n
'
)
lines
=
inspect
.
getsource
(
func
)
.
split
(
'
\n
'
)
...
...
python/tvm/hybrid/var_decl.py
View file @
d29b1c9e
...
@@ -41,7 +41,8 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -41,7 +41,8 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far
#No function pointer supported so far
if
not
isinstance
(
node
.
func
,
ast
.
Name
):
if
not
isinstance
(
node
.
func
,
ast
.
Name
):
raise
ValueError
(
"Function call should be an id"
)
raise
ValueError
(
"Function call should be an id"
)
if
(
node
.
func
.
id
not
in
HYBRID_GLOBALS
.
keys
())
and
node
.
func
.
id
!=
'range'
:
func_id
=
node
.
func
.
id
if
func_id
not
in
list
(
HYBRID_GLOBALS
.
keys
())
+
[
'range'
,
'max'
,
'min'
]:
raise
ValueError
(
"Function call id not in intrinsics' list"
)
raise
ValueError
(
"Function call id not in intrinsics' list"
)
for
elem
in
node
.
args
:
for
elem
in
node
.
args
:
self
.
visit
(
elem
)
self
.
visit
(
elem
)
...
@@ -64,7 +65,6 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -64,7 +65,6 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
else
:
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
loop
=
self
.
scope_level
[
-
1
]
usage
.
add
(
type
(
node
.
ctx
))
usage
.
add
(
type
(
node
.
ctx
))
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
self
.
status
[
node
.
id
]
=
(
decl
,
loop
,
usage
)
...
...
tests/python/unittest/test_hybrid_script.py
View file @
d29b1c9e
...
@@ -2,6 +2,7 @@ import tvm, inspect, sys, traceback, numpy
...
@@ -2,6 +2,7 @@ import tvm, inspect, sys, traceback, numpy
from
tvm.hybrid
import
script
from
tvm.hybrid
import
script
from
tvm.hybrid.intrin
import
HYBRID_GLOBALS
from
tvm.hybrid.intrin
import
HYBRID_GLOBALS
@script
@script
def
outer_product
(
n
,
m
,
a
,
b
,
c
):
def
outer_product
(
n
,
m
,
a
,
b
,
c
):
for
i
in
range
(
n
):
for
i
in
range
(
n
):
...
@@ -56,6 +57,7 @@ def test_outer_product():
...
@@ -56,6 +57,7 @@ def test_outer_product():
tvm_c
=
tvm
.
ndarray
.
array
(
numpy
.
zeros
((
_n
,
_m
),
dtype
=
'float32'
))
tvm_c
=
tvm
.
ndarray
.
array
(
numpy
.
zeros
((
_n
,
_m
),
dtype
=
'float32'
))
func
(
_n
,
_m
,
tvm_a
,
tvm_b
,
tvm_c
)
func
(
_n
,
_m
,
tvm_a
,
tvm_b
,
tvm_c
)
numpy
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
c_python
,
rtol
=
1e-5
)
numpy
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
c_python
,
rtol
=
1e-5
)
for
key
,
_
in
HYBRID_GLOBALS
.
items
():
for
key
,
_
in
HYBRID_GLOBALS
.
items
():
assert
key
not
in
globals
()
.
keys
()
assert
key
not
in
globals
()
.
keys
()
assert
key
not
in
outer_product
.
__globals__
.
keys
()
assert
key
not
in
outer_product
.
__globals__
.
keys
()
...
@@ -74,8 +76,8 @@ def test_fanout():
...
@@ -74,8 +76,8 @@ def test_fanout():
b
[
i
]
=
sigma
b
[
i
]
=
sigma
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
a
=
tvm
.
placeholder
((
n
,
),
name
=
'a'
)
a
=
tvm
.
placeholder
((
n
,
),
'float32'
,
name
=
'a'
)
b
=
tvm
.
placeholder
((
n
-
3
,
),
name
=
'b'
)
b
=
tvm
.
placeholder
((
n
-
3
,
),
'float32'
,
name
=
'b'
)
ir
=
fanout
(
n
,
a
,
b
)
ir
=
fanout
(
n
,
a
,
b
)
#Check for i in (0, n-3)
#Check for i in (0, n-3)
...
@@ -85,12 +87,14 @@ def test_fanout():
...
@@ -85,12 +87,14 @@ def test_fanout():
assert
tvm
.
ir_pass
.
Equal
(
ir
.
extent
,
n
-
3
)
assert
tvm
.
ir_pass
.
Equal
(
ir
.
extent
,
n
-
3
)
#Check loopbody
#Check loopbody
ibody
=
ir
.
body
ibody
=
ir
.
body
assert
isinstance
(
ibody
,
tvm
.
stmt
.
Realize
)
assert
isinstance
(
ibody
,
tvm
.
stmt
.
AttrStmt
)
assert
ibody
.
bounds
[
0
]
.
min
.
value
==
0
abody
=
ibody
.
body
assert
ibody
.
bounds
[
0
]
.
extent
.
value
==
1
assert
isinstance
(
abody
,
tvm
.
stmt
.
Realize
)
assert
ibody
.
func
.
name
==
'sigma'
assert
abody
.
bounds
[
0
]
.
min
.
value
==
0
assert
abody
.
bounds
[
0
]
.
extent
.
value
==
1
assert
abody
.
func
.
name
==
'sigma'
#Check i loop body
#Check i loop body
rbody
=
i
body
.
body
rbody
=
a
body
.
body
assert
isinstance
(
rbody
.
first
,
tvm
.
stmt
.
Provide
)
assert
isinstance
(
rbody
.
first
,
tvm
.
stmt
.
Provide
)
assert
rbody
.
first
.
func
.
name
==
'sigma'
assert
rbody
.
first
.
func
.
name
==
'sigma'
assert
len
(
rbody
.
first
.
args
)
==
1
assert
len
(
rbody
.
first
.
args
)
==
1
...
@@ -131,6 +135,21 @@ def test_fanout():
...
@@ -131,6 +135,21 @@ def test_fanout():
assert
len
(
write
.
value
.
args
)
==
1
assert
len
(
write
.
value
.
args
)
==
1
assert
write
.
value
.
args
[
0
]
.
value
==
0
assert
write
.
value
.
args
[
0
]
.
value
==
0
func
=
tvm
.
build
(
tvm
.
lower
(
ir
,
[
n
,
a
,
b
]))
assert
func
np_a
=
numpy
.
random
.
randn
(
10
)
.
astype
(
'float32'
)
np_b
=
numpy
.
zeros
(
7
)
.
astype
(
'float32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_b
=
tvm
.
ndarray
.
array
(
np_b
)
fanout
(
10
,
np_a
,
np_b
)
func
(
10
,
nd_a
,
nd_b
)
numpy
.
testing
.
assert_allclose
(
nd_b
.
asnumpy
(),
np_b
,
rtol
=
1e-5
,
atol
=
1e-5
)
@script
@script
def
failure
():
def
failure
():
for
i
in
range
(
1
,
100
):
for
i
in
range
(
1
,
100
):
...
@@ -148,15 +167,18 @@ def test_failure():
...
@@ -148,15 +167,18 @@ def test_failure():
def
test_looptype
():
def
test_looptype
():
@script
@script
def
looptype
(
a
):
def
looptype
(
a
,
b
,
c
):
for
i
in
parallel
(
6
):
for
i
in
parallel
(
8
):
a
[
i
]
=
i
a
[
i
]
=
i
for
j
in
vectorize
(
6
):
for
j
in
vectorize
(
8
):
a
[
j
]
=
j
b
[
j
]
=
j
for
k
in
unroll
(
6
):
for
k
in
unroll
(
8
):
a
[
k
]
=
k
c
[
k
]
=
k
a
=
tvm
.
placeholder
((
6
,
),
name
=
'a'
)
ir
=
looptype
(
a
)
a
=
tvm
.
placeholder
((
8
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
tvm
.
placeholder
((
8
,
),
name
=
'b'
,
dtype
=
'int32'
)
c
=
tvm
.
placeholder
((
8
,
),
name
=
'c'
,
dtype
=
'int32'
)
ir
=
looptype
(
a
,
b
,
c
)
iloop
=
ir
.
first
iloop
=
ir
.
first
jloop
=
ir
.
rest
.
first
jloop
=
ir
.
rest
.
first
kloop
=
ir
.
rest
.
rest
kloop
=
ir
.
rest
.
rest
...
@@ -164,6 +186,24 @@ def test_looptype():
...
@@ -164,6 +186,24 @@ def test_looptype():
assert
jloop
.
for_type
==
tvm
.
stmt
.
For
.
Vectorized
assert
jloop
.
for_type
==
tvm
.
stmt
.
For
.
Vectorized
assert
kloop
.
for_type
==
tvm
.
stmt
.
For
.
Unrolled
assert
kloop
.
for_type
==
tvm
.
stmt
.
For
.
Unrolled
func
=
tvm
.
build
(
tvm
.
lower
(
ir
,
[
a
,
b
,
c
]))
np_a
=
numpy
.
zeros
((
8
,
))
.
astype
(
'int32'
)
np_b
=
numpy
.
zeros
((
8
,
))
.
astype
(
'int32'
)
np_c
=
numpy
.
zeros
((
8
,
))
.
astype
(
'int32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_b
=
tvm
.
ndarray
.
array
(
np_b
)
nd_c
=
tvm
.
ndarray
.
array
(
np_c
)
looptype
(
np_a
,
np_b
,
np_c
)
func
(
nd_a
,
nd_b
,
nd_c
)
numpy
.
testing
.
assert_allclose
(
np_a
,
nd_a
.
asnumpy
())
numpy
.
testing
.
assert_allclose
(
np_b
,
nd_b
.
asnumpy
())
numpy
.
testing
.
assert_allclose
(
np_c
,
nd_c
.
asnumpy
())
def
test_if
():
def
test_if
():
@script
@script
def
if_then_else
(
a
,
b
):
def
if_then_else
(
a
,
b
):
...
@@ -234,12 +274,14 @@ def test_math_intrin():
...
@@ -234,12 +274,14 @@ def test_math_intrin():
a
[
3
]
=
sigmoid
(
a
[
3
])
a
[
3
]
=
sigmoid
(
a
[
3
])
a
[
4
]
=
power
(
a
[
4
],
a
[
5
])
a
[
4
]
=
power
(
a
[
4
],
a
[
5
])
a
[
5
]
=
tanh
(
a
[
5
])
a
[
5
]
=
tanh
(
a
[
5
])
a
[
6
]
=
min
(
a
[
4
],
a
[
5
])
a
[
7
]
=
max
(
a
[
5
],
a
[
6
])
a6
=
tvm
.
placeholder
((
6
,
),
dtype
=
'float32'
,
name
=
'a'
)
a6
=
tvm
.
placeholder
((
8
,
),
dtype
=
'float32'
,
name
=
'a'
)
ir
=
intrin_real
(
a6
)
ir
=
intrin_real
(
a6
)
func
=
tvm
.
build
(
tvm
.
lower
(
ir
,
[
a6
]))
func
=
tvm
.
build
(
tvm
.
lower
(
ir
,
[
a6
]))
assert
func
assert
func
a
=
numpy
.
arange
(
2
,
8
)
.
astype
(
'float32'
)
a
=
numpy
.
arange
(
2
,
10
)
.
astype
(
'float32'
)
tvm_a
=
tvm
.
ndarray
.
array
(
a
)
tvm_a
=
tvm
.
ndarray
.
array
(
a
)
func
(
tvm_a
)
func
(
tvm_a
)
intrin_real
(
a
)
intrin_real
(
a
)
...
@@ -259,22 +301,87 @@ def test_math_intrin():
...
@@ -259,22 +301,87 @@ def test_math_intrin():
func
(
tvm_a
)
func
(
tvm_a
)
assert
tvm_a
.
asnumpy
()[
0
]
==
a
[
0
]
assert
tvm_a
.
asnumpy
()[
0
]
==
a
[
0
]
def
test_allocate_buffer
():
def
test_non_zero
():
def
blur
(
a
):
@tvm.hybrid.script
for
i
in
serail
(
32
):
def
blur
(
a
,
b
):
h_blur
=
allocate
((
4
,
36
))
for
i
in
range
(
2
,
32
):
for
j
in
serail
(
4
):
for
j
in
range
(
2
,
32
):
for
k
in
serail
(
36
):
s
=
0.0
s
=
allocate
((
1
,
),
'float32'
)
for
di
in
range
(
3
):
for
dj
in
serail
(
4
):
for
dj
in
range
(
3
):
s
[
0
]
=
s
[
0
]
+
a
[
i
,
j
+
dj
]
s
=
s
+
a
[
i
-
di
,
j
-
dj
]
h_blur
[
j
,
k
]
=
s
[
0
]
/
4.
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
for
j
in
serail
(
32
):
try
:
s
=
0.
np_a
=
numpy
.
random
.
randn
(
32
,
32
)
.
astype
(
'float32'
)
for
di
in
serail
(
4
):
np_b
=
numpy
.
zeros
((
30
,
30
),
dtype
=
'float32'
)
s
=
s
+
h_blur
[
di
,
j
]
blur
(
np_a
,
np_b
)
h_blur
[
i
,
j
]
=
s
/
4.
ph_a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
ph_b
=
tvm
.
placeholder
((
30
,
30
),
'float32'
,
'b'
)
ir
=
tvm
.
hybrid
.
parse
(
blur
,
[
ph_a
,
ph_b
])
func
=
tvm
.
lower
(
ir
,
[
ph_a
,
ph_b
])
func
=
tvm
.
build
(
func
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_b
=
tvm
.
ndarray
.
array
(
np_b
)
func
(
nd_a
,
nd_b
)
numpy
.
testing
.
assert_allclose
(
np_b
,
nd_b
.
asnumpy
(),
atol
=
1e-5
,
rtol
=
1e-5
)
except
IOError
:
print
(
'[Warning] Non-zero first test skipped by Python2'
)
@tvm.hybrid.script
def
triangle
(
a
,
b
,
c
):
for
i
in
range
(
10
):
for
j
in
range
(
i
,
10
):
c
[
i
,
j
]
=
a
[
i
]
*
b
[
j
]
a
=
tvm
.
placeholder
((
10
,
),
dtype
=
'float32'
,
name
=
'a'
)
b
=
tvm
.
placeholder
((
10
,
),
dtype
=
'float32'
,
name
=
'b'
)
c
=
tvm
.
placeholder
((
10
,
10
),
dtype
=
'float32'
,
name
=
'c'
)
np_a
=
numpy
.
random
.
randn
(
10
)
.
astype
(
'float32'
)
np_b
=
numpy
.
random
.
randn
(
10
)
.
astype
(
'float32'
)
np_c
=
numpy
.
zeros
((
10
,
10
))
.
astype
(
'float32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_b
=
tvm
.
ndarray
.
array
(
np_b
)
nd_c
=
tvm
.
ndarray
.
array
(
np_c
)
triangle
(
np_a
,
np_b
,
np_c
)
func
=
tvm
.
build
(
tvm
.
lower
(
triangle
(
a
,
b
,
c
),
[
a
,
b
,
c
]))
assert
func
func
(
nd_a
,
nd_b
,
nd_c
)
numpy
.
testing
.
assert_allclose
(
nd_c
.
asnumpy
(),
np_c
)
def
test_allocate
():
@tvm.hybrid.script
def
blur2d
(
a
,
b
):
for
i
in
range
(
30
):
ha
=
allocate
((
3
,
30
),
'float32'
)
for
j
in
range
(
3
):
for
k
in
range
(
30
):
ha
[
j
,
k
]
=
a
[
i
+
j
,
k
]
+
a
[
i
+
j
,
k
+
1
]
+
a
[
i
+
j
,
k
+
2
]
for
j
in
range
(
30
):
b
[
i
,
j
]
=
(
ha
[
0
,
j
]
+
ha
[
1
,
j
]
+
ha
[
2
,
j
])
/
9.0
a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
b
=
tvm
.
placeholder
((
30
,
30
),
'float32'
,
'b'
)
func
=
tvm
.
build
(
tvm
.
lower
(
blur2d
(
a
,
b
),
[
a
,
b
]))
assert
func
np_a
=
numpy
.
random
.
randn
(
32
,
32
)
.
astype
(
'float32'
)
np_b
=
numpy
.
zeros
((
30
,
30
))
.
astype
(
'float32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_b
=
tvm
.
ndarray
.
array
(
np_b
)
func
(
nd_a
,
nd_b
)
blur2d
(
np_a
,
np_b
)
numpy
.
testing
.
assert_allclose
(
nd_b
.
asnumpy
(),
np_b
,
atol
=
1e-5
,
rtol
=
1e-5
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
...
@@ -284,4 +391,6 @@ if __name__ == "__main__":
...
@@ -284,4 +391,6 @@ if __name__ == "__main__":
test_if
()
test_if
()
test_bind
()
test_bind
()
test_math_intrin
()
test_math_intrin
()
test_non_zero
()
test_allocate
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment