Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
26eb4053
Commit
26eb4053
authored
Nov 18, 2019
by
Animesh Jain
Committed by
Tianqi Chen
Nov 18, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay tests] AlterOpLayout - Temporary attr update (#4357)
parent
f1d6f335
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
215 additions
and
12 deletions
+215
-12
include/tvm/relay/op.h
+6
-0
python/tvm/relay/op/op.py
+10
-0
python/tvm/relay/testing/__init__.py
+1
-0
python/tvm/relay/testing/temp_op_attr.py
+63
-0
src/relay/ir/op.cc
+27
-1
tests/python/relay/test_ir_op.py
+47
-0
tests/python/relay/test_op_qnn_conv2d.py
+34
-0
tests/python/relay/test_op_qnn_dense.py
+15
-0
tests/python/relay/test_pass_alter_op_layout.py
+0
-0
tests/python/relay/test_pass_legalize.py
+9
-9
tests/python/relay/test_pass_qnn_legalize.py
+3
-2
No files found.
include/tvm/relay/op.h
View file @
26eb4053
...
...
@@ -258,6 +258,12 @@ class OpRegistry {
inline
OpRegistry
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
,
int
plevel
=
10
);
/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline
void
reset_attr
(
const
std
::
string
&
attr_name
);
// set the name of the op to be the same as registry
inline
OpRegistry
&
set_name
()
{
// NOLINT(*)
if
(
get
()
->
name
.
length
()
==
0
)
{
...
...
python/tvm/relay/op/op.py
View file @
26eb4053
...
...
@@ -64,6 +64,16 @@ class Op(Expr):
"""
_OpSetAttr
(
self
,
attr_name
,
value
,
plevel
)
def
reset_attr
(
self
,
attr_name
):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_OpResetAttr
(
self
,
attr_name
)
def
get
(
op_name
):
"""Get the Op for a given name
...
...
python/tvm/relay/testing/__init__.py
View file @
26eb4053
...
...
@@ -37,6 +37,7 @@ from . import squeezenet
from
.
import
vgg
from
.
import
densenet
from
.
import
yolo_detection
from
.
import
temp_op_attr
from
.config
import
ctx_list
from
.init
import
create_workload
...
...
python/tvm/relay/testing/temp_op_attr.py
0 → 100644
View file @
26eb4053
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Defines a TempOpAttr class that allows temporarily changing an attr of the
operator to allow unit testing. This is useful for AlterOpLayout and Legalize
tests."""
from
tvm
import
relay
class
TempOpAttr
(
object
):
""" Temporarily changes the attr of an op. """
def
__init__
(
self
,
op_name
,
attr_key
,
attr_value
):
""" Saves the required info for RAII pattern usage.
Parameters
----------
op_name : str
The op name.
attr_key : str
The attribute name.
attr_value : object
The attribute value.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
self
.
op
=
relay
.
op
.
get
(
op_name
)
self
.
attr_key
=
attr_key
self
.
attr_value
=
attr_value
def
__enter__
(
self
):
self
.
older_attr
=
self
.
op
.
get_attr
(
self
.
attr_key
)
self
.
op
.
reset_attr
(
self
.
attr_key
)
self
.
op
.
set_attr
(
self
.
attr_key
,
self
.
attr_value
)
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
self
.
op
.
reset_attr
(
self
.
attr_key
)
if
self
.
older_attr
:
self
.
op
.
set_attr
(
self
.
attr_key
,
self
.
older_attr
)
src/relay/ir/op.cc
View file @
26eb4053
...
...
@@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) {
return
true
;
}
// Resets attr of the OpMap.
void
OpRegistry
::
reset_attr
(
const
std
::
string
&
key
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
mgr
->
mutex
);
std
::
unique_ptr
<
GenericOpMap
>&
op_map
=
mgr
->
attr
[
key
];
if
(
op_map
==
nullptr
)
{
return
;
}
uint32_t
index
=
op_
->
index_
;
if
(
op_map
->
data_
.
size
()
>
index
)
{
op_map
->
data_
[
index
]
=
std
::
make_pair
(
TVMRetValue
(),
0
);
}
}
void
OpRegistry
::
UpdateAttr
(
const
std
::
string
&
key
,
TVMRetValue
value
,
int
plevel
)
{
...
...
@@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK
(
p
.
second
!=
plevel
)
<<
"Attribute "
<<
key
<<
" of operator "
<<
this
->
name
<<
" is already registered with same plevel="
<<
plevel
;
if
(
p
.
second
<
plevel
)
{
CHECK
(
value
.
type_code
()
!=
kNull
)
<<
"Registered packed_func is Null for "
<<
key
<<
" of operator "
<<
this
->
name
;
if
(
p
.
second
<
plevel
&&
value
.
type_code
()
!=
kNull
)
{
op_map
->
data_
[
index
]
=
std
::
make_pair
(
value
,
plevel
);
}
}
...
...
@@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg
.
set_attr
(
attr_name
,
value
,
plevel
);
});
TVM_REGISTER_API
(
"relay.op._OpResetAttr"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Op
op
=
args
[
0
];
std
::
string
attr_name
=
args
[
1
];
auto
&
reg
=
OpRegistry
::
Registry
()
->
__REGISTER_OR_GET__
(
op
->
name
);
reg
.
reset_attr
(
attr_name
);
});
TVM_REGISTER_API
(
"relay.op._Register"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
op_name
=
args
[
0
];
...
...
tests/python/relay/test_ir_op.py
View file @
26eb4053
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from
tvm
import
relay
from
tvm.relay.testing.temp_op_attr
import
TempOpAttr
def
test_op_attr
():
log_op
=
relay
.
op
.
get
(
"log"
)
...
...
@@ -27,6 +28,50 @@ def test_op_attr():
assert
log_op
.
get_attr
(
"ftest"
)
is
None
assert
relay
.
op
.
get
(
"exp"
)
.
get_attr
(
"ftest"
)(
1
)
==
2
def
test_op_reset_attr
():
""" Tests reset_attr functionality. """
def
add1
(
x
):
return
x
+
1
def
add2
(
x
):
return
x
+
2
# Register fadd1 and fadd2 attributes.
relay
.
op
.
register
(
"exp"
,
"fadd1"
,
add1
)
relay
.
op
.
register
(
"log"
,
"fadd1"
,
add1
)
relay
.
op
.
register
(
"log"
,
"fadd2"
,
add2
)
# Reset log fadd1 attr.
log_op
=
relay
.
op
.
get
(
"log"
)
log_op
.
reset_attr
(
"fadd1"
)
# Check that fadd1 attr is resetted.
assert
log_op
.
get_attr
(
"fadd1"
)
is
None
# Check that fadd1 attr of other ops are intact.
assert
relay
.
op
.
get
(
"exp"
)
.
get_attr
(
"fadd1"
)(
1
)
==
2
# Check that other attrs of the log op are intact.
assert
relay
.
op
.
get
(
"log"
)
.
get_attr
(
"fadd2"
)(
1
)
==
3
def
test_op_temp_attr
():
""" Tests reset_attr functionality. """
def
add1
(
x
):
return
x
+
1
def
add2
(
x
):
return
x
+
2
# Set original attr value is add1.
relay
.
op
.
register
(
"sqrt"
,
"ftest"
,
add1
)
with
TempOpAttr
(
"sqrt"
,
"ftest"
,
add2
):
# Check that the attr value is updated to add2.
assert
relay
.
op
.
get
(
"sqrt"
)
.
get_attr
(
"ftest"
)(
1
)
==
3
# Check that the attr value is recovered to add1.
assert
relay
.
op
.
get
(
"sqrt"
)
.
get_attr
(
"ftest"
)(
1
)
==
2
def
test_op_level1
():
x
=
relay
.
Var
(
"x"
)
...
...
@@ -47,5 +92,7 @@ def test_op_level3():
if
__name__
==
"__main__"
:
test_op_attr
()
test_op_reset_attr
()
test_op_temp_attr
()
test_op_level1
()
test_op_level3
()
tests/python/relay/test_op_qnn_conv2d.py
View file @
26eb4053
...
...
@@ -21,6 +21,14 @@ from tvm import relay
from
tvm.relay
import
transform
from
tvm.relay.testing
import
run_infer_type
from
tvm.contrib
import
graph_runtime
from
tvm.relay.testing.temp_op_attr
import
TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def
legalize_qnn_conv2d
(
attrs
,
inputs
,
types
):
return
None
def
get_ref_func
(
data
,
kernel
,
...
...
@@ -173,6 +181,8 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
np
.
testing
.
assert_equal
(
qnn_output
,
golden_output
)
def
test_no_zero_point
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
1
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -220,6 +230,8 @@ def test_no_zero_point():
kernel_shape
,
kernel_dtype
)
def
test_kernel_zero_point
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
4
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -268,6 +280,8 @@ def test_kernel_zero_point():
def
test_input_zero_point
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
4
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -315,6 +329,8 @@ def test_input_zero_point():
kernel_shape
,
kernel_dtype
)
def
test_both_zero_point
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
4
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -362,6 +378,8 @@ def test_both_zero_point():
kernel_shape
,
kernel_dtype
)
def
test_layout
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
2
,
4
,
4
)
# NHWC
data_dtype
=
'uint8'
...
...
@@ -411,6 +429,8 @@ def test_layout():
def
test_padding
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
1
,
4
,
2
,
2
)
data_dtype
=
'uint8'
...
...
@@ -458,6 +478,8 @@ def test_padding():
kernel_shape
,
kernel_dtype
)
def
test_dilation
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
4
,
4
,
4
)
data_dtype
=
'uint8'
...
...
@@ -483,6 +505,8 @@ def test_dilation():
def
test_const_folding
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
data_shape
=
(
2
,
4
,
2
,
4
)
data_dtype
=
'uint8'
kernel_shape
=
(
3
,
4
,
2
,
2
)
...
...
@@ -511,6 +535,8 @@ def test_const_folding():
assert
"reshape"
not
in
folded_func
.
astext
()
def
test_kernel_size_1x1
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
4
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -536,6 +562,8 @@ def test_kernel_size_1x1():
kernel_shape
,
kernel_dtype
)
def
test_tflite_large_irregular
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
1
,
1024
,
1
,
1
)
data_dtype
=
'uint8'
...
...
@@ -571,6 +599,8 @@ def test_tflite_large_irregular():
np
.
testing
.
assert_equal
(
qnn_output
,
golden_output
)
def
test_tflite_output_multiplier_greater_than_one
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
2
,
1
,
2
,
4
)
data_dtype
=
'uint8'
...
...
@@ -617,6 +647,8 @@ def test_tflite_output_multiplier_greater_than_one():
np
.
testing
.
assert_equal
(
qnn_output
,
golden_output
)
def
test_tflite_anistropic_strides
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# uint8 input
data_shape
=
(
1
,
1
,
3
,
6
)
data_dtype
=
'uint8'
...
...
@@ -656,6 +688,8 @@ def test_tflite_anistropic_strides():
np
.
testing
.
assert_equal
(
qnn_output
,
golden_output
)
def
test_broadcast_layout
():
with
TempOpAttr
(
"qnn.conv2d"
,
"FTVMQnnLegalize"
,
legalize_qnn_conv2d
):
# Test broadcast support for NHWC layout.
data_shape
=
(
1
,
229
,
229
,
3
)
# NHWC
data_dtype
=
'uint8'
...
...
tests/python/relay/test_op_qnn_dense.py
View file @
26eb4053
...
...
@@ -19,6 +19,15 @@ import tvm
import
numpy
as
np
from
tvm
import
relay
from
tvm.contrib
import
graph_runtime
from
tvm.relay.testing.temp_op_attr
import
TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def
legalize_qnn_dense
(
attrs
,
inputs
,
types
):
return
None
def
make_requantize_params
(
input_scale
,
output_scale
,
output_zero_point
,
out_dtype
):
...
...
@@ -209,18 +218,24 @@ def qnn_dense_driver(test_configuration):
def
test_qnn_dense_without_bias
():
with
TempOpAttr
(
"qnn.dense"
,
"FTVMQnnLegalize"
,
legalize_qnn_dense
):
int32_output_without_bias_params
=
\
make_int_configuration
(
use_bias
=
False
)
qnn_dense_driver
(
int32_output_without_bias_params
)
def
test_qnn_dense_with_bias
():
with
TempOpAttr
(
"qnn.dense"
,
"FTVMQnnLegalize"
,
legalize_qnn_dense
):
int32_output_with_bias_params
=
\
make_int_configuration
(
use_bias
=
True
)
qnn_dense_driver
(
int32_output_with_bias_params
)
def
test_qnn_dense_with_requantized_output
():
with
TempOpAttr
(
"qnn.dense"
,
"FTVMQnnLegalize"
,
legalize_qnn_dense
):
int8_requantized_output_with_bias_params
=
\
make_int_configuration
(
use_bias
=
True
,
requantize_output
=
True
)
qnn_dense_driver
(
int8_requantized_output_with_bias_params
)
...
...
tests/python/relay/test_pass_alter_op_layout.py
View file @
26eb4053
This diff is collapsed.
Click to expand it.
tests/python/relay/test_pass_legalize.py
View file @
26eb4053
...
...
@@ -20,8 +20,8 @@ import tvm
from
tvm
import
relay
from
tvm.contrib
import
graph_runtime
from
tvm.relay.op
import
register_legalize
from
tvm.relay
import
transform
,
analysis
from
tvm.relay.testing.temp_op_attr
import
TempOpAttr
def
run_opt_pass
(
expr
,
passes
):
...
...
@@ -46,7 +46,6 @@ def test_legalize():
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
@register_legalize
(
"nn.conv2d"
,
level
=
100
)
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
data
,
weight
=
inputs
weight
=
relay
.
multiply
(
weight
,
relay
.
const
(
2.0
,
"float32"
))
...
...
@@ -63,6 +62,7 @@ def test_legalize():
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
with
TempOpAttr
(
"nn.conv2d"
,
"FTVMLegalize"
,
legalize_conv2d
):
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
Legalize
())
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
@@ -79,16 +79,15 @@ def test_legalize_none():
called
=
[
False
]
@register_legalize
(
"nn.global_max_pool2d"
,
level
=
101
)
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
called
[
0
]
=
True
return
None
with
TempOpAttr
(
"nn.global_max_pool2d"
,
"FTVMLegalize"
,
legalize_conv2d
):
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
Legalize
())
b
=
run_opt_pass
(
before
(),
transform
.
InferType
())
b
=
before
()
b
=
run_opt_pass
(
b
,
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
assert
(
called
[
0
])
...
...
@@ -105,14 +104,12 @@ def test_legalize_multiple_ops():
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
@register_legalize
(
"nn.conv2d"
,
level
=
102
)
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
data
,
weight
=
inputs
weight
=
relay
.
multiply
(
weight
,
relay
.
const
(
2.0
,
"float32"
))
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
attrs
)
@register_legalize
(
"nn.relu"
,
level
=
103
)
def
legalize_conv2d
(
attrs
,
inputs
,
types
):
def
legalize_relu
(
attrs
,
inputs
,
types
):
data
=
inputs
[
0
]
add
=
relay
.
add
(
tvm
.
relay
.
const
(
0
,
"float32"
),
data
)
return
relay
.
nn
.
relu
(
add
)
...
...
@@ -130,6 +127,8 @@ def test_legalize_multiple_ops():
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
with
TempOpAttr
(
"nn.conv2d"
,
"FTVMLegalize"
,
legalize_conv2d
):
with
TempOpAttr
(
"nn.relu"
,
"FTVMLegalize"
,
legalize_relu
):
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
Legalize
())
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
@@ -147,7 +146,6 @@ def test_legalize_multi_input():
func
=
relay
.
Function
([
x
,
y
,
z
],
func
)
return
func
@register_legalize
(
"concatenate"
,
level
=
104
)
def
legalize_concatenate
(
attrs
,
inputs
,
types
):
# Check that the correct multi-input case is handled.
assert
len
(
inputs
)
==
1
...
...
@@ -165,6 +163,8 @@ def test_legalize_multi_input():
func
=
relay
.
Function
([
x
,
y
,
z
],
func
)
return
func
with
TempOpAttr
(
"concatenate"
,
"FTVMLegalize"
,
legalize_concatenate
):
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
Legalize
())
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
tests/python/relay/test_pass_qnn_legalize.py
View file @
26eb4053
...
...
@@ -20,8 +20,8 @@ import tvm
from
tvm
import
relay
from
tvm.contrib
import
graph_runtime
from
tvm.relay.qnn.op
import
register_qnn_legalize
from
tvm.relay
import
transform
,
analysis
from
tvm.relay.testing.temp_op_attr
import
TempOpAttr
def
alpha_equal
(
x
,
y
):
"""
...
...
@@ -54,7 +54,6 @@ def test_qnn_legalize():
y
=
relay
.
Function
([
x
],
y
)
return
y
@register_qnn_legalize
(
"qnn.requantize"
,
level
=
100
)
def
legalize_qnn_requantize
(
attrs
,
inputs
,
types
):
data
=
inputs
[
0
]
data
=
relay
.
add
(
relay
.
const
(
0
,
'int8'
),
data
)
...
...
@@ -80,6 +79,8 @@ def test_qnn_legalize():
a
=
before
()
with
TempOpAttr
(
"qnn.requantize"
,
"FTVMQnnLegalize"
,
legalize_qnn_requantize
):
# Check that Relay Legalize does not change the graph.
a
=
run_opt_pass
(
a
,
relay
.
transform
.
Legalize
())
b
=
run_opt_pass
(
before
(),
transform
.
InferType
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment