Commit 26eb4053 by Animesh Jain Committed by Tianqi Chen

[Relay tests] AlterOpLayout - Temporary attr update (#4357)

parent f1d6f335
......@@ -258,6 +258,12 @@ class OpRegistry {
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);
/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline void reset_attr(const std::string& attr_name);
// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
......
......@@ -64,6 +64,16 @@ class Op(Expr):
"""
_OpSetAttr(self, attr_name, value, plevel)
def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_OpResetAttr(self, attr_name)
def get(op_name):
"""Get the Op for a given name
......
......@@ -37,6 +37,7 @@ from . import squeezenet
from . import vgg
from . import densenet
from . import yolo_detection
from . import temp_op_attr
from .config import ctx_list
from .init import create_workload
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Defines a TempOpAttr class that allows temporarily changing an attr of the
operator to allow unit testing. This is useful for AlterOpLayout and Legalize
tests."""
from tvm import relay
class TempOpAttr(object):
""" Temporarily changes the attr of an op. """
def __init__(self, op_name, attr_key, attr_value):
""" Saves the required info for RAII pattern usage.
Parameters
----------
op_name : str
The op name.
attr_key : str
The attribute name.
attr_value : object
The attribute value.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
self.op = relay.op.get(op_name)
self.attr_key = attr_key
self.attr_value = attr_value
def __enter__(self):
self.older_attr = self.op.get_attr(self.attr_key)
self.op.reset_attr(self.attr_key)
self.op.set_attr(self.attr_key, self.attr_value)
return self
def __exit__(self, ptype, value, trace):
self.op.reset_attr(self.attr_key)
if self.older_attr:
self.op.set_attr(self.attr_key, self.older_attr)
......@@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) {
return true;
}
// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
uint32_t index = op_->index_;
if (op_map->data_.size() > index) {
op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
}
}
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
......@@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
CHECK(value.type_code() != kNull)
<< "Registered packed_func is Null for " << key
<< " of operator " << this->name;
if (p.second < plevel && value.type_code() != kNull) {
op_map->data_[index] = std::make_pair(value, plevel);
}
}
......@@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg.set_attr(attr_name, value, plevel);
});
TVM_REGISTER_API("relay.op._OpResetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
reg.reset_attr(attr_name);
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
def test_op_attr():
log_op = relay.op.get("log")
......@@ -27,6 +28,50 @@ def test_op_attr():
assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2
def test_op_reset_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Register fadd1 and fadd2 attributes.
relay.op.register("exp", "fadd1", add1)
relay.op.register("log", "fadd1", add1)
relay.op.register("log", "fadd2", add2)
# Reset log fadd1 attr.
log_op = relay.op.get("log")
log_op.reset_attr("fadd1")
# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None
# Check that fadd1 attr of other ops are intact.
assert relay.op.get("exp").get_attr("fadd1")(1) == 2
# Check that other attrs of the log op are intact.
assert relay.op.get("log").get_attr("fadd2")(1) == 3
def test_op_temp_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Set original attr value is add1.
relay.op.register("sqrt", "ftest", add1)
with TempOpAttr("sqrt", "ftest", add2):
# Check that the attr value is updated to add2.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 3
# Check that the attr value is recovered to add1.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2
def test_op_level1():
x = relay.Var("x")
......@@ -47,5 +92,7 @@ def test_op_level3():
if __name__ == "__main__":
test_op_attr()
test_op_reset_attr()
test_op_temp_attr()
test_op_level1()
test_op_level3()
......@@ -21,6 +21,14 @@ from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
from tvm.contrib import graph_runtime
from tvm.relay.testing.temp_op_attr import TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def legalize_qnn_conv2d(attrs, inputs, types):
return None
def get_ref_func(data,
kernel,
......@@ -173,522 +181,548 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
np.testing.assert_equal(qnn_output, golden_output)
def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=1,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=5,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=1,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=0,
kernel_zero_point=5,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=0,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# int8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'int8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'int8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# NHWC and HWIO layout. Used in depthwise conv.
data_shape = (2, 2, 4, 1) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 1, 1) # HWOI
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# NHWC and HWIO layout. Used in depthwise conv.
data_shape = (2, 2, 4, 1) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 1, 1) # HWOI
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=5,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# Try different layout
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=5,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
# Try different layout
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(2, 2),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 1),
dilation=(2, 2),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
golden_weight = np.random.random_integers(low=0, high=255,
size=kernel_shape).astype(kernel_dtype)
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.const(golden_weight)
qnn_func = get_qnn_func(data,
kernel,
input_zero_point=8,
kernel_zero_point=3,
kernel_size=(2, 2),
input_scale=1.0,
kernel_scale=1.0,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
kernel_dtype = 'uint8'
golden_weight = np.random.random_integers(low=0, high=255,
size=kernel_shape).astype(kernel_dtype)
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.const(golden_weight)
qnn_func = get_qnn_func(data,
kernel,
input_zero_point=8,
kernel_zero_point=3,
kernel_size=(2, 2),
input_scale=1.0,
kernel_scale=1.0,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()
def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 1, 1)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
assert 'avg_pool2d' not in qnn_func.astext()
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 1, 1)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=5,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
assert 'avg_pool2d' not in qnn_func.astext()
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
kernel_shape = (1001, 1024, 1, 1)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=127,
kernel_zero_point=127,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = np.full(data_shape, 127).astype('uint8')
golden_weight = np.full(kernel_shape, 127).astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
np.testing.assert_equal(qnn_output, golden_output)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
kernel_shape = (1001, 1024, 1, 1)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=127,
kernel_zero_point=127,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(1, 1),
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = np.full(data_shape, 127).astype('uint8')
golden_weight = np.full(kernel_shape, 127).astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
np.testing.assert_equal(qnn_output, golden_output)
def test_tflite_output_multiplier_greater_than_one():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_scale=1.0,
kernel_scale=1.0,
input_zero_point=128,
kernel_zero_point=128,
kernel_size=(2, 2),
padding=(0, 0),
strides=(2, 2),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = 128 + np.array((1, 1, 1, 1,
2, 2, 2, 2,
1, 2, 3, 4,
1, 2, 3, 4)).reshape(data_shape).astype('uint8')
golden_weight = 128 + np.array((1, 2, 3, 4,
-1, 1, -1, 1,
-1, -1, 1, 1)).reshape(kernel_shape)
golden_weight = golden_weight.astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.array((17, 17,
0, 0,
2, 2,
16, 36,
2, 2,
0, 0)).reshape(2, 3, 1, 2)
np.testing.assert_equal(qnn_output, golden_output)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_scale=1.0,
kernel_scale=1.0,
input_zero_point=128,
kernel_zero_point=128,
kernel_size=(2, 2),
padding=(0, 0),
strides=(2, 2),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = 128 + np.array((1, 1, 1, 1,
2, 2, 2, 2,
1, 2, 3, 4,
1, 2, 3, 4)).reshape(data_shape).astype('uint8')
golden_weight = 128 + np.array((1, 2, 3, 4,
-1, 1, -1, 1,
-1, -1, 1, 1)).reshape(kernel_shape)
golden_weight = golden_weight.astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.array((17, 17,
0, 0,
2, 2,
16, 36,
2, 2,
0, 0)).reshape(2, 3, 1, 2)
np.testing.assert_equal(qnn_output, golden_output)
def test_tflite_anistropic_strides():
# uint8 input
data_shape = (1, 1, 3, 6)
data_dtype = 'uint8'
kernel_shape = (1, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=127,
kernel_zero_point=127,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 3),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = np.array((133, 131, 129, 125, 123, 121,
135, 133, 131, 123, 121, 119,
137, 135, 133, 121, 119, 117)).reshape(data_shape)
golden_data = golden_data.astype('uint8')
golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape)
golden_weight = golden_weight.astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input
data_shape = (1, 1, 3, 6)
data_dtype = 'uint8'
kernel_shape = (1, 1, 2, 2)
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=127,
kernel_zero_point=127,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(0, 0),
strides=(1, 3),
dilation=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32")
golden_data = np.array((133, 131, 129, 125, 123, 121,
135, 133, 131, 123, 121, 119,
137, 135, 133, 121, 119, 117)).reshape(data_shape)
golden_data = golden_data.astype('uint8')
golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape)
golden_weight = golden_weight.astype('uint8')
with relay.build_config(opt_level=2):
params = {'kernel': golden_weight}
graph, lib, params = relay.build(qnn_func, "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", golden_data)
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)
def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
kernel_shape = (7, 7, 3, 64) # HWIO
kernel_dtype = 'int8'
_, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(7, 7),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
func = qnn_func['main'].body
bias = relay.var("bias", shape=(64,), dtype="int32")
bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")
# Check broadcast support on both lhs and rhs
func = relay.add(func, bias2)
func = relay.add(bias2, func)
func = relay.add(bias, func)
func = relay.add(func, bias)
func = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(func)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
kernel_shape = (7, 7, 3, 64) # HWIO
kernel_dtype = 'int8'
_, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(7, 7),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
func = qnn_func['main'].body
bias = relay.var("bias", shape=(64,), dtype="int32")
bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")
# Check broadcast support on both lhs and rhs
func = relay.add(func, bias2)
func = relay.add(bias2, func)
func = relay.add(bias, func)
func = relay.add(func, bias)
func = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(func)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
if __name__ == "__main__":
test_no_zero_point()
......
......@@ -19,6 +19,15 @@ import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.temp_op_attr import TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def legalize_qnn_dense(attrs, inputs, types):
return None
def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype):
......@@ -209,21 +218,27 @@ def qnn_dense_driver(test_configuration):
def test_qnn_dense_without_bias():
int32_output_without_bias_params = \
make_int_configuration(use_bias=False)
qnn_dense_driver(int32_output_without_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_without_bias_params = \
make_int_configuration(use_bias=False)
qnn_dense_driver(int32_output_without_bias_params)
def test_qnn_dense_with_bias():
int32_output_with_bias_params = \
make_int_configuration(use_bias=True)
qnn_dense_driver(int32_output_with_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_with_bias_params = \
make_int_configuration(use_bias=True)
qnn_dense_driver(int32_output_with_bias_params)
def test_qnn_dense_with_requantized_output():
int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(int8_requantized_output_with_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(int8_requantized_output_with_bias_params)
if __name__ == "__main__":
......
......@@ -18,9 +18,8 @@
import tvm
from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
......@@ -31,7 +30,6 @@ def run_opt_pass(expr, passes):
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_alter_op():
"""Test directly replacing an operator with a new one"""
def before():
......@@ -45,8 +43,6 @@ def test_alter_op():
y = relay.Function([x, weight], y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=100)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
......@@ -63,9 +59,10 @@ def test_alter_op():
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -80,17 +77,15 @@ def test_alter_return_none():
called = [False]
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.global_max_pool2d", level=101)
def alter_conv2d(attrs, inputs, tinfos):
called[0] = True
return None
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
with TempOpAttr("nn.global_max_pool2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(before(), transform.InferType())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
......@@ -114,8 +109,6 @@ def test_alter_layout():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=102)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -123,6 +116,7 @@ def test_alter_layout():
new_attrs['kernel_layout'] = 'OIHW16i'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
bias = relay.var("bias", shape=(64,))
......@@ -149,12 +143,11 @@ def test_alter_layout():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -183,14 +176,13 @@ def test_alter_layout_dual_path():
y = relay.Function(analysis.free_vars(ret), ret)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=103)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
......@@ -215,11 +207,10 @@ def test_alter_layout_dual_path():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -245,14 +236,13 @@ def test_alter_layout_resnet():
y = relay.nn.global_max_pool2d(y)
return relay.Function(analysis.free_vars(y), y)
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=104)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
......@@ -274,11 +264,10 @@ def test_alter_layout_resnet():
y = relay.layout_transform(y, "NCHW16c", "NCHW")
return relay.Function(analysis.free_vars(y), y)
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -296,8 +285,6 @@ def test_alter_layout_broadcast_op():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=105)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -323,12 +310,11 @@ def test_alter_layout_broadcast_op():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -344,8 +330,6 @@ def test_alter_layout_scalar():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=106)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -368,26 +352,24 @@ def test_alter_layout_scalar():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate():
""" NCHW, NHWC and corner case concatenate layout transform."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=107)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# NCHW layout transformation.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
......@@ -425,11 +407,10 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -472,11 +453,10 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -492,8 +472,6 @@ def test_alter_layout_nchw_upsamping_op():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -512,12 +490,10 @@ def test_alter_layout_nchw_upsamping_op():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -532,8 +508,6 @@ def test_alter_layout_strided_slice():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=109)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -551,12 +525,11 @@ def test_alter_layout_strided_slice():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -570,12 +543,11 @@ def test_alter_layout_depthwise_conv2d():
return y
import topi
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
with tvm.target.create("llvm"):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
def expected():
x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3))
......@@ -588,12 +560,11 @@ def test_alter_layout_depthwise_conv2d():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b))
......@@ -608,8 +579,6 @@ def test_alter_layout_prelu():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -632,25 +601,23 @@ def test_alter_layout_prelu():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b))
def test_alter_layout_pad():
""" Check NCHW, NHWC and corner case for pad layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=112)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
......@@ -677,11 +644,10 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -712,15 +678,14 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that conversion does not happen when padding along split axis..
# Check that conversion does not happen when padding along split axis.
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
......@@ -746,25 +711,23 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=113)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
......@@ -791,11 +754,10 @@ def test_alter_layout_pool():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -826,25 +788,23 @@ def test_alter_layout_pool():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum():
""" Check NCHW, NHWC sum layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=114)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
......@@ -871,11 +831,10 @@ def test_alter_layout_sum():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -907,19 +866,16 @@ def test_alter_layout_sum():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=115)
def alter_conv2d(attrs, inputs, tinfos):
from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
......@@ -968,11 +924,10 @@ def test_alter_layout_nhwc_nchw_arm():
y = relay.Function(analysis.free_vars(y), y)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......
......@@ -20,8 +20,8 @@ import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
......@@ -46,7 +46,6 @@ def test_legalize():
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
......@@ -63,9 +62,10 @@ def test_legalize():
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -79,16 +79,15 @@ def test_legalize_none():
called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None
a = before()
a = run_opt_pass(a, transform.Legalize())
with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
......@@ -105,14 +104,12 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=102)
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
@register_legalize("nn.relu", level=103)
def legalize_conv2d(attrs, inputs, types):
def legalize_relu(attrs, inputs, types):
data = inputs[0]
add = relay.add(tvm.relay.const(0, "float32"), data)
return relay.nn.relu(add)
......@@ -130,9 +127,11 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -147,7 +146,6 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func)
return func
@register_legalize("concatenate", level=104)
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
......@@ -165,9 +163,11 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func)
return func
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......
......@@ -20,8 +20,8 @@ import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def alpha_equal(x, y):
"""
......@@ -54,7 +54,6 @@ def test_qnn_legalize():
y = relay.Function([x], y)
return y
@register_qnn_legalize("qnn.requantize", level=100)
def legalize_qnn_requantize(attrs, inputs, types):
data = inputs[0]
data = relay.add(relay.const(0, 'int8'), data)
......@@ -80,15 +79,17 @@ def test_qnn_legalize():
a = before()
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
with TempOpAttr("qnn.requantize", "FTVMQnnLegalize", legalize_qnn_requantize):
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment