Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
76812dea
Commit
76812dea
authored
Feb 22, 2019
by
Jian Weng
Committed by
Tianqi Chen
Feb 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix lint (#2649)
parent
dee8cf9b
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
32 deletions
+83
-32
python/tvm/hybrid/calls.py
+1
-1
python/tvm/hybrid/parser.py
+59
-16
python/tvm/hybrid/preprocessor.py
+0
-0
tests/python/unittest/test_hybrid_script.py
+23
-15
No files found.
python/tvm/hybrid/calls.py
View file @
76812dea
...
@@ -45,8 +45,8 @@ def bind(func_id, args):
...
@@ -45,8 +45,8 @@ def bind(func_id, args):
_internal_assert
(
args
.
__len__
()
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
args
.
__len__
()
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
low
,
ext
=
_api
.
const
(
0
,
"int32"
),
args
[
1
]
low
,
ext
=
_api
.
const
(
0
,
"int32"
),
args
[
1
]
iter_var
=
_api
.
thread_axis
((
low
,
ext
),
args
[
0
])
for_type
=
None
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
return
iter_var
,
low
,
ext
,
for_type
...
...
python/tvm/hybrid/parser.py
View file @
76812dea
...
@@ -12,7 +12,7 @@ from enum import Enum
...
@@ -12,7 +12,7 @@ from enum import Enum
from
.util
import
_internal_assert
from
.util
import
_internal_assert
from
.
import
calls
from
.
import
calls
from
.
import
util
from
.
import
util
from
.
var_decl
import
determine_variable_usage
from
.
preprocessor
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..api
import
any
as
_any
from
..container
import
Array
from
..container
import
Array
...
@@ -61,6 +61,7 @@ class Symbol(Enum):
...
@@ -61,6 +61,7 @@ class Symbol(Enum):
BufferVar
=
7
BufferVar
=
7
LoopVar
=
8
LoopVar
=
8
ConstLoopVar
=
9
ConstLoopVar
=
9
ThreadBind
=
10
class
HybridParser
(
ast
.
NodeVisitor
):
class
HybridParser
(
ast
.
NodeVisitor
):
...
@@ -117,7 +118,10 @@ class HybridParser(ast.NodeVisitor):
...
@@ -117,7 +118,10 @@ class HybridParser(ast.NodeVisitor):
self
.
symbols
=
{}
# Symbol table
self
.
symbols
=
{}
# Symbol table
for
k
,
v
in
symbols
.
items
():
for
k
,
v
in
symbols
.
items
():
if
isinstance
(
v
,
types
.
FunctionType
):
if
isinstance
(
v
,
types
.
FunctionType
):
self
.
symbols
[
k
]
=
Symbol
.
Callable
,
v
self
.
add_symbol
(
k
,
Symbol
.
Callable
,
v
)
self
.
binds
=
{}
# Thread binds
self
.
device
=
0
# Is it generating device
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
self
.
outputs
=
[]
# Output tensors' name
...
@@ -126,6 +130,25 @@ class HybridParser(ast.NodeVisitor):
...
@@ -126,6 +130,25 @@ class HybridParser(ast.NodeVisitor):
self
.
returned
=
False
# If this function has a valid return
self
.
returned
=
False
# If this function has a valid return
def
add_symbol
(
self
,
key
,
ty
,
val
):
#pylint: disable=invalid-name
"""Add value to the symbol table context"""
if
key
in
self
.
symbols
.
keys
():
old
=
str
(
self
.
symbols
[
key
])
new
=
str
((
ty
,
val
))
_internal_assert
(
False
,
"Name conflict in symbol table! [
%
s]
%
s ->
%
s"
%
(
key
,
old
,
new
))
self
.
symbols
[
key
]
=
ty
,
val
if
ty
==
Symbol
.
ThreadBind
:
if
val
.
var
.
name
not
in
self
.
binds
.
keys
():
self
.
binds
[
val
.
var
.
name
]
=
val
return
val_
=
self
.
binds
[
val
.
var
.
name
]
_internal_assert
(
_ir_pass
.
Equal
(
val_
.
dom
.
extent
,
val
.
dom
.
extent
),
"Thread extents should be uniform!"
)
self
.
symbols
[
key
]
=
ty
,
val_
def
wrap_up_realize
(
self
,
node
,
body
):
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
"""Wrap up all the variables which will no longer be used"""
...
@@ -141,11 +164,14 @@ class HybridParser(ast.NodeVisitor):
...
@@ -141,11 +164,14 @@ class HybridParser(ast.NodeVisitor):
continue
continue
elif
'Buffer'
in
ty
.
name
:
elif
'Buffer'
in
ty
.
name
:
_buf
=
entry
_buf
=
entry
_scope
=
ty
.
name
[:
-
6
]
.
lower
()
if
ty
is
not
Symbol
.
BufferVar
else
'global'
_scope
=
'global'
if
ty
is
Symbol
.
BufferVar
else
ty
.
name
[:
-
6
]
.
lower
()
to_pop
.
append
(
key
)
to_pop
.
append
(
key
)
else
:
else
:
continue
continue
if
_scope
==
'global'
:
body
=
self
.
wrap_up_binds
(
body
)
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
_true
=
_api
.
convert
(
True
)
...
@@ -158,6 +184,14 @@ class HybridParser(ast.NodeVisitor):
...
@@ -158,6 +184,14 @@ class HybridParser(ast.NodeVisitor):
return
body
return
body
def
wrap_up_binds
(
self
,
body
):
for
_
,
iter_var
in
self
.
binds
.
items
():
ext
=
iter_var
.
dom
.
extent
body
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
body
)
self
.
binds
=
{}
return
body
#pylint: disable=invalid-name, missing-docstring
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
...
@@ -173,10 +207,10 @@ class HybridParser(ast.NodeVisitor):
...
@@ -173,10 +207,10 @@ class HybridParser(ast.NodeVisitor):
self
.
func_name
=
node
.
name
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
symbols
[
getattr
(
arg
,
_attr
)]
=
(
Symbol
.
Input
,
self
.
args
[
idx
])
self
.
add_symbol
(
getattr
(
arg
,
_attr
),
Symbol
.
Input
,
self
.
args
[
idx
])
res
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
res
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
return
res
return
self
.
wrap_up_binds
(
res
)
def
visit_Expr
(
self
,
node
):
def
visit_Expr
(
self
,
node
):
...
@@ -189,6 +223,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -189,6 +223,8 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
return
entry
return
entry
if
ty
is
Symbol
.
ThreadBind
:
return
entry
.
var
if
ty
is
Symbol
.
ConstVar
:
if
ty
is
Symbol
.
ConstVar
:
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
if
ty
is
Symbol
.
BufferVar
:
if
ty
is
Symbol
.
BufferVar
:
...
@@ -237,7 +273,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -237,7 +273,7 @@ class HybridParser(ast.NodeVisitor):
for
i
in
range
(
rhs
.
num_outputs
):
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
"You should bind a pure name to the tensors"
)
self
.
symbols
[
node
.
targets
[
i
]
.
id
]
=
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
self
.
add_symbol
(
node
.
targets
[
i
]
.
id
,
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
...
@@ -260,15 +296,19 @@ class HybridParser(ast.NodeVisitor):
...
@@ -260,15 +296,19 @@ class HybridParser(ast.NodeVisitor):
if
isinstance
(
rhs
,
tuple
):
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
self
.
symbols
[
lhs
]
=
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
self
.
add_symbol
(
lhs
,
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
)
if
scope
==
'output'
:
if
scope
==
'output'
:
self
.
outputs
.
append
(
lhs
)
self
.
outputs
.
append
(
lhs
)
return
util
.
make_nop
()
return
util
.
make_nop
()
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
symbols
[
lhs
]
=
Symbol
.
ConstVar
,
rhs
self
.
add_symbol
(
lhs
,
Symbol
.
ConstVar
,
rhs
)
else
:
else
:
_internal_assert
(
self
.
device
==
0
,
"Single variable not supported in devices' side!
\n
"
+
\
"If you are using GPU, please allocate a 'local' spad "
+
\
"outside the bind body"
)
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
symbols
[
lhs
]
=
Symbol
.
BufferVar
,
ph
self
.
add_symbol
(
lhs
,
Symbol
.
BufferVar
,
ph
)
lhs
=
self
.
visit
(
lhs_
)
lhs
=
self
.
visit
(
lhs_
)
if
lhs
is
not
None
:
if
lhs
is
not
None
:
buf
,
args
=
lhs
buf
,
args
=
lhs
...
@@ -356,7 +396,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -356,7 +396,7 @@ class HybridParser(ast.NodeVisitor):
if
node
.
orelse
:
if
node
.
orelse
:
else_body
=
visit_list_to_block
(
self
.
visit
,
node
.
orelse
)
else_body
=
visit_list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else
:
else_body
=
util
.
make_nop
()
else_body
=
None
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
@@ -445,28 +485,31 @@ class HybridParser(ast.NodeVisitor):
...
@@ -445,28 +485,31 @@ class HybridParser(ast.NodeVisitor):
bodies
=
[]
bodies
=
[]
for
i
in
range
(
low
,
low
+
ext
):
for
i
in
range
(
low
,
low
+
ext
):
self
.
symbols
[
_name
]
=
Symbol
.
ConstLoopVar
,
i
self
.
add_symbol
(
_name
,
Symbol
.
ConstLoopVar
,
i
)
body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
body
=
self
.
wrap_up_realize
(
node
,
body
)
body
=
self
.
wrap_up_realize
(
node
,
body
)
bodies
.
append
(
body
)
bodies
.
append
(
body
)
self
.
symbols
.
pop
(
_name
)
return
concat_list_to_block
(
bodies
)
return
concat_list_to_block
(
bodies
)
if
iter_var
is
None
:
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop
bind
function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop
iterating
function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
offset
=
iter_var
+
low
offset
=
iter_var
+
low
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
offset
self
.
add_symbol
(
_name
,
Symbol
.
LoopVar
,
offset
)
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
else
:
else
:
_internal_assert
(
for_type
is
None
,
"The loop iterating function parse error!"
)
_internal_assert
(
for_type
is
None
,
"The loop bind function parse error!"
)
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
iter_var
.
var
self
.
add_symbol
(
_name
,
Symbol
.
ThreadBind
,
iter_var
)
self
.
device
+=
1
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
self
.
device
-=
1
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
if
for_type
is
None
:
if
for_type
is
None
:
res
=
_
make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
res
=
_
body
else
:
else
:
_internal_assert
(
not
isinstance
(
for_type
,
tuple
),
\
_internal_assert
(
not
isinstance
(
for_type
,
tuple
),
\
"Micro expansion should be handled before!"
)
"Micro expansion should be handled before!"
)
...
...
python/tvm/hybrid/
var_decl
.py
→
python/tvm/hybrid/
preprocessor
.py
View file @
76812dea
File moved
tests/python/unittest/test_hybrid_script.py
View file @
76812dea
...
@@ -300,6 +300,7 @@ def test_bind():
...
@@ -300,6 +300,7 @@ def test_bind():
if
not
tvm
.
gpu
(
0
)
.
exist
:
if
not
tvm
.
gpu
(
0
)
.
exist
:
print
(
'[Warning] No GPU found! Skip bind test!'
)
print
(
'[Warning] No GPU found! Skip bind test!'
)
return
return
@script
@script
def
vec_add
(
a
,
b
):
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
'float32'
)
c
=
output_tensor
((
1000
,
),
'float32'
)
...
@@ -326,23 +327,29 @@ def test_bind():
...
@@ -326,23 +327,29 @@ def test_bind():
func
,
ins
,
outs
=
run_and_check
(
raw
,
[
a
,
b
],
sch
=
sch
,
outs
=
[
c
],
target
=
'cuda'
)
func
,
ins
,
outs
=
run_and_check
(
raw
,
[
a
,
b
],
sch
=
sch
,
outs
=
[
c
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
# Test loop binds
@tvm.hybrid.script
@tvm.hybrid.script
def
goo
(
a
,
b
):
def
foo
(
a
):
c
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
c
=
output_tensor
((
a
.
shape
[
0
],),
a
.
dtype
)
len_b
=
len
(
b
)
total
=
allocate
((
1
,),
a
.
dtype
,
'local'
)
for
i
in
const_range
(
len_b
*
2
):
len_i
=
a
.
shape
[
0
]
if
i
<
len_b
:
len_j
=
a
.
shape
[
1
]
c
[
i
]
=
a
[
i
]
+
b
[
i
]
for
i
in
bind
(
'threadIdx.x'
,
len_i
):
else
:
total
[
0
]
=
0.
c
[
i
-
len_b
]
=
a
[
i
-
len_b
]
+
b
[
i
-
len_b
]
for
k
in
const_range
(
len_j
):
total
[
0
]
+=
a
[
i
,
k
]
c
[
i
]
=
total
[
0
]
return
c
return
c
a
=
tvm
.
placeholder
((
5
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
[
1
,
2
,
3
,
4
,
5
]
a
=
tvm
.
placeholder
((
8
,
4
),
'float32'
)
c
=
goo
(
a
,
tvm
.
convert
(
b
))
c
=
foo
(
a
)
sch
=
tvm
.
create_schedule
(
c
.
op
)
s
=
tvm
.
create_schedule
(
c
.
op
)
func
,
ins
,
outs
=
run_and_check
(
goo
,
[
a
,
b
],
sch
=
sch
,
outs
=
[
c
])
ir
=
tvm
.
lower
(
s
,
[
a
,
c
],
simple_mode
=
True
)
run_and_check
(
func
,
ins
,
outs
=
outs
)
assert
not
isinstance
(
ir
,
tvm
.
stmt
.
AttrStmt
)
func
,
ins
,
outs
=
run_and_check
(
foo
,
[
a
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
def
test_math_intrin
():
def
test_math_intrin
():
@script
@script
...
@@ -455,6 +462,7 @@ def test_allocate():
...
@@ -455,6 +462,7 @@ def test_allocate():
a
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'a'
)
a
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'a'
)
b
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'b'
)
b
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'b'
)
c
=
share_vec_add
(
a
,
b
)
func
,
ins
,
outs
=
run_and_check
(
share_vec_add
,
[
a
,
b
],
target
=
'cuda'
)
func
,
ins
,
outs
=
run_and_check
(
share_vec_add
,
[
a
,
b
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment