Commit 26eb4053 by Animesh Jain Committed by Tianqi Chen

[Relay tests] AlterOpLayout - Temporary attr update (#4357)

parent f1d6f335
...@@ -258,6 +258,12 @@ class OpRegistry { ...@@ -258,6 +258,12 @@ class OpRegistry {
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10); const ValueType& value, int plevel = 10);
/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline void reset_attr(const std::string& attr_name);
// set the name of the op to be the same as registry // set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*) inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) { if (get()->name.length() == 0) {
......
...@@ -64,6 +64,16 @@ class Op(Expr): ...@@ -64,6 +64,16 @@ class Op(Expr):
""" """
_OpSetAttr(self, attr_name, value, plevel) _OpSetAttr(self, attr_name, value, plevel)
def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_OpResetAttr(self, attr_name)
def get(op_name): def get(op_name):
"""Get the Op for a given name """Get the Op for a given name
......
...@@ -37,6 +37,7 @@ from . import squeezenet ...@@ -37,6 +37,7 @@ from . import squeezenet
from . import vgg from . import vgg
from . import densenet from . import densenet
from . import yolo_detection from . import yolo_detection
from . import temp_op_attr
from .config import ctx_list from .config import ctx_list
from .init import create_workload from .init import create_workload
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Defines a TempOpAttr class that allows temporarily changing an attr of the
operator to allow unit testing. This is useful for AlterOpLayout and Legalize
tests."""
from tvm import relay
class TempOpAttr(object):
""" Temporarily changes the attr of an op. """
def __init__(self, op_name, attr_key, attr_value):
""" Saves the required info for RAII pattern usage.
Parameters
----------
op_name : str
The op name.
attr_key : str
The attribute name.
attr_value : object
The attribute value.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
self.op = relay.op.get(op_name)
self.attr_key = attr_key
self.attr_value = attr_value
def __enter__(self):
self.older_attr = self.op.get_attr(self.attr_key)
self.op.reset_attr(self.attr_key)
self.op.set_attr(self.attr_key, self.attr_value)
return self
def __exit__(self, ptype, value, trace):
self.op.reset_attr(self.attr_key)
if self.older_attr:
self.op.set_attr(self.attr_key, self.older_attr)
...@@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) { ...@@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) {
return true; return true;
} }
// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
uint32_t index = op_->index_;
if (op_map->data_.size() > index) {
op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
}
}
void OpRegistry::UpdateAttr(const std::string& key, void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value, TVMRetValue value,
int plevel) { int plevel) {
...@@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key, ...@@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK(p.second != plevel) CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name << "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel; << " is already registered with same plevel=" << plevel;
if (p.second < plevel) { CHECK(value.type_code() != kNull)
<< "Registered packed_func is Null for " << key
<< " of operator " << this->name;
if (p.second < plevel && value.type_code() != kNull) {
op_map->data_[index] = std::make_pair(value, plevel); op_map->data_[index] = std::make_pair(value, plevel);
} }
} }
...@@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr") ...@@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg.set_attr(attr_name, value, plevel); reg.set_attr(attr_name, value, plevel);
}); });
TVM_REGISTER_API("relay.op._OpResetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
reg.reset_attr(attr_name);
});
TVM_REGISTER_API("relay.op._Register") TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0]; std::string op_name = args[0];
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from tvm import relay from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
def test_op_attr(): def test_op_attr():
log_op = relay.op.get("log") log_op = relay.op.get("log")
...@@ -27,6 +28,50 @@ def test_op_attr(): ...@@ -27,6 +28,50 @@ def test_op_attr():
assert log_op.get_attr("ftest") is None assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2 assert relay.op.get("exp").get_attr("ftest")(1) == 2
def test_op_reset_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Register fadd1 and fadd2 attributes.
relay.op.register("exp", "fadd1", add1)
relay.op.register("log", "fadd1", add1)
relay.op.register("log", "fadd2", add2)
# Reset log fadd1 attr.
log_op = relay.op.get("log")
log_op.reset_attr("fadd1")
# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None
# Check that fadd1 attr of other ops are intact.
assert relay.op.get("exp").get_attr("fadd1")(1) == 2
# Check that other attrs of the log op are intact.
assert relay.op.get("log").get_attr("fadd2")(1) == 3
def test_op_temp_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Set original attr value is add1.
relay.op.register("sqrt", "ftest", add1)
with TempOpAttr("sqrt", "ftest", add2):
# Check that the attr value is updated to add2.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 3
# Check that the attr value is recovered to add1.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2
def test_op_level1(): def test_op_level1():
x = relay.Var("x") x = relay.Var("x")
...@@ -47,5 +92,7 @@ def test_op_level3(): ...@@ -47,5 +92,7 @@ def test_op_level3():
if __name__ == "__main__": if __name__ == "__main__":
test_op_attr() test_op_attr()
test_op_reset_attr()
test_op_temp_attr()
test_op_level1() test_op_level1()
test_op_level3() test_op_level3()
...@@ -21,6 +21,14 @@ from tvm import relay ...@@ -21,6 +21,14 @@ from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import run_infer_type from tvm.relay.testing import run_infer_type
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.temp_op_attr import TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def legalize_qnn_conv2d(attrs, inputs, types):
return None
def get_ref_func(data, def get_ref_func(data,
kernel, kernel,
...@@ -173,6 +181,8 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, ...@@ -173,6 +181,8 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def test_no_zero_point(): def test_no_zero_point():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 1, 2, 4) data_shape = (2, 1, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -220,6 +230,8 @@ def test_no_zero_point(): ...@@ -220,6 +230,8 @@ def test_no_zero_point():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_kernel_zero_point(): def test_kernel_zero_point():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -268,6 +280,8 @@ def test_kernel_zero_point(): ...@@ -268,6 +280,8 @@ def test_kernel_zero_point():
def test_input_zero_point(): def test_input_zero_point():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -315,6 +329,8 @@ def test_input_zero_point(): ...@@ -315,6 +329,8 @@ def test_input_zero_point():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_both_zero_point(): def test_both_zero_point():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -362,6 +378,8 @@ def test_both_zero_point(): ...@@ -362,6 +378,8 @@ def test_both_zero_point():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_layout(): def test_layout():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 2, 4, 4) # NHWC data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -411,6 +429,8 @@ def test_layout(): ...@@ -411,6 +429,8 @@ def test_layout():
def test_padding(): def test_padding():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (1, 4, 2, 2) data_shape = (1, 4, 2, 2)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -458,6 +478,8 @@ def test_padding(): ...@@ -458,6 +478,8 @@ def test_padding():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_dilation(): def test_dilation():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 4, 4, 4) data_shape = (2, 4, 4, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -483,6 +505,8 @@ def test_dilation(): ...@@ -483,6 +505,8 @@ def test_dilation():
def test_const_folding(): def test_const_folding():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2) kernel_shape = (3, 4, 2, 2)
...@@ -511,6 +535,8 @@ def test_const_folding(): ...@@ -511,6 +535,8 @@ def test_const_folding():
assert "reshape" not in folded_func.astext() assert "reshape" not in folded_func.astext()
def test_kernel_size_1x1(): def test_kernel_size_1x1():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -536,6 +562,8 @@ def test_kernel_size_1x1(): ...@@ -536,6 +562,8 @@ def test_kernel_size_1x1():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def test_tflite_large_irregular(): def test_tflite_large_irregular():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (1, 1024, 1, 1) data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -571,6 +599,8 @@ def test_tflite_large_irregular(): ...@@ -571,6 +599,8 @@ def test_tflite_large_irregular():
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def test_tflite_output_multiplier_greater_than_one(): def test_tflite_output_multiplier_greater_than_one():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (2, 1, 2, 4) data_shape = (2, 1, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -617,6 +647,8 @@ def test_tflite_output_multiplier_greater_than_one(): ...@@ -617,6 +647,8 @@ def test_tflite_output_multiplier_greater_than_one():
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def test_tflite_anistropic_strides(): def test_tflite_anistropic_strides():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# uint8 input # uint8 input
data_shape = (1, 1, 3, 6) data_shape = (1, 1, 3, 6)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -656,6 +688,8 @@ def test_tflite_anistropic_strides(): ...@@ -656,6 +688,8 @@ def test_tflite_anistropic_strides():
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def test_broadcast_layout(): def test_broadcast_layout():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
# Test broadcast support for NHWC layout. # Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
......
...@@ -19,6 +19,15 @@ import tvm ...@@ -19,6 +19,15 @@ import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.testing.temp_op_attr import TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def legalize_qnn_dense(attrs, inputs, types):
return None
def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype):
...@@ -209,18 +218,24 @@ def qnn_dense_driver(test_configuration): ...@@ -209,18 +218,24 @@ def qnn_dense_driver(test_configuration):
def test_qnn_dense_without_bias(): def test_qnn_dense_without_bias():
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_without_bias_params = \ int32_output_without_bias_params = \
make_int_configuration(use_bias=False) make_int_configuration(use_bias=False)
qnn_dense_driver(int32_output_without_bias_params) qnn_dense_driver(int32_output_without_bias_params)
def test_qnn_dense_with_bias(): def test_qnn_dense_with_bias():
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_with_bias_params = \ int32_output_with_bias_params = \
make_int_configuration(use_bias=True) make_int_configuration(use_bias=True)
qnn_dense_driver(int32_output_with_bias_params) qnn_dense_driver(int32_output_with_bias_params)
def test_qnn_dense_with_requantized_output(): def test_qnn_dense_with_requantized_output():
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int8_requantized_output_with_bias_params = \ int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True) make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(int8_requantized_output_with_bias_params) qnn_dense_driver(int8_requantized_output_with_bias_params)
......
...@@ -18,9 +18,8 @@ ...@@ -18,9 +18,8 @@
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay import transform, analysis from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
...@@ -31,7 +30,6 @@ def run_opt_pass(expr, passes): ...@@ -31,7 +30,6 @@ def run_opt_pass(expr, passes):
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_alter_op(): def test_alter_op():
"""Test directly replacing an operator with a new one""" """Test directly replacing an operator with a new one"""
def before(): def before():
...@@ -45,8 +43,6 @@ def test_alter_op(): ...@@ -45,8 +43,6 @@ def test_alter_op():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=100)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32")) weight = relay.multiply(weight, relay.const(2.0, "float32"))
...@@ -63,6 +59,7 @@ def test_alter_op(): ...@@ -63,6 +59,7 @@ def test_alter_op():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
...@@ -80,17 +77,15 @@ def test_alter_return_none(): ...@@ -80,17 +77,15 @@ def test_alter_return_none():
called = [False] called = [False]
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.global_max_pool2d", level=101)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
called[0] = True called[0] = True
return None return None
with TempOpAttr("nn.global_max_pool2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(before(), transform.InferType())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
...@@ -114,8 +109,6 @@ def test_alter_layout(): ...@@ -114,8 +109,6 @@ def test_alter_layout():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=102)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -123,6 +116,7 @@ def test_alter_layout(): ...@@ -123,6 +116,7 @@ def test_alter_layout():
new_attrs['kernel_layout'] = 'OIHW16i' new_attrs['kernel_layout'] = 'OIHW16i'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
def expected(): def expected():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
bias = relay.var("bias", shape=(64,)) bias = relay.var("bias", shape=(64,))
...@@ -149,12 +143,11 @@ def test_alter_layout(): ...@@ -149,12 +143,11 @@ def test_alter_layout():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -183,14 +176,13 @@ def test_alter_layout_dual_path(): ...@@ -183,14 +176,13 @@ def test_alter_layout_dual_path():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=103)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
def expected(): def expected():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1') weight1 = relay.var('weight1')
...@@ -215,11 +207,10 @@ def test_alter_layout_dual_path(): ...@@ -215,11 +207,10 @@ def test_alter_layout_dual_path():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -245,14 +236,13 @@ def test_alter_layout_resnet(): ...@@ -245,14 +236,13 @@ def test_alter_layout_resnet():
y = relay.nn.global_max_pool2d(y) y = relay.nn.global_max_pool2d(y)
return relay.Function(analysis.free_vars(y), y) return relay.Function(analysis.free_vars(y), y)
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=104)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
def expected(): def expected():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1') weight1 = relay.var('weight1')
...@@ -274,11 +264,10 @@ def test_alter_layout_resnet(): ...@@ -274,11 +264,10 @@ def test_alter_layout_resnet():
y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.layout_transform(y, "NCHW16c", "NCHW")
return relay.Function(analysis.free_vars(y), y) return relay.Function(analysis.free_vars(y), y)
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -296,8 +285,6 @@ def test_alter_layout_broadcast_op(): ...@@ -296,8 +285,6 @@ def test_alter_layout_broadcast_op():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=105)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -323,12 +310,11 @@ def test_alter_layout_broadcast_op(): ...@@ -323,12 +310,11 @@ def test_alter_layout_broadcast_op():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -344,8 +330,6 @@ def test_alter_layout_scalar(): ...@@ -344,8 +330,6 @@ def test_alter_layout_scalar():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=106)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -368,26 +352,24 @@ def test_alter_layout_scalar(): ...@@ -368,26 +352,24 @@ def test_alter_layout_scalar():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate(): def test_alter_layout_concatenate():
""" NCHW, NHWC and corner case concatenate layout transform.""" """ NCHW, NHWC and corner case concatenate layout transform."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=107)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
# NCHW layout transformation. # NCHW layout transformation.
def before_nchw(): def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
...@@ -425,11 +407,10 @@ def test_alter_layout_concatenate(): ...@@ -425,11 +407,10 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw() a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -472,11 +453,10 @@ def test_alter_layout_concatenate(): ...@@ -472,11 +453,10 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc() a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -492,8 +472,6 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -492,8 +472,6 @@ def test_alter_layout_nchw_upsamping_op():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -512,12 +490,10 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -512,12 +490,10 @@ def test_alter_layout_nchw_upsamping_op():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, transform.AlterOpLayout())
transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -532,8 +508,6 @@ def test_alter_layout_strided_slice(): ...@@ -532,8 +508,6 @@ def test_alter_layout_strided_slice():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=109)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -551,12 +525,11 @@ def test_alter_layout_strided_slice(): ...@@ -551,12 +525,11 @@ def test_alter_layout_strided_slice():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -570,12 +543,11 @@ def test_alter_layout_depthwise_conv2d(): ...@@ -570,12 +543,11 @@ def test_alter_layout_depthwise_conv2d():
return y return y
import topi import topi
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
with tvm.target.create("llvm"): with tvm.target.create("llvm"):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
def expected(): def expected():
x = relay.var("x", shape=(1, 32, 56, 56)) x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3)) w = relay.var("w", shape=(32, 1, 3, 3))
...@@ -588,12 +560,11 @@ def test_alter_layout_depthwise_conv2d(): ...@@ -588,12 +560,11 @@ def test_alter_layout_depthwise_conv2d():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert(analysis.alpha_equal(a, b)) assert(analysis.alpha_equal(a, b))
...@@ -608,8 +579,6 @@ def test_alter_layout_prelu(): ...@@ -608,8 +579,6 @@ def test_alter_layout_prelu():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
...@@ -632,25 +601,23 @@ def test_alter_layout_prelu(): ...@@ -632,25 +601,23 @@ def test_alter_layout_prelu():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert(analysis.alpha_equal(a, b)) assert(analysis.alpha_equal(a, b))
def test_alter_layout_pad(): def test_alter_layout_pad():
""" Check NCHW, NHWC and corner case for pad layout conversion""" """ Check NCHW, NHWC and corner case for pad layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=112)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion. # Check NCHW conversion.
def before_nchw(): def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
...@@ -677,11 +644,10 @@ def test_alter_layout_pad(): ...@@ -677,11 +644,10 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw() a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -712,15 +678,14 @@ def test_alter_layout_pad(): ...@@ -712,15 +678,14 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc() a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that conversion does not happen when padding along split axis.. # Check that conversion does not happen when padding along split axis.
def before(): def before():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1') weight1 = relay.var('weight1')
...@@ -746,25 +711,23 @@ def test_alter_layout_pad(): ...@@ -746,25 +711,23 @@ def test_alter_layout_pad():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
b = expected()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_pool(): def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion""" """ Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=113)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion. # Check NCHW conversion.
def before_nchw(): def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
...@@ -791,11 +754,10 @@ def test_alter_layout_pool(): ...@@ -791,11 +754,10 @@ def test_alter_layout_pool():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw() a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -826,25 +788,23 @@ def test_alter_layout_pool(): ...@@ -826,25 +788,23 @@ def test_alter_layout_pool():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc() a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum(): def test_alter_layout_sum():
""" Check NCHW, NHWC sum layout conversion""" """ Check NCHW, NHWC sum layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=114)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs data, weight = inputs
new_attrs = dict(attrs) new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c' new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs) return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion. # Check NCHW conversion.
def before_nchw(): def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56)) x = relay.var("x", shape=(1, 64, 56, 56))
...@@ -871,11 +831,10 @@ def test_alter_layout_sum(): ...@@ -871,11 +831,10 @@ def test_alter_layout_sum():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nchw() a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
...@@ -907,19 +866,16 @@ def test_alter_layout_sum(): ...@@ -907,19 +866,16 @@ def test_alter_layout_sum():
y = relay.Function(analysis.free_vars(ret), ret) y = relay.Function(analysis.free_vars(ret), ret)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc() a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nhwc_nchw_arm(): def test_alter_layout_nhwc_nchw_arm():
""" Check NHWC to NHCW conversion for a small sequence of ops.""" """ Check NHWC to NHCW conversion for a small sequence of ops."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=115)
def alter_conv2d(attrs, inputs, tinfos): def alter_conv2d(attrs, inputs, tinfos):
from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay) return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
...@@ -968,11 +924,10 @@ def test_alter_layout_nhwc_nchw_arm(): ...@@ -968,11 +924,10 @@ def test_alter_layout_nhwc_nchw_arm():
y = relay.Function(analysis.free_vars(y), y) y = relay.Function(analysis.free_vars(y), y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before_nhwc() a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......
...@@ -20,8 +20,8 @@ import tvm ...@@ -20,8 +20,8 @@ import tvm
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
...@@ -46,7 +46,6 @@ def test_legalize(): ...@@ -46,7 +46,6 @@ def test_legalize():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, types): def legalize_conv2d(attrs, inputs, types):
data, weight = inputs data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32")) weight = relay.multiply(weight, relay.const(2.0, "float32"))
...@@ -63,6 +62,7 @@ def test_legalize(): ...@@ -63,6 +62,7 @@ def test_legalize():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
...@@ -79,16 +79,15 @@ def test_legalize_none(): ...@@ -79,16 +79,15 @@ def test_legalize_none():
called = [False] called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, types): def legalize_conv2d(attrs, inputs, types):
called[0] = True called[0] = True
return None return None
with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
a = before() a = before()
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
...@@ -105,14 +104,12 @@ def test_legalize_multiple_ops(): ...@@ -105,14 +104,12 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
@register_legalize("nn.conv2d", level=102)
def legalize_conv2d(attrs, inputs, types): def legalize_conv2d(attrs, inputs, types):
data, weight = inputs data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32")) weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs) return relay.nn.conv2d(data, weight, **attrs)
@register_legalize("nn.relu", level=103) def legalize_relu(attrs, inputs, types):
def legalize_conv2d(attrs, inputs, types):
data = inputs[0] data = inputs[0]
add = relay.add(tvm.relay.const(0, "float32"), data) add = relay.add(tvm.relay.const(0, "float32"), data)
return relay.nn.relu(add) return relay.nn.relu(add)
...@@ -130,6 +127,8 @@ def test_legalize_multiple_ops(): ...@@ -130,6 +127,8 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y) y = relay.Function([x, weight], y)
return y return y
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
a = before() a = before()
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
...@@ -147,7 +146,6 @@ def test_legalize_multi_input(): ...@@ -147,7 +146,6 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func) func = relay.Function([x, y, z], func)
return func return func
@register_legalize("concatenate", level=104)
def legalize_concatenate(attrs, inputs, types): def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled. # Check that the correct multi-input case is handled.
assert len(inputs) == 1 assert len(inputs) == 1
...@@ -165,6 +163,8 @@ def test_legalize_multi_input(): ...@@ -165,6 +163,8 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func) func = relay.Function([x, y, z], func)
return func return func
with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
a = before() a = before()
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
......
...@@ -20,8 +20,8 @@ import tvm ...@@ -20,8 +20,8 @@ import tvm
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def alpha_equal(x, y): def alpha_equal(x, y):
""" """
...@@ -54,7 +54,6 @@ def test_qnn_legalize(): ...@@ -54,7 +54,6 @@ def test_qnn_legalize():
y = relay.Function([x], y) y = relay.Function([x], y)
return y return y
@register_qnn_legalize("qnn.requantize", level=100)
def legalize_qnn_requantize(attrs, inputs, types): def legalize_qnn_requantize(attrs, inputs, types):
data = inputs[0] data = inputs[0]
data = relay.add(relay.const(0, 'int8'), data) data = relay.add(relay.const(0, 'int8'), data)
...@@ -80,6 +79,8 @@ def test_qnn_legalize(): ...@@ -80,6 +79,8 @@ def test_qnn_legalize():
a = before() a = before()
with TempOpAttr("qnn.requantize", "FTVMQnnLegalize", legalize_qnn_requantize):
# Check that Relay Legalize does not change the graph. # Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize()) a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType()) b = run_opt_pass(before(), transform.InferType())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment