Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3b3b8cbe
Commit
3b3b8cbe
authored
Dec 07, 2018
by
Jian Weng
Committed by
Tianqi Chen
Dec 07, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
allows constant param in op construct (#2257)
parent
d50f7b66
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
20 deletions
+63
-20
python/tvm/hybrid/parser.py
+10
-1
python/tvm/hybrid/util.py
+14
-16
python/tvm/hybrid/var_decl.py
+9
-0
tests/python/unittest/test_hybrid_script.py
+30
-3
No files found.
python/tvm/hybrid/parser.py
View file @
3b3b8cbe
...
@@ -144,7 +144,7 @@ class HybridParser(ast.NodeVisitor):
...
@@ -144,7 +144,7 @@ class HybridParser(ast.NodeVisitor):
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
_id
=
node
.
id
_id
=
node
.
id
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
_expr
.
Var
):
if
_id
in
self
.
_args
.
keys
()
and
isinstance
(
self
.
_args
[
_id
],
(
_expr
.
Var
,
_expr
.
ConstExpr
)
):
return
self
.
_args
[
_id
]
return
self
.
_args
[
_id
]
elif
_id
in
self
.
loops_above
.
keys
():
elif
_id
in
self
.
loops_above
.
keys
():
return
self
.
loops_above
[
_id
]
return
self
.
loops_above
[
_id
]
...
@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
...
@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
return
_api
.
const
(
node
.
n
)
return
_api
.
const
(
node
.
n
)
def
visit_AugAssign
(
self
,
node
):
lhs
=
self
.
visit
(
node
.
target
)
rhs
=
self
.
visit
(
node
.
value
)
rhs
=
HybridParser
.
_binop_maker
[
type
(
node
.
op
)](
lhs
,
rhs
)
_internal_assert
(
isinstance
(
lhs
,
_expr
.
Call
),
\
"The LHS of an AugAssign is supposed to be a call!"
)
return
_make
.
Provide
(
lhs
.
func
,
0
,
rhs
,
lhs
.
args
)
def
visit_Assign
(
self
,
node
):
def
visit_Assign
(
self
,
node
):
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
_internal_assert
(
len
(
node
.
targets
)
==
1
,
"So far only one-valued assignment is supported!"
)
lhs
=
node
.
targets
[
0
]
lhs
=
node
.
targets
[
0
]
...
...
python/tvm/hybrid/util.py
View file @
3b3b8cbe
...
@@ -15,9 +15,14 @@ from ..tensor import Tensor
...
@@ -15,9 +15,14 @@ from ..tensor import Tensor
#pylint: disable=invalid-name
#pylint: disable=invalid-name
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
np_arg_types
=
tuple
(
list
(
numeric_types
)
+
[
numpy
.
ndarray
])
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
)
tvm_arg_types
=
(
Tensor
,
_expr
.
Var
,
_expr
.
ConstExpr
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
halide_imm_types
=
(
_expr
.
IntImm
,
_expr
.
FloatImm
,
_expr
.
UIntImm
)
def
_internal_assert
(
cond
,
err
):
"""Simplify the code segment like if not XXX then raise an error"""
if
not
cond
:
raise
ValueError
(
err
)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def
make_nop
():
def
make_nop
():
...
@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
...
@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
If neither is true, raise a value error."""
If neither is true, raise a value error."""
if
isinstance
(
args
[
0
],
tvm_arg_types
):
if
isinstance
(
args
[
0
],
tvm_arg_types
):
for
elem
in
args
[
1
:]:
for
elem
in
args
[
1
:]:
if
not
isinstance
(
elem
,
tvm_arg_types
):
_internal_assert
(
isinstance
(
elem
,
tvm_arg_types
),
raise
ValueError
(
"Expect a Var or Tensor instance but
%
get!"
%
str
(
type
(
elem
)))
"Expecting a Var, Tensor or ConstExpr instance but
%
s get!"
\
%
str
(
type
(
elem
)))
return
True
return
True
if
not
isinstance
(
args
[
0
],
np_arg_types
):
raise
ValueError
(
"Expect a numpy type but
%
get!"
%
str
(
type
(
args
[
0
])))
_internal_assert
(
isinstance
(
args
[
0
],
np_arg_types
),
\
"Expect a numpy type but
%
s get!"
%
str
(
type
(
args
[
0
])))
for
elem
in
args
[
1
:]:
for
elem
in
args
[
1
:]:
if
not
isinstance
(
elem
,
np_arg_types
):
_internal_assert
(
isinstance
(
elem
,
np_arg_types
),
\
raise
ValueError
(
"Expect a numpy type but
%
get!"
%
str
(
type
(
elem
)))
"Expect a numpy type but
%
s
get!"
%
str
(
type
(
elem
)))
return
False
return
False
...
@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
...
@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
_globals
.
pop
(
elem
)
_globals
.
pop
(
elem
)
for
k
,
v
in
intersect
:
for
k
,
v
in
intersect
:
_globals
[
k
]
=
v
_globals
[
k
]
=
v
def
_internal_assert
(
cond
,
err
):
"""Simplify the code segment like if not XXX then raise an error"""
if
not
cond
:
raise
ValueError
(
err
)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert
=
_internal_assert
python/tvm/hybrid/var_decl.py
View file @
3b3b8cbe
...
@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
scope_level
=
[]
self
.
scope_level
=
[]
self
.
_args
=
{}
self
.
_args
=
{}
self
.
args
=
args
self
.
args
=
args
self
.
aug_assign_
=
False
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self
.
visit
(
elem
)
self
.
visit
(
elem
)
def
visit_AugAssign
(
self
,
node
):
self
.
aug_assign_
=
True
self
.
generic_visit
(
node
)
self
.
aug_assign_
=
False
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
# If it is from the argument list or loop variable, we do not worry about it!
# If it is from the argument list or loop variable, we do not worry about it!
if
node
.
id
in
self
.
_args
.
keys
():
if
node
.
id
in
self
.
_args
.
keys
():
...
@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
...
@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
if
node
.
id
not
in
self
.
status
.
keys
():
if
node
.
id
not
in
self
.
status
.
keys
():
_internal_assert
(
isinstance
(
node
.
ctx
,
ast
.
Store
),
\
_internal_assert
(
isinstance
(
node
.
ctx
,
ast
.
Store
),
\
'Undeclared variable
%
s'
%
node
.
id
)
'Undeclared variable
%
s'
%
node
.
id
)
if
self
.
aug_assign_
:
raise
ValueError
(
'"First store" cannot be an AugAssign'
)
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
self
.
status
[
node
.
id
]
=
(
node
,
self
.
scope_level
[
-
1
],
set
())
else
:
else
:
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
decl
,
loop
,
usage
=
self
.
status
[
node
.
id
]
...
...
tests/python/unittest/test_hybrid_script.py
View file @
3b3b8cbe
...
@@ -115,7 +115,7 @@ def test_fanout():
...
@@ -115,7 +115,7 @@ def test_fanout():
for
i
in
range
(
a
.
shape
[
0
]
-
3
):
for
i
in
range
(
a
.
shape
[
0
]
-
3
):
sigma
=
0.0
sigma
=
0.0
for
j
in
range
(
3
):
for
j
in
range
(
3
):
sigma
=
sigma
+
a
[
i
+
j
]
sigma
+=
a
[
i
+
j
]
sigma
=
sigma
/
three
sigma
=
sigma
/
three
b
[
i
]
=
sigma
b
[
i
]
=
sigma
return
b
return
b
...
@@ -246,7 +246,7 @@ def test_bind():
...
@@ -246,7 +246,7 @@ def test_bind():
def
vec_add
(
a
,
b
):
def
vec_add
(
a
,
b
):
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
c
=
output_tensor
((
1000
,
),
dtype
=
'float32'
)
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
for
tx
in
bind
(
'threadIdx.x'
,
1000
):
c
[
tx
]
=
b
[
tx
]
+
c
[
tx
]
c
[
tx
]
=
a
[
tx
]
+
b
[
tx
]
return
c
return
c
a
=
tvm
.
placeholder
((
1000
,
),
dtype
=
'float32'
,
name
=
'a'
)
a
=
tvm
.
placeholder
((
1000
,
),
dtype
=
'float32'
,
name
=
'a'
)
...
@@ -308,7 +308,7 @@ def test_non_zero():
...
@@ -308,7 +308,7 @@ def test_non_zero():
s
=
0.0
s
=
0.0
for
di
in
range
(
3
):
for
di
in
range
(
3
):
for
dj
in
range
(
3
):
for
dj
in
range
(
3
):
s
=
s
+
a
[
i
-
di
,
j
-
dj
]
s
+=
a
[
i
-
di
,
j
-
dj
]
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
b
[
i
-
2
,
j
-
2
]
=
s
/
9.0
return
b
return
b
...
@@ -419,6 +419,32 @@ def test_downstream():
...
@@ -419,6 +419,32 @@ def test_downstream():
module
(
tvm_a
,
tvm_c
)
module
(
tvm_a
,
tvm_c
)
tvm
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
tvm
.
testing
.
assert_allclose
(
tvm_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
def
test_const_param
():
@tvm.hybrid.script
def
add_something
(
a
,
b
):
c
=
output_tensor
((
11
,
),
'int32'
)
for
i
in
range
(
11
):
c
[
i
]
=
a
[
i
]
+
b
return
c
a
=
tvm
.
placeholder
((
11
,
),
dtype
=
'int32'
,
name
=
'a'
)
b
=
tvm
.
const
(
11
,
'int32'
)
c
=
add_something
(
a
,
b
)
sch
=
tvm
.
create_schedule
(
c
.
op
)
module
=
tvm
.
build
(
sch
,
[
a
,
c
],
'llvm'
)
assert
(
module
)
np_a
=
numpy
.
arange
(
11
)
.
astype
(
'int32'
)
np_b
=
11
np_c
=
numpy
.
zeros
((
11
,
))
.
astype
(
'int32'
)
nd_a
=
tvm
.
ndarray
.
array
(
np_a
)
nd_c
=
tvm
.
ndarray
.
array
(
numpy
.
zeros
((
11
,
))
.
astype
(
'int32'
))
module
(
nd_a
,
nd_c
)
ref
=
add_something
(
np_a
,
11
)
tvm
.
testing
.
assert_allclose
(
nd_c
.
asnumpy
(),
ref
,
1e-5
,
1e-5
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_outer_product
()
test_outer_product
()
...
@@ -432,5 +458,6 @@ if __name__ == "__main__":
...
@@ -432,5 +458,6 @@ if __name__ == "__main__":
#test_inplace()
#test_inplace()
test_upstream
()
test_upstream
()
test_downstream
()
test_downstream
()
test_const_param
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment