Unverified Commit 3f2abfbc by Zhi Committed by GitHub

[relay] Relay annotation and partitioning for external compilers (#4570)

* [relay] Relay annotation and partitioning for codegen

* Add fusion unit test

* fix comments

* Update include/tvm/relay/attrs/annotation.h

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

* rebase

* remove annotation helper

* rebase again

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: 雾雨魔理沙 <lolisa@marisa.moe>
parent d7d2a9b3
......@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};
/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief A 3rd party compiler for code generation. */
std::string compiler;
TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("A 3rd party compiler used for code generation.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
......@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
......@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* \param args The input symbols of the original node.
* \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
......
......@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/
TVM_DLL Pass PrintIR(bool show_meta_data = true);
/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();
} // namespace transform
/*!
......
......@@ -62,6 +62,7 @@ def stop_fusion(data):
"""
return _make.stop_fusion(data)
def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
......@@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective)
def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)
def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
......@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
return _transform.PrintIR(show_meta_data)
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
......
......@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
if (ref->IsInstance<FunctionNode>()) {
GenDNNLFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
} else if (ref->IsInstance<IRModuleNode>()) {
IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) {
GenDNNLFunc(Downcast<Function>(it.second));
}
......
......@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
return outputs;
});
RELAY_REGISTER_OP("annotation.compiler_begin")
.describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("annotation.compiler_end")
.describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
} // namespace relay
} // namespace tvm
......@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>();
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment