Commit 26eb4053 by Animesh Jain Committed by Tianqi Chen

[Relay tests] AlterOpLayout - Temporary attr update (#4357)

parent f1d6f335
......@@ -258,6 +258,12 @@ class OpRegistry {
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);
/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline void reset_attr(const std::string& attr_name);
// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
......
......@@ -64,6 +64,16 @@ class Op(Expr):
"""
_OpSetAttr(self, attr_name, value, plevel)
def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_OpResetAttr(self, attr_name)
def get(op_name):
"""Get the Op for a given name
......
......@@ -37,6 +37,7 @@ from . import squeezenet
from . import vgg
from . import densenet
from . import yolo_detection
from . import temp_op_attr
from .config import ctx_list
from .init import create_workload
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Defines a TempOpAttr class that allows temporarily changing an attr of the
operator to allow unit testing. This is useful for AlterOpLayout and Legalize
tests."""
from tvm import relay
class TempOpAttr(object):
""" Temporarily changes the attr of an op. """
def __init__(self, op_name, attr_key, attr_value):
""" Saves the required info for RAII pattern usage.
Parameters
----------
op_name : str
The op name.
attr_key : str
The attribute name.
attr_value : object
The attribute value.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
self.op = relay.op.get(op_name)
self.attr_key = attr_key
self.attr_value = attr_value
def __enter__(self):
self.older_attr = self.op.get_attr(self.attr_key)
self.op.reset_attr(self.attr_key)
self.op.set_attr(self.attr_key, self.attr_value)
return self
def __exit__(self, ptype, value, trace):
self.op.reset_attr(self.attr_key)
if self.older_attr:
self.op.set_attr(self.attr_key, self.older_attr)
......@@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) {
return true;
}
// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
uint32_t index = op_->index_;
if (op_map->data_.size() > index) {
op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
}
}
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
......@@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
CHECK(value.type_code() != kNull)
<< "Registered packed_func is Null for " << key
<< " of operator " << this->name;
if (p.second < plevel && value.type_code() != kNull) {
op_map->data_[index] = std::make_pair(value, plevel);
}
}
......@@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg.set_attr(attr_name, value, plevel);
});
TVM_REGISTER_API("relay.op._OpResetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
reg.reset_attr(attr_name);
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
def test_op_attr():
log_op = relay.op.get("log")
......@@ -27,6 +28,50 @@ def test_op_attr():
assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2
def test_op_reset_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Register fadd1 and fadd2 attributes.
relay.op.register("exp", "fadd1", add1)
relay.op.register("log", "fadd1", add1)
relay.op.register("log", "fadd2", add2)
# Reset log fadd1 attr.
log_op = relay.op.get("log")
log_op.reset_attr("fadd1")
# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None
# Check that fadd1 attr of other ops are intact.
assert relay.op.get("exp").get_attr("fadd1")(1) == 2
# Check that other attrs of the log op are intact.
assert relay.op.get("log").get_attr("fadd2")(1) == 3
def test_op_temp_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1
def add2(x):
return x + 2
# Set original attr value is add1.
relay.op.register("sqrt", "ftest", add1)
with TempOpAttr("sqrt", "ftest", add2):
# Check that the attr value is updated to add2.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 3
# Check that the attr value is recovered to add1.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2
def test_op_level1():
x = relay.Var("x")
......@@ -47,5 +92,7 @@ def test_op_level3():
if __name__ == "__main__":
test_op_attr()
test_op_reset_attr()
test_op_temp_attr()
test_op_level1()
test_op_level3()
......@@ -19,6 +19,15 @@ import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.temp_op_attr import TempOpAttr
# We use llvm target for testing functionality. `llvm` points to an older Intel
# generation machine, that legalizes to a simple lowering. Therefore, the
# legalization is overwritten such that it can be skipped and we use the
# QNNCanonicalizeOps lowering for the testing.
def legalize_qnn_dense(attrs, inputs, types):
return None
def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype):
......@@ -209,21 +218,27 @@ def qnn_dense_driver(test_configuration):
def test_qnn_dense_without_bias():
int32_output_without_bias_params = \
make_int_configuration(use_bias=False)
qnn_dense_driver(int32_output_without_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_without_bias_params = \
make_int_configuration(use_bias=False)
qnn_dense_driver(int32_output_without_bias_params)
def test_qnn_dense_with_bias():
int32_output_with_bias_params = \
make_int_configuration(use_bias=True)
qnn_dense_driver(int32_output_with_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int32_output_with_bias_params = \
make_int_configuration(use_bias=True)
qnn_dense_driver(int32_output_with_bias_params)
def test_qnn_dense_with_requantized_output():
int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(int8_requantized_output_with_bias_params)
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
int8_requantized_output_with_bias_params = \
make_int_configuration(use_bias=True, requantize_output=True)
qnn_dense_driver(int8_requantized_output_with_bias_params)
if __name__ == "__main__":
......
......@@ -20,8 +20,8 @@ import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
......@@ -46,7 +46,6 @@ def test_legalize():
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
......@@ -63,9 +62,10 @@ def test_legalize():
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -79,16 +79,15 @@ def test_legalize_none():
called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None
a = before()
a = run_opt_pass(a, transform.Legalize())
with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
......@@ -105,14 +104,12 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=102)
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
@register_legalize("nn.relu", level=103)
def legalize_conv2d(attrs, inputs, types):
def legalize_relu(attrs, inputs, types):
data = inputs[0]
add = relay.add(tvm.relay.const(0, "float32"), data)
return relay.nn.relu(add)
......@@ -130,9 +127,11 @@ def test_legalize_multiple_ops():
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -147,7 +146,6 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func)
return func
@register_legalize("concatenate", level=104)
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
......@@ -165,9 +163,11 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func)
return func
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......
......@@ -20,8 +20,8 @@ import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr
def alpha_equal(x, y):
"""
......@@ -54,7 +54,6 @@ def test_qnn_legalize():
y = relay.Function([x], y)
return y
@register_qnn_legalize("qnn.requantize", level=100)
def legalize_qnn_requantize(attrs, inputs, types):
data = inputs[0]
data = relay.add(relay.const(0, 'int8'), data)
......@@ -80,15 +79,17 @@ def test_qnn_legalize():
a = before()
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
with TempOpAttr("qnn.requantize", "FTVMQnnLegalize", legalize_qnn_requantize):
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment