Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b8fedfb1
Commit
b8fedfb1
authored
Jul 20, 2018
by
Jian Weng
Committed by
Tianqi Chen
Jul 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND] [HYBRID] Augmented assign operator supported! (#1459)
parent
84eea572
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
26 deletions
+67
-26
python/tvm/hybrid/parser.py
+8
-1
python/tvm/hybrid/var_decl.py
+9
-0
tests/python/unittest/test_hybrid_script.py
+50
-25
No files found.
python/tvm/hybrid/parser.py
View file @
b8fedfb1
...
@@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass
...
@@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass
def
list_to_block
(
visit
,
lst
):
def
list_to_block
(
visit
,
lst
):
"""Convert a list of Python IR nodes to HalideIR Block"""
"""Convert a list of Python IR nodes to HalideIR Block"""
lst
=
list
(
map
(
visit
,
lst
))
lst
=
[
visit
(
i
)
for
i
in
lst
]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
make_nop
())]
lst
=
[
stmt
for
stmt
in
lst
if
not
_ir_pass
.
Equal
(
stmt
,
make_nop
())]
if
not
lst
:
if
not
lst
:
return
make_nop
()
return
make_nop
()
...
@@ -162,6 +162,13 @@ class HybridParser(ast.NodeVisitor):
...
@@ -162,6 +162,13 @@ class HybridParser(ast.NodeVisitor):
def
visit_Num
(
self
,
node
):
def
visit_Num
(
self
,
node
):
return
_api
.
const
(
node
.
n
)
return
_api
.
const
(
node
.
n
)
def
visit_AugAssign
(
self
,
node
):
lhs
=
self
.
visit
(
node
.
target
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
lhs
,
rhs
)
if
not
isinstance
(
lhs
,
_expr
.
Call
):
raise
ValueError
(
"The LHS of an AugAssign is supposed to be a call!"
)
return
_make
.
Provide
(
lhs
.
func
,
0
,
rhs
,
lhs
.
args
)
def
visit_Assign
(
self
,
node
):
def
visit_Assign
(
self
,
node
):
if
len
(
node
.
targets
)
!=
1
:
if
len
(
node
.
targets
)
!=
1
:
...
...
python/tvm/hybrid/var_decl.py
View file @
b8fedfb1
...
@@ -14,6 +14,7 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -14,6 +14,7 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
scope_level
=
[]
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
_args
=
{}
self
.
args
=
args
self
.
args
=
args
self
.
aug_assign_
=
False
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
visit
(
elem
)
self
.
visit
(
elem
)
def
visit_AugAssign
(
self
,
node
):
self
.
aug_assign_
=
True
self
.
generic_visit
(
node
)
self
.
aug_assign_
=
False
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
# If it is from the argument list or loop variable, we do not worry about it!
# If it is from the argument list or loop variable, we do not worry about it!
if
node
.
id
in
self
.
_args
.
keys
():
if
node
.
id
in
self
.
_args
.
keys
():
...
@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
if
node
.
id
not
in
self
.
status
.
keys
():
if
node
.
id
not
in
self
.
status
.
keys
():
if
not
isinstance
(
node
.
ctx
,
ast
.
Store
):
if
not
isinstance
(
node
.
ctx
,
ast
.
Store
):
raise
ValueError
(
'In Python, "first store" indicates "declaration"'
)
raise
ValueError
(
'In Python, "first store" indicates "declaration"'
)
if
self
.
aug_assign_
:
raise
ValueError
(
'"First store" cannot be an AugAssign'
)
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
else
:
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
...
...
tests/python/unittest/test_hybrid_script.py
View file @
b8fedfb1
...
@@ -38,7 +38,9 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
...
@@ -38,7 +38,9 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
module
(
*
nd_args
)
module
(
*
nd_args
)
for
nd
,
np
in
to_check
:
for
nd
,
np
in
to_check
:
numpy
.
testing
.
assert_allclose
(
nd
.
asnumpy
(),
np
,
rtol
=
1e-5
,
atol
=
1e-5
)
numpy
.
testing
.
assert_allclose
(
nd
.
asnumpy
(),
np
,
rtol
=
1e-3
,
atol
=
1e-3
)
return
module
@script
@script
...
@@ -83,7 +85,7 @@ def test_outer_product():
...
@@ -83,7 +85,7 @@ def test_outer_product():
func
=
tvm
.
lower
(
ir
,
[
n
,
m
,
a
,
b
,
c
])
func
=
tvm
.
lower
(
ir
,
[
n
,
m
,
a
,
b
,
c
])
func
=
tvm
.
build
(
func
)
func
=
tvm
.
build
(
func
)
run_and_check
(
outer_product
,
[
n
,
m
,
a
,
b
,
c
],
[
c
],
{
n
:
99
9
,
m
:
10
01
})
run_and_check
(
outer_product
,
[
n
,
m
,
a
,
b
,
c
],
[
c
],
{
n
:
99
,
m
:
1
01
})
for
key
,
_
in
HYBRID_GLOBALS
.
items
():
for
key
,
_
in
HYBRID_GLOBALS
.
items
():
assert
key
not
in
globals
()
.
keys
()
assert
key
not
in
globals
()
.
keys
()
...
@@ -165,20 +167,32 @@ def test_fanout():
...
@@ -165,20 +167,32 @@ def test_fanout():
run_and_check
(
fanout
,
[
n
,
a
,
b
],
[
b
],
{
n
:
10
})
run_and_check
(
fanout
,
[
n
,
a
,
b
],
[
b
],
{
n
:
10
})
@script
def
failure
():
for
i
in
range
(
1
,
100
):
i
=
0
def
test_failure
():
def
test_failure
():
try
:
try
:
@script
def
failure
():
for
i
in
range
(
1
,
100
):
i
=
0
tvm
.
hybrid
.
parse
(
failure
,
[])
tvm
.
hybrid
.
parse
(
failure
,
[])
except
IOError
as
err
:
except
IOError
as
err
:
assert
sys
.
version_info
[
0
]
==
2
assert
sys
.
version_info
[
0
]
==
2
print
(
'[Warning] Case test_failure is skipped by Python2 because "
%
s"'
%
str
(
err
))
print
(
'[Warning] Case test_failure
.0
is skipped by Python2 because "
%
s"'
%
str
(
err
))
except
Exception
as
err
:
except
ValueError
as
err
:
assert
str
(
err
)
==
'You CAN NEVER overwrite a loop variable!'
assert
str
(
err
)
==
'You CAN NEVER overwrite a loop variable!'
try
:
@tvm.hybrid.script
def
augdefine
():
for
i
in
range
(
10
):
es
+=
0
tvm
.
hybrid
.
parse
(
augdefine
,
[])
except
IOError
as
err
:
assert
sys
.
version_info
[
0
]
==
2
print
(
'[Warning] Case test_failure.1 is skipped by Python2 because "
%
s"'
%
str
(
err
))
except
ValueError
as
err
:
assert
str
(
err
)
==
'"First store" cannot be an AugAssign'
def
test_looptype
():
def
test_looptype
():
@script
@script
...
@@ -280,7 +294,7 @@ def test_non_zero():
...
@@ -280,7 +294,7 @@ def test_non_zero():
s
=
0.0
s
=
0.0
for
di
in
range
(
3
):
for
di
in
range
(
3
):
for
dj
in
range
(
3
):
for
dj
in
range
(
3
):
s
=
s
+
a
[
i
-
di
,
j
-
dj
]
s
+=
a
[
i
-
di
,
j
-
dj
]
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
try
:
try
:
a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
...
@@ -315,29 +329,39 @@ def test_allocate():
...
@@ -315,29 +329,39 @@ def test_allocate():
a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
a
=
tvm
.
placeholder
((
32
,
32
),
'float32'
,
'a'
)
b
=
tvm
.
placeholder
((
30
,
30
),
'float32'
,
'b'
)
b
=
tvm
.
placeholder
((
30
,
30
),
'float32'
,
'b'
)
run_and_check
(
blur2d
,
[
a
,
b
],
[
b
])
run_and_check
(
blur2d
,
[
a
,
b
],
[
b
])
if
tvm
.
gpu
()
.
exist
:
if
tvm
.
gpu
()
.
exist
:
@tvm.hybrid.script
@tvm.hybrid.script
def
share_vec_add
(
a
,
b
,
c
):
def
shared_gemm
(
a
,
b
,
c
):
shared
=
allocate
((
256
,
),
'float32'
,
'shared'
)
for
io
in
bind
(
'blockIdx.x'
,
8
):
for
i
in
bind
(
"threadIdx.x"
,
256
):
for
ii
in
bind
(
'blockIdx.y'
,
8
):
shared
[
i
]
=
a
[
i
]
shared_b
=
allocate
((
64
,
64
),
'float32'
,
'shared'
)
local
=
allocate
((
256
,
),
'float32'
,
'local'
)
for
k
in
range
(
64
):
for
i
in
bind
(
"threadIdx.x"
,
256
):
shared_b
[
io
*
8
+
ii
,
k
]
=
b
[
io
*
8
+
ii
,
k
]
local
[
i
]
=
b
[
i
]
for
jo
in
bind
(
'threadIdx.y'
,
8
):
for
i
in
bind
(
"threadIdx.x"
,
256
):
for
ji
in
bind
(
'threadIdx.x'
,
8
):
c
[
i
]
=
shared
[
i
]
+
local
[
i
]
for
k
in
range
(
64
):
c
[
io
*
8
+
ii
,
jo
*
8
+
ji
]
+=
a
[
io
*
8
+
ii
,
k
]
*
shared_b
[
k
,
jo
*
8
+
ji
]
a
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'a'
)
b
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'b'
)
a
=
tvm
.
placeholder
((
64
,
64
),
dtype
=
'float32'
,
name
=
'a'
)
c
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'c'
)
b
=
tvm
.
placeholder
((
64
,
64
),
dtype
=
'float32'
,
name
=
'b'
)
run_and_check
(
share_vec_add
,
[
a
,
b
,
c
],
[
c
],
target
=
'cuda'
)
c
=
tvm
.
placeholder
((
64
,
64
),
dtype
=
'float32'
,
name
=
'c'
)
module
=
run_and_check
(
shared_gemm
,
[
a
,
b
,
c
],
[
c
],
target
=
'cuda'
)
assert
"__syncthreads()"
in
module
.
imported_modules
[
0
]
.
get_source
()
else
:
else
:
print
(
'[Warning] No GPU found! Skip shared mem test!'
)
print
(
'[Warning] No GPU found! Skip shared mem test!'
)
def
test_augassign
():
@tvm.hybrid.script
def
augassign
(
a
):
for
i
in
range
(
a
.
shape
[
0
]):
a
[
i
]
+=
1.0
a
=
tvm
.
placeholder
((
16
,
),
dtype
=
'float32'
,
name
=
'a'
)
run_and_check
(
augassign
,
[
a
],
[
a
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
test_fanout
()
test_fanout
()
...
@@ -348,4 +372,5 @@ if __name__ == "__main__":
...
@@ -348,4 +372,5 @@ if __name__ == "__main__":
test_math_intrin
()
test_math_intrin
()
test_non_zero
()
test_non_zero
()
test_allocate
()
test_allocate
()
test_augassign
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment