Commit 79922bd3 by Animesh Jain Committed by Yizhi Liu

[Relay] Legalize pass (#3672)

* [Relay] Rewrite pass.

This pass transforms an expression to other expression.

This pass has many usecases
 * Replace a expr to another expr, if the other expr has faster performance.
 * For ASICs, we might want to modify the inputs to adapt to the HW support.
 * Alter op layout can work in conjunction with this pass.

The supporting usecase is the Intel i8 x i8 conv. Intel HW supports u8 x i8 conv
in HW. Using this pass, we can replace an i8 x i8 conv to a sequence of
operators where one of the operators is now u8 x i8 conv. This will also help
automatic quantizaion performance.

* Better API name.

* Removing the conv2d legalization for x86. Will send a separate PR.

* Test name changes.

* Registering one funtion to register FTVMLegalize.

* Better comments.
parent 831b32e7
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file nnvm/compiler/op_attr_types.h * \file tvm/relay/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction. * \brief The Expr and related elements in DataFlow construction.
*/ */
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_ #ifndef TVM_RELAY_OP_ATTR_TYPES_H_
...@@ -128,6 +128,20 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< ...@@ -128,6 +128,20 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
const Array<Tensor>& tinfos)>; const Array<Tensor>& tinfos)>;
/*! /*!
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
using FTVMLegalize = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<tvm::relay::Type>& arg_types)>;
/*!
* \brief Forward rewriting rule for a specific op. * \brief Forward rewriting rule for a specific op.
* *
* \param ref_call The reference old call type to be rewritten. * \param ref_call The reference old call type to be rewritten.
......
...@@ -521,6 +521,13 @@ TVM_DLL Pass CanonicalizeOps(); ...@@ -521,6 +521,13 @@ TVM_DLL Pass CanonicalizeOps();
TVM_DLL Pass AlterOpLayout(); TVM_DLL Pass AlterOpLayout();
/*! /*!
* \brief Legalizes an expr with another expression.
*
* \return The pass.
*/
TVM_DLL Pass Legalize();
/*!
* \brief Canonicalize cast expressions to make operator fusion more efficient. * \brief Canonicalize cast expressions to make operator fusion more efficient.
* *
* \return The pass. * \return The pass.
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
"""Relay core operators.""" """Relay core operators."""
# operator defs # operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \ from .op import get, register, register_schedule, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, schedule_injective, Op, OpPattern, debug register_pattern, register_alter_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug
# Operators # Operators
from .reduce import * from .reduce import *
......
...@@ -204,6 +204,10 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): ...@@ -204,6 +204,10 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
return None
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
...@@ -170,6 +170,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): ...@@ -170,6 +170,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return register(op_name, "FTVMAlterOpLayout", alter_layout, level) return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op
Parameters
----------
op_name : str
The name of the operator
legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for transforming an expr to another expr.
level : int
The priority level
"""
return register(op_name, "FTVMLegalize", legal_op, level)
def register_pattern(op_name, pattern, level=10): def register_pattern(op_name, pattern, level=10):
"""Register operator pattern for an op. """Register operator pattern for an op.
......
...@@ -437,6 +437,21 @@ def AlterOpLayout(): ...@@ -437,6 +437,21 @@ def AlterOpLayout():
return _transform.AlterOpLayout() return _transform.AlterOpLayout()
def Legalize():
"""Legalizes an expression with another expression.
This pass can be used to replace an expr with another expr for target
dependent optimizations. For example, one expr, though semnatically
equivalent to the other, can have better performance on a target. This pass
can be used to legalize the expr in a target-dependent manner.
Returns
-------
ret : tvm.relay.Pass
The registered pass that rewrites an expr.
"""
return _transform.Legalize()
def RewriteAnnotatedOps(fallback_device): def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g. """Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to. `on_deivce`, mark which device an expression should be scheduled to.
......
...@@ -304,6 +304,11 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -304,6 +304,11 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps()); pass_seqs.push_back(transform::CanonicalizeOps());
// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}
// Alter layout transformation is only applied to homogeneous execution yet. // Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) { if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout()); pass_seqs.push_back(transform::AlterOpLayout());
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file legalize.cc
* \brief Converts an expr to another expr. This pass can be used to transform an op based on its
* shape, dtype or layout to another op or a sequence of ops.
*/
#include <tvm/operation.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace legalize {
// Call registered FTVMLegalize of an op
// Returns the legalized expression
Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) {
static auto fop_legalize = Op::GetAttr<FTVMLegalize>("FTVMLegalize");
Op op = Downcast<Op>(ref_call->op);
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types;
for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type());
}
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
}
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
}
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call) << "Can only replace the original operator with another call node";
return GetRef<Call>(new_call);
}
Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); }
} // namespace legalize
namespace transform {
Pass Legalize() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f));
};
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
}
TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
} // namespace transform
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
import tvm
from tvm import relay
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = relay.Module.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_legalize():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_none():
"""Test doing nothing by returning 'None' """
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
y = relay.nn.global_max_pool2d(x)
y = relay.Function([x], y)
return y
called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
called[0] = True
return None
a = before()
a = run_opt_pass(a, transform.Legalize())
b = before()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
def test_legalize_multi_input():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
y = relay.var("y", shape=(1, 64, 56, 20))
z = relay.var("z", shape=(1, 64, 56, 10))
func = relay.concatenate([x, y, z], axis=3)
func = relay.Function([x, y, z], func)
return func
@register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1
assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
return None
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
y = relay.var("y", shape=(1, 64, 56, 20))
z = relay.var("z", shape=(1, 64, 56, 10))
func = relay.concatenate([x, y, z], axis=3)
func = relay.Function([x, y, z], func)
return func
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multi_input()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment