Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
76812dea
Commit
76812dea
authored
Feb 22, 2019
by
Jian Weng
Committed by
Tianqi Chen
Feb 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix lint (#2649)
parent
dee8cf9b
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
32 deletions
+83
-32
python/tvm/hybrid/calls.py
+1
-1
python/tvm/hybrid/parser.py
+59
-16
python/tvm/hybrid/preprocessor.py
+0
-0
tests/python/unittest/test_hybrid_script.py
+23
-15
No files found.
python/tvm/hybrid/calls.py
View file @
76812dea
...
...
@@ -45,8 +45,8 @@ def bind(func_id, args):
_internal_assert
(
args
.
__len__
()
==
2
,
"A loop bind should only have 2 arguments!"
)
_internal_assert
(
isinstance
(
args
[
0
],
str
),
\
"A loop bind's first argument should be a string!"
)
iter_var
=
_api
.
thread_axis
(
args
[
0
])
low
,
ext
=
_api
.
const
(
0
,
"int32"
),
args
[
1
]
iter_var
=
_api
.
thread_axis
((
low
,
ext
),
args
[
0
])
for_type
=
None
return
iter_var
,
low
,
ext
,
for_type
...
...
python/tvm/hybrid/parser.py
View file @
76812dea
...
...
@@ -12,7 +12,7 @@ from enum import Enum
from
.util
import
_internal_assert
from
.
import
calls
from
.
import
util
from
.
var_decl
import
determine_variable_usage
from
.
preprocessor
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..container
import
Array
...
...
@@ -61,6 +61,7 @@ class Symbol(Enum):
BufferVar
=
7
LoopVar
=
8
ConstLoopVar
=
9
ThreadBind
=
10
class
HybridParser
(
ast
.
NodeVisitor
):
...
...
@@ -117,7 +118,10 @@ class HybridParser(ast.NodeVisitor):
self
.
symbols
=
{}
# Symbol table
for
k
,
v
in
symbols
.
items
():
if
isinstance
(
v
,
types
.
FunctionType
):
self
.
symbols
[
k
]
=
Symbol
.
Callable
,
v
self
.
add_symbol
(
k
,
Symbol
.
Callable
,
v
)
self
.
binds
=
{}
# Thread binds
self
.
device
=
0
# Is it generating device
self
.
func_name
=
func_name
# The name of the function to be lowered
self
.
outputs
=
[]
# Output tensors' name
...
...
@@ -126,6 +130,25 @@ class HybridParser(ast.NodeVisitor):
self
.
returned
=
False
# If this function has a valid return
def
add_symbol
(
self
,
key
,
ty
,
val
):
#pylint: disable=invalid-name
"""Add value to the symbol table context"""
if
key
in
self
.
symbols
.
keys
():
old
=
str
(
self
.
symbols
[
key
])
new
=
str
((
ty
,
val
))
_internal_assert
(
False
,
"Name conflict in symbol table! [
%
s]
%
s ->
%
s"
%
(
key
,
old
,
new
))
self
.
symbols
[
key
]
=
ty
,
val
if
ty
==
Symbol
.
ThreadBind
:
if
val
.
var
.
name
not
in
self
.
binds
.
keys
():
self
.
binds
[
val
.
var
.
name
]
=
val
return
val_
=
self
.
binds
[
val
.
var
.
name
]
_internal_assert
(
_ir_pass
.
Equal
(
val_
.
dom
.
extent
,
val
.
dom
.
extent
),
"Thread extents should be uniform!"
)
self
.
symbols
[
key
]
=
ty
,
val_
def
wrap_up_realize
(
self
,
node
,
body
):
"""Wrap up all the variables which will no longer be used"""
...
...
@@ -141,11 +164,14 @@ class HybridParser(ast.NodeVisitor):
continue
elif
'Buffer'
in
ty
.
name
:
_buf
=
entry
_scope
=
ty
.
name
[:
-
6
]
.
lower
()
if
ty
is
not
Symbol
.
BufferVar
else
'global'
_scope
=
'global'
if
ty
is
Symbol
.
BufferVar
else
ty
.
name
[:
-
6
]
.
lower
()
to_pop
.
append
(
key
)
else
:
continue
if
_scope
==
'global'
:
body
=
self
.
wrap_up_binds
(
body
)
_domain
=
[
_make
.
range_by_min_extent
(
0
,
i
)
for
i
in
_buf
.
shape
]
_dtype
=
_buf
.
dtype
_true
=
_api
.
convert
(
True
)
...
...
@@ -158,6 +184,14 @@ class HybridParser(ast.NodeVisitor):
return
body
def
wrap_up_binds
(
self
,
body
):
for
_
,
iter_var
in
self
.
binds
.
items
():
ext
=
iter_var
.
dom
.
extent
body
=
_make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
body
)
self
.
binds
=
{}
return
body
#pylint: disable=invalid-name, missing-docstring
def
visit_Module
(
self
,
node
):
_internal_assert
(
len
(
node
.
body
)
==
1
,
\
...
...
@@ -173,10 +207,10 @@ class HybridParser(ast.NodeVisitor):
self
.
func_name
=
node
.
name
for
idx
,
arg
in
enumerate
(
node
.
args
.
args
):
_attr
=
'id'
if
sys
.
version_info
[
0
]
<
3
else
'arg'
# To make py2 and 3 compatible
self
.
symbols
[
getattr
(
arg
,
_attr
)]
=
(
Symbol
.
Input
,
self
.
args
[
idx
])
self
.
add_symbol
(
getattr
(
arg
,
_attr
),
Symbol
.
Input
,
self
.
args
[
idx
])
res
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
res
=
self
.
wrap_up_realize
(
node
,
res
)
return
res
return
self
.
wrap_up_binds
(
res
)
def
visit_Expr
(
self
,
node
):
...
...
@@ -189,6 +223,8 @@ class HybridParser(ast.NodeVisitor):
_internal_assert
(
name
in
self
.
symbols
,
"Unknown symbol
%
s!"
%
name
)
if
ty
in
[
Symbol
.
LoopVar
,
Symbol
.
Input
,
Symbol
.
ConstLoopVar
]:
return
entry
if
ty
is
Symbol
.
ThreadBind
:
return
entry
.
var
if
ty
is
Symbol
.
ConstVar
:
return
entry
if
isinstance
(
node
.
ctx
,
ast
.
Load
)
else
None
if
ty
is
Symbol
.
BufferVar
:
...
...
@@ -237,7 +273,7 @@ class HybridParser(ast.NodeVisitor):
for
i
in
range
(
rhs
.
num_outputs
):
_internal_assert
(
isinstance
(
node
.
targets
[
i
],
ast
.
Name
),
"You should bind a pure name to the tensors"
)
self
.
symbols
[
node
.
targets
[
i
]
.
id
]
=
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
self
.
add_symbol
(
node
.
targets
[
i
]
.
id
,
Symbol
.
GlobalBuffer
,
rhs
.
output
(
i
)
)
rmap
[
rhs
.
outputs
[
i
]
.
op
]
=
rhs
.
output
(
i
)
return
util
.
replace_io
(
rhs
.
body
,
rmap
)
...
...
@@ -260,15 +296,19 @@ class HybridParser(ast.NodeVisitor):
if
isinstance
(
rhs
,
tuple
):
shape
,
dtype
,
scope
=
rhs
ph
=
_api
.
placeholder
(
shape
,
dtype
=
dtype
,
name
=
lhs
)
self
.
symbols
[
lhs
]
=
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
self
.
add_symbol
(
lhs
,
getattr
(
Symbol
,
scope
.
title
()
+
"Buffer"
),
ph
)
if
scope
==
'output'
:
self
.
outputs
.
append
(
lhs
)
return
util
.
make_nop
()
if
isinstance
(
rhs
,
util
.
halide_imm_types
)
and
ast
.
Store
not
in
rw
:
self
.
symbols
[
lhs
]
=
Symbol
.
ConstVar
,
rhs
self
.
add_symbol
(
lhs
,
Symbol
.
ConstVar
,
rhs
)
else
:
_internal_assert
(
self
.
device
==
0
,
"Single variable not supported in devices' side!
\n
"
+
\
"If you are using GPU, please allocate a 'local' spad "
+
\
"outside the bind body"
)
ph
=
_api
.
placeholder
((
1
,
),
dtype
=
rhs
.
dtype
,
name
=
lhs
)
self
.
symbols
[
lhs
]
=
Symbol
.
BufferVar
,
ph
self
.
add_symbol
(
lhs
,
Symbol
.
BufferVar
,
ph
)
lhs
=
self
.
visit
(
lhs_
)
if
lhs
is
not
None
:
buf
,
args
=
lhs
...
...
@@ -356,7 +396,7 @@ class HybridParser(ast.NodeVisitor):
if
node
.
orelse
:
else_body
=
visit_list_to_block
(
self
.
visit
,
node
.
orelse
)
else
:
else_body
=
util
.
make_nop
()
else_body
=
None
return
_make
.
IfThenElse
(
cond
,
if_body
,
else_body
)
...
...
@@ -445,28 +485,31 @@ class HybridParser(ast.NodeVisitor):
bodies
=
[]
for
i
in
range
(
low
,
low
+
ext
):
self
.
symbols
[
_name
]
=
Symbol
.
ConstLoopVar
,
i
self
.
add_symbol
(
_name
,
Symbol
.
ConstLoopVar
,
i
)
body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
body
=
self
.
wrap_up_realize
(
node
,
body
)
bodies
.
append
(
body
)
self
.
symbols
.
pop
(
_name
)
return
concat_list_to_block
(
bodies
)
if
iter_var
is
None
:
_internal_assert
(
for_type
is
not
None
,
"The loop
bind
function parse error!"
)
_internal_assert
(
for_type
is
not
None
,
"The loop
iterating
function parse error!"
)
offset
=
iter_var
=
_api
.
var
(
_name
)
if
not
_ir_pass
.
Equal
(
low
,
_api
.
const
(
0
,
'int32'
)):
offset
=
iter_var
+
low
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
offset
self
.
add_symbol
(
_name
,
Symbol
.
LoopVar
,
offset
)
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
else
:
_internal_assert
(
for_type
is
None
,
"The loop iterating function parse error!"
)
self
.
symbols
[
_name
]
=
Symbol
.
LoopVar
,
iter_var
.
var
_internal_assert
(
for_type
is
None
,
"The loop bind function parse error!"
)
self
.
add_symbol
(
_name
,
Symbol
.
ThreadBind
,
iter_var
)
self
.
device
+=
1
_body
=
visit_list_to_block
(
self
.
visit
,
node
.
body
)
self
.
device
-=
1
_body
=
self
.
wrap_up_realize
(
node
,
_body
)
if
for_type
is
None
:
res
=
_
make
.
AttrStmt
(
iter_var
,
'thread_extent'
,
ext
,
_body
)
res
=
_
body
else
:
_internal_assert
(
not
isinstance
(
for_type
,
tuple
),
\
"Micro expansion should be handled before!"
)
...
...
python/tvm/hybrid/
var_decl
.py
→
python/tvm/hybrid/
preprocessor
.py
View file @
76812dea
File moved
tests/python/unittest/test_hybrid_script.py
View file @
76812dea
...
...
@@ -300,6 +300,7 @@ def test_bind():
if
not
tvm
.
gpu
(
0
)
.
exist
:
print
(
'[Warning] No GPU found! Skip bind test!'
)
return
@script
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
'float32'
)
...
...
@@ -326,23 +327,29 @@ def test_bind():
func
,
ins
,
outs
=
run_and_check
(
raw
,
[
a
,
b
],
sch
=
sch
,
outs
=
[
c
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
# Test loop binds
@tvm.hybrid.script
def
goo
(
a
,
b
):
c
=
output_tensor
(
a
.
shape
,
a
.
dtype
)
len_b
=
len
(
b
)
for
i
in
const_range
(
len_b
*
2
):
if
i
<
len_b
:
c
[
i
]
=
a
[
i
]
+
b
[
i
]
else
:
c
[
i
-
len_b
]
=
a
[
i
-
len_b
]
+
b
[
i
-
len_b
]
def
foo
(
a
):
c
=
output_tensor
((
a
.
shape
[
0
],),
a
.
dtype
)
total
=
allocate
((
1
,),
a
.
dtype
,
'local'
)
len_i
=
a
.
shape
[
0
]
len_j
=
a
.
shape
[
1
]
for
i
in
bind
(
'threadIdx.x'
,
len_i
):
total
[
0
]
=
0.
for
k
in
const_range
(
len_j
):
total
[
0
]
+=
a
[
i
,
k
]
c
[
i
]
=
total
[
0
]
return
c
a
=
tvm
.
placeholder
((
5
,
),
name
=
'a'
,
dtype
=
'int32'
)
b
=
[
1
,
2
,
3
,
4
,
5
]
c
=
goo
(
a
,
tvm
.
convert
(
b
))
sch
=
tvm
.
create_schedule
(
c
.
op
)
func
,
ins
,
outs
=
run_and_check
(
goo
,
[
a
,
b
],
sch
=
sch
,
outs
=
[
c
])
run_and_check
(
func
,
ins
,
outs
=
outs
)
a
=
tvm
.
placeholder
((
8
,
4
),
'float32'
)
c
=
foo
(
a
)
s
=
tvm
.
create_schedule
(
c
.
op
)
ir
=
tvm
.
lower
(
s
,
[
a
,
c
],
simple_mode
=
True
)
assert
not
isinstance
(
ir
,
tvm
.
stmt
.
AttrStmt
)
func
,
ins
,
outs
=
run_and_check
(
foo
,
[
a
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
def
test_math_intrin
():
@script
...
...
@@ -455,6 +462,7 @@ def test_allocate():
a
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'a'
)
b
=
tvm
.
placeholder
((
256
,
),
dtype
=
'float32'
,
name
=
'b'
)
c
=
share_vec_add
(
a
,
b
)
func
,
ins
,
outs
=
run_and_check
(
share_vec_add
,
[
a
,
b
],
target
=
'cuda'
)
run_and_check
(
func
,
ins
,
outs
=
outs
,
target
=
'cuda'
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment