Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
95de08ba
Unverified
Commit
95de08ba
authored
Feb 16, 2020
by
Zhi
Committed by
GitHub
Feb 16, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix alpha_equal bug (#4897)
parent
e7be8bf4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
50 deletions
+109
-50
src/relay/ir/alpha_equal.cc
+1
-1
tests/python/relay/test_ir_nodes.py
+2
-0
tests/python/relay/test_pass_alpha_equal.py
+24
-1
tests/python/relay/test_pass_fuse_ops.py
+35
-1
tests/python/relay/test_pass_merge_composite.py
+47
-47
No files found.
src/relay/ir/alpha_equal.cc
View file @
95de08ba
...
...
@@ -92,7 +92,7 @@ class AlphaEqualHandler:
auto
compute
=
[
&
]()
{
if
(
&
lhs
==
&
rhs
)
return
true
;
if
(
auto
lhsd
=
lhs
.
as
<
DictAttrsNode
>
())
{
auto
rhsd
=
l
hs
.
as
<
DictAttrsNode
>
();
auto
rhsd
=
r
hs
.
as
<
DictAttrsNode
>
();
if
(
!
rhsd
)
return
false
;
if
(
lhsd
->
dict
.
size
()
!=
rhsd
->
dict
.
size
())
return
false
;
for
(
const
auto
&
k
:
lhsd
->
dict
)
{
...
...
tests/python/relay/test_ir_nodes.py
View file @
95de08ba
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
""" test ir"""
import
pytest
import
tvm
from
tvm
import
relay
from
tvm.tir.expr
import
*
...
...
@@ -174,6 +175,7 @@ def test_function():
str
(
fn
)
check_json_roundtrip
(
fn
)
@pytest.mark.skip
(
reason
=
"AttrsEqualHandler doesn't handle Map so far."
)
def
test_function_attrs
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
convert
([
relay
.
var
(
n
,
shape
=
(
5
,
2
))
for
n
in
param_names
])
...
...
tests/python/relay/test_pass_alpha_equal.py
View file @
95de08ba
...
...
@@ -18,6 +18,7 @@ import numpy as np
import
tvm
from
tvm
import
relay
from
tvm.relay
import
analysis
from
tvm.relay.testing
import
run_opt_pass
def
alpha_equal
(
x
,
y
):
"""
...
...
@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal():
assert
alpha_equal
(
relay
.
TupleGetItem
(
x
,
1
),
relay
.
TupleGetItem
(
x
,
1
))
def
test_
multi_node_subgraph
():
def
test_
function_attr
():
x0
=
relay
.
var
(
'x0'
,
shape
=
(
10
,
10
))
w00
=
relay
.
var
(
'w00'
,
shape
=
(
10
,
10
))
w01
=
relay
.
var
(
'w01'
,
shape
=
(
10
,
10
))
...
...
@@ -608,6 +609,7 @@ def test_graph_equal():
z3
=
relay
.
add
(
relay
.
add
(
x
,
x
),
relay
.
add
(
x
,
x
))
assert
alpha_equal
(
z0
,
z1
)
assert
alpha_equal
(
z0
,
z1
)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
...
...
@@ -649,6 +651,26 @@ def test_tuple_match():
assert
analysis
.
structural_hash
(
x
)
==
analysis
.
structural_hash
(
y
)
def
test_fn_attribute
():
# create function that performs add
a
=
relay
.
var
(
'a'
,
shape
=
(
10
,
10
))
b
=
relay
.
var
(
'b'
,
shape
=
(
10
,
10
))
add
=
relay
.
add
(
a
,
b
)
add_fn
=
relay
.
Function
([
a
,
b
],
add
)
add_fn
=
run_opt_pass
(
add_fn
,
relay
.
transform
.
InferType
())
# create function that performs add with test attribute
c
=
relay
.
var
(
'c'
,
shape
=
(
10
,
10
))
d
=
relay
.
var
(
'd'
,
shape
=
(
10
,
10
))
add_1
=
relay
.
add
(
c
,
d
)
add_1_fn
=
relay
.
Function
([
c
,
d
],
add_1
)
add_1_fn
=
add_1_fn
.
set_attribute
(
"TestAttribute"
,
tvm
.
tir
.
StringImm
(
"test"
))
add_1_fn
=
run_opt_pass
(
add_1_fn
,
relay
.
transform
.
InferType
())
assert
not
relay
.
analysis
.
alpha_equal
(
add_1_fn
,
add_fn
)
assert
not
relay
.
analysis
.
alpha_equal
(
add_fn
,
add_1_fn
)
if
__name__
==
"__main__"
:
test_tensor_type_alpha_equal
()
test_incomplete_type_alpha_equal
()
...
...
@@ -672,3 +694,4 @@ if __name__ == "__main__":
test_var_alpha_equal
()
test_graph_equal
()
test_hash_unequal
()
test_fn_attribute
()
tests/python/relay/test_pass_fuse_ops.py
View file @
95de08ba
...
...
@@ -35,6 +35,7 @@ def test_fuse_simple():
z
=
relay
.
exp
(
y
)
w
=
relay
.
squeeze
(
z
)
f1
=
relay
.
Function
([
x
],
w
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
Call
(
f1
,
[
x
])
return
relay
.
Function
([
x
],
y
)
...
...
@@ -76,6 +77,8 @@ def test_conv2d_fuse():
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
f0
=
relay
.
Function
([
x
],
y
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
# segment 1
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
...
...
@@ -86,6 +89,8 @@ def test_conv2d_fuse():
y1
=
relay
.
add
(
relay
.
const
(
1
,
"float32"
),
y
)
y
=
relay
.
add
(
y
,
y1
)
f1
=
relay
.
Function
([
x
,
w
],
y
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
# segment 2
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
...
...
@@ -94,6 +99,8 @@ def test_conv2d_fuse():
padding
=
(
1
,
1
),
channels
=
16
)
f2
=
relay
.
Function
([
x
,
w
],
z2
)
f2
=
f2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
# segment 3
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
...
...
@@ -104,6 +111,8 @@ def test_conv2d_fuse():
channels
=
16
)
z3
=
relay
.
add
(
z3
,
offset
)
f3
=
relay
.
Function
([
x
,
w
,
offset
],
z3
)
f3
=
f3
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
# compose
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f0
,
[
x
])
...
...
@@ -135,6 +144,7 @@ def test_concatenate():
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
pooled
=
relay
.
nn
.
max_pool2d
(
x
,
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
))
f0
=
relay
.
Function
([
x
],
pooled
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p0
=
relay
.
var
(
"p0"
,
shape
=
(
dshape
[
0
],
dshape
[
1
],
dshape
[
2
]
//
2
,
dshape
[
3
]
//
2
))
p1
=
relay
.
var
(
"p1"
,
shape
=
dshape
)
...
...
@@ -142,6 +152,7 @@ def test_concatenate():
concat
=
relay
.
concatenate
((
upsampled
,
p1
),
axis
=
1
)
out
=
relay
.
add
(
concat
,
relay
.
const
(
1
,
"float32"
))
f1
=
relay
.
Function
([
p0
,
p1
],
out
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f0
,
[
x
])
...
...
@@ -172,10 +183,12 @@ def test_tuple_root():
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
pooled
=
relay
.
nn
.
max_pool2d
(
x
,
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
))
f0
=
relay
.
Function
([
x
],
pooled
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p0
=
relay
.
var
(
"p0"
,
shape
=
(
dshape
[
0
],
dshape
[
1
],
dshape
[
2
]
//
2
,
dshape
[
3
]
//
2
))
upsampled
=
relay
.
nn
.
upsampling
(
p0
,
scale_h
=
2
,
scale_w
=
2
,
layout
=
"NCHW"
)
f1
=
relay
.
Function
([
p0
],
upsampled
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f0
,
[
x
])
...
...
@@ -205,10 +218,12 @@ def test_stop_fusion():
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
f1
=
relay
.
Function
([
x
],
y
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"p01"
,
shape
=
dshape
)
y
=
relay
.
exp
(
x
)
f2
=
relay
.
Function
([
x
],
y
)
f2
=
f2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f1
,
[
x
])
...
...
@@ -242,6 +257,7 @@ def test_fuse_myia_regression():
p2
=
relay
.
var
(
'p2'
,
shape
=
dshape
,
dtype
=
dtype
)
fused_gt
=
relay
.
Function
([
p1
,
p2
],
relay
.
op
.
greater
(
p1
,
p2
))
fused_gt
=
fused_gt
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
with
sb
.
if_scope
(
fused_gt
(
x
,
y
)):
sb
.
ret
(
relay
.
Function
([],
x
))
with
sb
.
else_scope
():
...
...
@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise():
p1
=
relay
.
var
(
"p1"
,
shape
=
(
3
*
dim
,
dim
))
matmul
=
relay
.
nn
.
dense
(
p0
,
p1
)
f0
=
relay
.
Function
([
p0
,
p1
],
matmul
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p01
=
relay
.
var
(
"p01"
,
shape
=
(
1
,
3
*
dim
))
splitted
=
relay
.
split
(
p01
,
indices_or_sections
=
3
,
axis
=
1
)
out
=
relay
.
sigmoid
(
splitted
[
0
])
+
relay
.
tanh
(
splitted
[
1
])
*
relay
.
exp
(
splitted
[
2
])
f1
=
relay
.
Function
([
p01
],
out
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
X
=
relay
.
var
(
"X"
,
shape
=
(
1
,
dim
))
W
=
relay
.
var
(
"W"
,
shape
=
(
3
*
dim
,
dim
))
...
...
@@ -306,11 +324,13 @@ def test_tuple_get_root():
splitted
=
relay
.
split
(
p0
,
indices_or_sections
=
3
,
axis
=
1
)
out
=
splitted
[
0
]
f0
=
relay
.
Function
([
p0
],
out
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p01
=
relay
.
var
(
"p01"
,
shape
=
(
1
,
dim
))
p1
=
relay
.
var
(
"p1"
,
shape
=
(
dim
,
dim
))
out
=
relay
.
nn
.
dense
(
p01
,
p1
)
f1
=
relay
.
Function
([
p01
,
p1
],
out
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
X
=
relay
.
var
(
"X"
,
shape
=
(
1
,
3
*
dim
))
W
=
relay
.
var
(
"W"
,
shape
=
(
dim
,
dim
))
...
...
@@ -346,8 +366,9 @@ def test_tuple_intermediate():
def
expected
(
p0
):
f0
=
before
(
p0
)
f1
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f
0
,
[
x
])
y
=
relay
.
Call
(
f
1
,
[
x
])
return
relay
.
Function
([
x
],
y
)
dshape
=
(
1
,
16
,
64
,
64
)
...
...
@@ -388,15 +409,18 @@ def test_tuple_consecutive():
p0
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
concat
=
gen_consecutive_tuple
(
p0
)
f0
=
relay
.
Function
([
p0
],
concat
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p01
=
relay
.
var
(
"p01"
,
shape
=
(
1
,
dshape
[
1
]
*
9
,
dshape
[
2
],
dshape
[
3
]))
pooled
=
relay
.
nn
.
max_pool2d
(
p01
,
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
))
out
=
relay
.
add
(
pooled
,
relay
.
const
(
1
,
"float32"
))
f1
=
relay
.
Function
([
p01
],
out
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p02
=
relay
.
var
(
"p02"
,
shape
=
(
1
,
dshape
[
1
]
*
9
,
dshape
[
2
]
//
2
,
dshape
[
3
]
//
2
))
out
=
relay
.
add
(
p02
,
relay
.
const
(
1
,
"float32"
))
f2
=
relay
.
Function
([
p02
],
out
)
f2
=
f2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f0
,
[
x
])
...
...
@@ -438,30 +462,36 @@ def test_inception_like():
p0
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
c
=
conv
(
p0
)
f0
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
c
),
c
)
f0
=
f0
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p01
=
relay
.
var
(
"p01"
,
shape
=
dshape
)
c
=
conv
(
p01
)
f1
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
c
),
c
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p02
=
relay
.
var
(
"p02"
,
shape
=
dshape
)
p12
=
relay
.
var
(
"p12"
,
shape
=
dshape
)
concat1
=
relay
.
concatenate
((
p02
,
p12
),
axis
=
1
)
f_concat1
=
relay
.
Function
([
p02
,
p12
],
concat1
)
f_concat1
=
f_concat1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
dshape2
=
(
dshape
[
0
],
dshape
[
1
]
*
2
,
dshape
[
2
],
dshape
[
3
])
p03
=
relay
.
var
(
"p03"
,
shape
=
dshape2
)
c
=
conv
(
p03
)
f2
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
c
),
c
)
f2
=
f2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p04
=
relay
.
var
(
"p04"
,
shape
=
dshape2
)
c
=
conv
(
p04
)
f3
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
c
),
c
)
f3
=
f3
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
p05
=
relay
.
var
(
"p05"
,
shape
=
dshape
)
p15
=
relay
.
var
(
"p15"
,
shape
=
dshape
)
concat2
=
relay
.
concatenate
((
p05
,
p15
),
axis
=
1
)
f_concat2
=
relay
.
Function
([
p05
,
p15
],
concat2
)
f_concat2
=
f_concat2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
c1
=
relay
.
Call
(
f0
,
[
x
,
relay
.
var
(
"w1"
)])
...
...
@@ -499,6 +529,7 @@ def test_fuse_parallel_injective():
u
=
relay
.
transpose
(
y
,
axes
=
[
0
,
1
])
w
=
relay
.
left_shift
(
z
,
u
)
f1
=
relay
.
Function
([
x
],
w
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
Call
(
f1
,
[
x
])
return
relay
.
Function
([
x
],
y
)
...
...
@@ -529,6 +560,7 @@ def test_immutable():
z
=
relay
.
exp
(
y
)
w
=
relay
.
squeeze
(
z
)
f1
=
relay
.
Function
([
x
],
w
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
Call
(
f1
,
[
x
])
mod
=
tvm
.
IRModule
()
...
...
@@ -570,6 +602,7 @@ def test_fuse_max():
for
i
in
range
(
max_fused_ops
):
y
=
relay
.
exp
(
y
)
f1
=
relay
.
Function
([
x
],
y
)
f1
=
f1
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
z
=
relay
.
Call
(
f1
,
[
x
])
xx
=
relay
.
var
(
"pp"
,
shape
=
(
10
,
20
))
...
...
@@ -577,6 +610,7 @@ def test_fuse_max():
for
i
in
range
(
n
-
max_fused_ops
):
yy
=
relay
.
exp
(
yy
)
f2
=
relay
.
Function
([
xx
],
yy
)
f2
=
f2
.
set_attribute
(
"Primitive"
,
tvm
.
tir
.
IntImm
(
"int32"
,
1
))
zz
=
relay
.
Call
(
f2
,
[
z
])
return
relay
.
Function
([
x
],
zz
)
...
...
tests/python/relay/test_pass_merge_composite.py
View file @
95de08ba
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment