Unverified Commit 3f2abfbc by Zhi Committed by GitHub

[relay] Relay annotation and partitioning for external compilers (#4570)

* [relay] Relay annotation and partitioning for codegen

* Add fusion unit test

* fix comments

* Update include/tvm/relay/attrs/annotation.h

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

* rebase

* remove annotation helper

* rebase again

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: 雾雨魔理沙 <lolisa@marisa.moe>
parent d7d2a9b3
...@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> { ...@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
} }
}; };
/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief A 3rd party compiler for code generation. */
std::string compiler;
TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("A 3rd party compiler used for code generation.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_ #endif // TVM_RELAY_ATTRS_ANNOTATION_H_
...@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked * operator with other expressions. This function will be invoked
* in AlterOpLayout pass. * in AlterOpLayout pass.
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node. * \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape * \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs. * and dtype of the inputs.
* \return new_expr The modified expression. * \return new_expr The modified expression.
...@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< ...@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be * \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass. * invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node. * \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape * \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs. * and dtype of the inputs.
* \return new_expr The modified expression. * \return new_expr The modified expression.
*/ */
......
...@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); ...@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/ */
TVM_DLL Pass PrintIR(bool show_meta_data = true); TVM_DLL Pass PrintIR(bool show_meta_data = true);
/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();
} // namespace transform } // namespace transform
/*! /*!
......
...@@ -62,6 +62,7 @@ def stop_fusion(data): ...@@ -62,6 +62,7 @@ def stop_fusion(data):
""" """
return _make.stop_fusion(data) return _make.stop_fusion(data)
def checkpoint(data): def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization. """Annotate an expression to be a checkpoint for the checkpointing memory optimization.
...@@ -78,3 +79,43 @@ def checkpoint(data): ...@@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data) return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective) register_schedule("annotation.checkpoint", schedule_injective)
def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)
def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
...@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True): ...@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
return _transform.PrintIR(show_meta_data) return _transform.PrintIR(show_meta_data)
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()
def gradient(expr, mod=None, mode='higher_order'): def gradient(expr, mod=None, mode='higher_order'):
""" """
Transform the input function, Transform the input function,
......
...@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { ...@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
if (ref->IsInstance<FunctionNode>()) { if (ref->IsInstance<FunctionNode>()) {
GenDNNLFunc(Downcast<Function>(ref)); GenDNNLFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) { } else if (ref->IsInstance<IRModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref); IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
GenDNNLFunc(Downcast<Function>(it.second)); GenDNNLFunc(Downcast<Function>(it.second));
} }
......
...@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization. ...@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
return outputs; return outputs;
}); });
RELAY_REGISTER_OP("annotation.compiler_begin")
.describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("annotation.compiler_end")
.describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will // Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression. // need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque; OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>(); if (const OpNode* opnode = call->op.as<OpNode>()) {
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]); op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else { } else {
this->Update(call->op, node, kOpaque); this->Update(call->op, node, kOpaque);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment