Commit 73dda6be by Animesh Jain Committed by Yizhi Liu

[Relay] Convert Layout Pass. (#4335)

parent 641024f5
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -133,6 +134,22 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< ...@@ -133,6 +134,22 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
const Array<Tensor>& tinfos)>; const Array<Tensor>& tinfos)>;
/*! /*!
* \brief Convert the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in ConvertLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \param desired_layout The desired layout.
* \return new_expr The modified expression.
*/
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be * \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass. * invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
......
...@@ -533,6 +533,26 @@ TVM_DLL Pass CanonicalizeOps(); ...@@ -533,6 +533,26 @@ TVM_DLL Pass CanonicalizeOps();
TVM_DLL Pass AlterOpLayout(); TVM_DLL Pass AlterOpLayout();
/*! /*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
* at the start and one at the end.
*
* This pass is not a part of relay.build and is expected to be called between framework-relay
* parser and relay.build call. This is very helpful for hardware backends that support/prefer only
* type of data layout.
*
* RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
*
* This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new
* layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
* using the InferCorrectLayout infrastructure.
*
* \param desired_layout The desired layout.
* \return The pass.
*/
TVM_DLL Pass ConvertLayout(const std::string& desired_layout);
/*!
* \brief Legalizes an expr with another expression. * \brief Legalizes an expr with another expression.
* \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function. * \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
* One can collect and isolate similar type of legalize transformations using this param. For * One can collect and isolate similar type of legalize transformations using this param. For
......
...@@ -251,6 +251,47 @@ def legalize_conv2d(attrs, inputs, types): ...@@ -251,6 +251,47 @@ def legalize_conv2d(attrs, inputs, types):
""" """
return topi.nn.conv2d_legalize(attrs, inputs, types) return topi.nn.conv2d_legalize(attrs, inputs, types)
@reg.register_convert_op_layout("nn.conv2d")
def convert_conv2d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
from tvm import relay
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
if data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return None
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
...@@ -196,6 +196,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): ...@@ -196,6 +196,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return register(op_name, "FTVMAlterOpLayout", alter_layout, level) return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
def register_convert_op_layout(op_name, convert_layout=None, level=10):
"""Register convert op layout function for an op
Parameters
----------
op_name : str
The name of the operator
convert_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMConvertOpLayout", convert_layout, level)
def register_legalize(op_name, legal_op=None, level=10): def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op """Register legal transformation function for an op
......
...@@ -460,6 +460,34 @@ def AlterOpLayout(): ...@@ -460,6 +460,34 @@ def AlterOpLayout():
return _transform.AlterOpLayout() return _transform.AlterOpLayout()
def ConvertLayout(desired_layout):
""" Given a dest layout, this pass transforms the expr such that most of the ops input data
layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
one at the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay
parser and relay.build call. This is very helpful for hardware backends that support/prefer only
type of data layout.
RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define
new layouts for conv2d ops for now. Most of the other operators try to adapt to their input
layout using the InferCorrectLayout infrastructure.
Parameters
----------
desired_layout : str
The desired layout for the transformed expr.
Returns
-------
pass: FunctionPass
The pass.
"""
return _transform.ConvertLayout(desired_layout)
def Legalize(legalize_map_attr_name="FTVMLegalize"): def Legalize(legalize_map_attr_name="FTVMLegalize"):
"""Legalizes an expression with another expression. """Legalizes an expression with another expression.
This pass can be used to replace an expr with another expr for target This pass can be used to replace an expr with another expr for target
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/alter_op_layout.h" #include "../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <tvm/relay/attrs/memory.h> #include <tvm/relay/attrs/memory.h>
#include "../op_common.h" #include "../op_common.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/attrs/bitserial.h> #include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
#include "../op_common.h" #include "../op_common.h"
#include "convolution.h" #include "convolution.h"
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "../type_relations.h" #include "../type_relations.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
#include "../op_common.h" #include "../op_common.h"
#include "nn.h" #include "nn.h"
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/alter_op_layout.h" #include "../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../../../arithmetic/compute_expr.h" #include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h" #include "../../pass/infer_layout_util.h"
#include "../../pass/pattern_util.h" #include "../../pass/pattern_util.h"
#include "transform.h" #include "transform.h"
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file convert_op_layout.cc
* \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/operation.h>
#include <tuple>
#include <vector>
#include <functional>
#include <string>
#include <utility>
#include <unordered_map>
#include "transform_layout.h"
#include "pattern_util.h"
namespace tvm {
namespace relay {
namespace convert_op_layout {
/*!
* \brief Container for the transformations for ConvertLayout.
*/
class ConvertTransformMemorizerNode : public TransformMemorizerNode {
public:
/*!
* \brief Initializes the desired_layout.
* \param desired_layout The desired layout.
*/
explicit ConvertTransformMemorizerNode(const std::string& desired_layout)
: desired_layout_(desired_layout) {}
/*! \brief The desired layout for the Convert Layout pass */
std::string desired_layout_;
};
/*!
* \brief Container that provides the transformation function for convert layout.
*/
class ConvertTransformMemorizer : public TransformMemorizer {
public:
ConvertTransformMemorizer() {}
explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}
ConvertTransformMemorizerNode* operator->() {
return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
}
/*!
* \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the
* desired layout as specified by the user.
* \param ref_call The original call.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
static auto fconvert_layout = Op::GetAttr<FTVMConvertOpLayout>("FTVMConvertOpLayout");
Op op = Downcast<Op>(ref_call->op);
Expr new_e;
bool modified = false;
if (fconvert_layout.count(op)) {
tvm::Array<tvm::Tensor> tinfos;
for (auto expr : ref_call->args) {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
}
Expr altered_value =
fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_);
if (altered_value.defined()) {
new_e = altered_value;
modified = true;
}
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
}
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call) << "Can only replace the original operator with another call node";
return GetRef<Call>(new_call);
}
using ContainerType = ConvertTransformMemorizerNode;
};
/*!
* Limitations:
* 1. The altered op should have the same number of arguments as the previous one.
* 2. Do not support nested tuple arguments.
*/
Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) {
ConvertTransformMemorizer transformMemorizer(
make_node<ConvertTransformMemorizerNode>(desired_layout));
auto fcontext = [&](const Call& call) -> NodeRef { return transformMemorizer; };
return ForwardRewrite(expr, LayoutRewriter<ConvertTransformMemorizer>, fcontext);
}
} // namespace convert_op_layout
namespace transform {
Pass ConvertLayout(const std::string& desired_layout) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"),
ir::StringImm::make("CanonicalizeOps")});
}
TVM_REGISTER_API("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
} // namespace transform
} // namespace relay
} // namespace tvm
...@@ -18,18 +18,20 @@ ...@@ -18,18 +18,20 @@
*/ */
/*! /*!
* \file alter_op_layout.h * \file infer_layout_util.h
* \brief Alternate the layouts of operators or replace primitive operators with * \brief Utility functions to alter the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation. custom layouts or other general weight pre-transformation.
*/ */
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #ifndef TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #define TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include <tuple>
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -193,7 +195,40 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, ...@@ -193,7 +195,40 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
} }
} }
/*!
* Call registered FInferCorrectLayout of an op.
* Parameters are the same as the parameters for FInferCorrectLayout
* Returns inferred_input_layout, inferred_output_layout, success
*/
static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
Array<Array<Layout>> inferred_layouts;
inferred_layouts =
finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_shapes);
CHECK_EQ(inferred_layouts.size(), 2)
<< "FInferCorrectLayout should return an array with size of 2";
for (auto x : inferred_layouts) {
for (auto y : x) {
if (!y.defined()) { // inference fails
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
}
}
return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true);
} else {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ #endif // TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment