Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4bbf96e4
Commit
4bbf96e4
authored
Dec 13, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[BUGFIX] [Hybrid Script] fix in-correct value index in hybrid script (#2268)
parent
6b405824
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
7 deletions
+48
-7
python/tvm/hybrid/parser.py
+4
-3
tests/python/unittest/test_hybrid_script.py
+44
-4
No files found.
python/tvm/hybrid/parser.py
View file @
4bbf96e4
...
@@ -39,10 +39,11 @@ class HybridParser(ast.NodeVisitor):
...
@@ -39,10 +39,11 @@ class HybridParser(ast.NodeVisitor):
ast
.
Sub
:
operator
.
sub
,
ast
.
Sub
:
operator
.
sub
,
ast
.
Mult
:
operator
.
mul
,
ast
.
Mult
:
operator
.
mul
,
ast
.
Div
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
Div
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
FloorDiv
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
Mod
:
operator
.
mod
,
ast
.
Mod
:
operator
.
mod
,
ast
.
BitOr
:
operator
.
or_
,
ast
.
BitOr
:
operator
.
or_
,
ast
.
BitAnd
:
operator
.
and_
,
ast
.
BitAnd
:
operator
.
and_
,
ast
.
BitXor
:
operator
.
xor
,
ast
.
BitXor
:
operator
.
xor
,
ast
.
Gt
:
operator
.
gt
,
ast
.
Gt
:
operator
.
gt
,
ast
.
GtE
:
operator
.
ge
,
ast
.
GtE
:
operator
.
ge
,
ast
.
Lt
:
operator
.
lt
,
ast
.
Lt
:
operator
.
lt
,
...
@@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor):
if
isinstance
(
node
.
value
,
ast
.
Name
):
if
isinstance
(
node
.
value
,
ast
.
Name
):
array
=
node
.
value
.
id
array
=
node
.
value
.
id
_buf
=
self
.
_get_buffer_from_id
(
array
)
_buf
=
self
.
_get_buffer_from_id
(
array
)
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
0
)
return
_make
.
Call
(
_buf
.
dtype
,
array
,
args
,
_expr
.
Call
.
Halide
,
_buf
.
op
,
_buf
.
value_index
)
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
_internal_assert
(
isinstance
(
node
.
value
,
ast
.
Attribute
),
\
"Only variable and attribute's subscript supported so far"
)
"Only variable and attribute's subscript supported so far"
)
...
...
tests/python/unittest/test_hybrid_script.py
View file @
4bbf96e4
import
tvm
,
inspect
,
sys
,
traceback
,
numpy
,
nose
import
tvm
,
inspect
,
sys
,
traceback
,
numpy
,
nose
,
types
from
tvm.hybrid
import
script
from
tvm.hybrid
import
script
from
tvm.hybrid.intrin
import
HYBRID_GLOBALS
from
tvm.hybrid.intrin
import
HYBRID_GLOBALS
...
@@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
...
@@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
return
val
.
value
return
val
.
value
ctx
=
tvm
.
context
(
target
,
0
)
ctx
=
tvm
.
context
(
target
,
0
)
op
=
None
outs
=
func
(
*
args
)
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
emu_args
=
[]
emu_args
=
[]
nd_args
=
[]
nd_args
=
[]
...
@@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
...
@@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
emu_args
.
append
(
tvm_val_2_py_val
(
i
))
nd_args
.
append
(
emu_args
[
-
1
])
nd_args
.
append
(
emu_args
[
-
1
])
outs
=
func
(
*
args
)
op
=
outs
[
0
]
.
op
if
isinstance
(
outs
,
list
)
else
outs
.
op
sch
=
tvm
.
create_schedule
(
op
)
sch
=
tvm
.
create_schedule
(
op
)
module
=
tvm
.
build
(
sch
,
args
+
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
module
=
tvm
.
build
(
sch
,
args
+
(
outs
if
isinstance
(
outs
,
list
)
else
[
outs
]),
target
=
target
)
assert
module
assert
module
...
@@ -426,9 +428,11 @@ def test_downstream():
...
@@ -426,9 +428,11 @@ def test_downstream():
b
[
i
]
=
a
[
i
]
*
i
b
[
i
]
=
a
[
i
]
*
i
return
b
return
b
a
=
tvm
.
placeholder
((
20
,
),
'float32'
)
a
=
tvm
.
placeholder
((
20
,
),
'float32'
)
b
=
downstream
(
a
)
b
=
downstream
(
a
)
c
=
tvm
.
compute
((
20
,
),
lambda
x
:
b
[
x
]
+
1.0
)
c
=
tvm
.
compute
((
20
,
),
lambda
x
:
b
[
x
]
+
1.0
)
sch
=
tvm
.
create_schedule
(
c
.
op
)
sch
=
tvm
.
create_schedule
(
c
.
op
)
module
=
tvm
.
build
(
sch
,
[
a
,
c
])
module
=
tvm
.
build
(
sch
,
[
a
,
c
])
assert
module
assert
module
...
@@ -469,6 +473,40 @@ def test_const_param():
...
@@ -469,6 +473,40 @@ def test_const_param():
tvm
.
testing
.
assert_allclose
(
nd_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
tvm
.
testing
.
assert_allclose
(
nd_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
def
test_value_index
():
@tvm.hybrid.script
def
kernel_a
(
a
):
b
=
output_tensor
((
16
,
),
'int32'
)
c
=
output_tensor
((
4
,
4
),
'int32'
)
for
i
in
range
(
16
):
b
[
i
]
=
a
[
i
]
+
2
c
[
i
//
4
,
i
%
4
]
=
a
[
i
]
+
1
return
b
,
c
@tvm.hybrid.script
def
kernel_b
(
b
,
a
):
c
=
output_tensor
((
4
,
4
),
'int32'
)
for
i
in
range
(
4
):
for
j
in
range
(
4
):
c
[
i
,
j
]
=
a
[
i
*
4
+
j
]
*
b
[
i
,
j
]
return
c
a
=
tvm
.
placeholder
((
16
,
),
'int32'
)
b
,
c
=
kernel_a
(
a
)
d
=
kernel_b
(
c
,
b
)
sch
=
tvm
.
create_schedule
(
d
.
op
)
module
=
tvm
.
build
(
sch
,
[
a
,
d
])
assert
module
np_a
=
numpy
.
arange
(
16
)
.
astype
(
'int32'
)
np_b
,
np_c
=
kernel_a
(
np_a
)
ref
=
kernel_b
(
np_c
,
np_b
)
res
=
tvm
.
ndarray
.
array
(
numpy
.
zeros
((
4
,
4
))
.
astype
(
'int32'
))
module
(
tvm
.
ndarray
.
array
(
np_a
),
res
)
tvm
.
testing
.
assert_allclose
(
res
.
asnumpy
(),
ref
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
...
@@ -479,9 +517,11 @@ if __name__ == "__main__":
...
@@ -479,9 +517,11 @@ if __name__ == "__main__":
test_math_intrin
()
test_math_intrin
()
test_non_zero
()
test_non_zero
()
test_allocate
()
test_allocate
()
#test_inplace()
test_upstream
()
test_upstream
()
test_downstream
()
test_downstream
()
test_const_param
()
test_const_param
()
test_value_index
()
# TODO:
# test_inplace()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment