Unverified Commit 3f2abfbc by Zhi Committed by GitHub

[relay] Relay annotation and partitioning for external compilers (#4570)

* [relay] Relay annotation and partitioning for codegen

* Add fusion unit test

* fix comments

* Update include/tvm/relay/attrs/annotation.h

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

* rebase

* remove annotation helper

* rebase again

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: 雾雨魔理沙 <lolisa@marisa.moe>
parent d7d2a9b3
...@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> { ...@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
} }
}; };
/*!
* \brief Options for the operators used to annotate a compiler.
*/
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
/*! \brief A 3rd party compiler for code generation. */
std::string compiler;
TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("A 3rd party compiler used for code generation.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_ #endif // TVM_RELAY_ATTRS_ANNOTATION_H_
...@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -123,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked * operator with other expressions. This function will be invoked
* in AlterOpLayout pass. * in AlterOpLayout pass.
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node. * \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape * \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs. * and dtype of the inputs.
* \return new_expr The modified expression. * \return new_expr The modified expression.
...@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< ...@@ -153,8 +153,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be * \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass. * invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node. * \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape * \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs. * and dtype of the inputs.
* \return new_expr The modified expression. * \return new_expr The modified expression.
*/ */
......
...@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); ...@@ -310,6 +310,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*/ */
TVM_DLL Pass PrintIR(bool show_meta_data = true); TVM_DLL Pass PrintIR(bool show_meta_data = true);
/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();
} // namespace transform } // namespace transform
/*! /*!
......
...@@ -62,6 +62,7 @@ def stop_fusion(data): ...@@ -62,6 +62,7 @@ def stop_fusion(data):
""" """
return _make.stop_fusion(data) return _make.stop_fusion(data)
def checkpoint(data): def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization. """Annotate an expression to be a checkpoint for the checkpointing memory optimization.
...@@ -78,3 +79,43 @@ def checkpoint(data): ...@@ -78,3 +79,43 @@ def checkpoint(data):
return _make.checkpoint(data) return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective) register_schedule("annotation.checkpoint", schedule_injective)
def compiler_begin(data, compiler):
"""Annotate an expression to indicate that it is the beginning of
a regeion that will be handled by the given compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_begin(data, compiler)
def compiler_end(data, compiler):
"""Annotate an expression to indicate that it is the end of a region that
is handled by the provided compiler.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
compiler : Str
The compiler used to generate code of the annotated region.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.compiler_end(data, compiler)
...@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True): ...@@ -663,6 +663,18 @@ def PrintIR(show_meta_data=True):
return _transform.PrintIR(show_meta_data) return _transform.PrintIR(show_meta_data)
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()
def gradient(expr, mod=None, mode='higher_order'): def gradient(expr, mod=None, mode='higher_order'):
""" """
Transform the input function, Transform the input function,
......
...@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { ...@@ -270,8 +270,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
if (ref->IsInstance<FunctionNode>()) { if (ref->IsInstance<FunctionNode>()) {
GenDNNLFunc(Downcast<Function>(ref)); GenDNNLFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) { } else if (ref->IsInstance<IRModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref); IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
GenDNNLFunc(Downcast<Function>(it.second)); GenDNNLFunc(Downcast<Function>(it.second));
} }
......
...@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization. ...@@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization.
return outputs; return outputs;
}); });
RELAY_REGISTER_OP("annotation.compiler_begin")
.describe(R"code(
Beginning of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("annotation.compiler_end")
.describe(R"code(
End of a region that is handled by a given compiler.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -242,8 +242,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// Finally if the operator position is not a call node we will // Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression. // need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque; OpPatternKind op_pattern = kOpaque;
const OpNode* opnode = call->op.as<OpNode>(); if (const OpNode* opnode = call->op.as<OpNode>()) {
if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]); op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else { } else {
this->Update(call->op, node, kOpaque); this->Update(call->op, node, kOpaque);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/relay/pass/partition_graph.cc
*
* \brief Partition an input function into multiple functions according based
* on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
* These nodes are used as boundaries to partition the Relay function into
* multiple regions that can be offloaded to different accelerators/backends.
*
* Each of these paritioned functions, a.k.a subgraphs, will be viewed as
* external functions, and they will use the provided compiler for codegen.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tvm {
namespace relay {
namespace partitioning {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*!
* \brief The subgraph properties for partitioning.
*/
struct Subgraph {
/*! \brief The subgraph ID. */
int id;
/*! \brief The input arguments of this subgraph. */
std::vector<std::pair<Var, Expr>> args;
/*! \brief Nodes in this subgraph. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
};
/*!
* \brief The checker that verifies if a Relay program is annotated correctly
* for partitioning.
*/
class AnnotationChecker : public ExprVisitor {
public:
bool Check() {
if (!found_start_ && !found_end_) {
LOG(WARNING) << "No compiler annotation found";
} else if (!found_start_) {
LOG(ERROR) << "compiler_begin annotation is missing";
return false;
} else if (!found_end_) {
LOG(ERROR) << "compiler_end annotation is missing";
return false;
}
return true;
}
void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
} else if (call->op == compiler_begin_op) {
found_start_ = true;
} else if (call->op == compiler_end_op) {
found_end_ = true;
}
}
private:
bool found_start_{false};
bool found_end_{false};
};
/*! \brief This class partitions the expr labeled with begin and end annoations
* into function containing multiple regions. Each region is labeled with
* a compiler attribute so that it will be handled by any compilers that are not
* in the TVM stack.
*
* TODO(@zhiics) This following algorithm is not adequate to handle all cases,
* i.e. multiple `compiler_end` nodes.
*/
class Partitioner : public ExprMutator {
public:
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) {
return candidate;
}
}
return nullptr;
}
void MergeSubgraph(std::shared_ptr<Subgraph> subgraph1,
std::shared_ptr<Subgraph> subgraph2) {
if (subgraph1 == subgraph2) {
return;
}
// Merge subgraph 2 to subgraph 1 and erase subgraph 2.
subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end());
for (auto arg : subgraph2->args) {
subgraph1->args.push_back(arg);
}
this->subgraphs_.erase(subgraph2);
}
void AddToSubgraph(std::shared_ptr<Subgraph> subgraph, const Expr expr) {
auto subgraph2 = GetSubgraph(expr);
if (subgraph2) {
MergeSubgraph(subgraph, subgraph2);
} else {
subgraph->nodes.insert(expr);
}
}
Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propogate subgraph to arguments
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (subgraph) {
for (auto arg : call->args) {
AddToSubgraph(subgraph, arg);
}
}
return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
auto input_expr = VisitExpr(call->args[0]);
// Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
input_expr->checked_type_);
// Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call));
if (!subgraph) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding subgraph for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
subgraph->args.push_back({var, input_expr});
return std::move(var);
} else {
CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// Check if the argument already belongs to an exist subgraph
auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
subgraph = *ret.first;
subgraph->nodes.insert(call->args[0]);
subgraph->id = this->subgraph_id_++;
}
subgraph->nodes.insert(GetRef<Call>(call));
// Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]);
Array<Var> params;
Array<Expr> args;
// The subgraph may be merged so we need to update it again.
subgraph = GetSubgraph(GetRef<Call>(call));
CHECK(subgraph);
for (auto pair : subgraph->args) {
params.push_back(pair.first);
args.push_back(pair.second);
}
auto subgraph_func =
FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::ir::StringImmNode::make(compiler_attrs->compiler));
return CallNode::make(subgraph_func, args);
}
}
Expr VisitExpr_(const TupleNode* op) final {
auto subgraph = GetSubgraph(GetRef<Tuple>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
for (auto field : op->fields) {
AddToSubgraph(subgraph, field);
}
Array<Expr> fields;
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
}
return TupleNode::make(fields);
}
}
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto subgraph = GetSubgraph(GetRef<TupleGetItem>(g));
if (!subgraph) {
return ExprMutator::VisitExpr_(g);
} else {
AddToSubgraph(subgraph, g->tuple);
auto t = VisitExpr(g->tuple);
return TupleGetItemNode::make(t, g->index);
}
}
Expr VisitExpr_(const FunctionNode* op) final {
auto subgraph = GetSubgraph(GetRef<Function>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Var> params;
for (auto param : op->params) {
AddToSubgraph(subgraph, param);
}
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param);
}
auto body = VisitExpr(op->body);
return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs);
}
}
Expr VisitExpr_(const LetNode* op) final {
auto subgraph = GetSubgraph(GetRef<Let>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->var);
AddToSubgraph(subgraph, op->value);
AddToSubgraph(subgraph, op->body);
Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return LetNode::make(var, value, body);
}
}
Expr VisitExpr_(const IfNode* op) final {
auto subgraph = GetSubgraph(GetRef<If>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->cond);
AddToSubgraph(subgraph, op->true_branch);
AddToSubgraph(subgraph, op->false_branch);
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
return IfNode::make(guard, true_b, false_b);
}
}
Expr VisitExpr_(const RefCreateNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefCreate>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->value);
Expr value = VisitExpr(op->value);
return RefCreateNode::make(value);
}
}
Expr VisitExpr_(const RefReadNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefRead>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
return RefReadNode::make(ref);
}
}
Expr VisitExpr_(const RefWriteNode* op) final {
auto subgraph = GetSubgraph(GetRef<RefWrite>(op));
if (!subgraph) {
return ExprMutator::VisitExpr_(op);
} else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWriteNode::make(ref, value);
}
}
private:
int var_id_{0};
int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
};
/*!
* \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
* the same codegen backend. This reduces rounds trips between TVM and external
* backends. Likely we can borrow some ideas from operator fusion.
*
* For example, sg1 and sg2 should be combined if they belong to the same
* codegen tool in the following case.
*
* op1
* / \
* sg1 sg2
*
* |
* \|/
*
* op1
* |
* sg1_sg2
*
* where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
* inputs that obtained from the tuple.
*/
Expr PartitionGraph(const Expr& expr) {
Partitioner part;
return part.Mutate(expr);
}
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::PartitionGraph(f));
};
auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
.set_body_typed(transform::PartitionGraph);
} // namespace transform
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for graph partitioning."""
import os
import sys
import numpy as np
import pytest
import tvm
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
# Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0)
class WhiteListAnnotator:
def __init__(self, op_list, compiler):
assert isinstance(op_list, (list, tuple, set))
self.op_list = op_list
self.compiler = compiler
def transform_function(self, func, mod, ctx):
annotator = self
class Annotator(tvm.relay.ExprMutator):
def visit_call(self, call):
op_name = call.op.name
if op_name in annotator.op_list:
new_args = []
for arg in call.args:
ann = compiler_begin(super().visit(arg),
annotator.compiler)
new_args.append(ann)
new_call = relay.Call(call.op, new_args, call.attrs,
call.type_args)
return compiler_end(new_call, annotator.compiler)
else:
return super().visit_call(call)
return Annotator().visit(func)
class CcompilerAnnotator(ExprMutator):
"""
A simple annotator that creates the following program:
|
-- begin --
|
add
|
subtract
|
multiply
|
-- end --
|
"""
def __init__(self):
super(CcompilerAnnotator, self).__init__()
self.in_compiler = 0
def visit_call(self, call):
if call.op.name == "add": # Annotate begin at args
if self.in_compiler == 1:
lhs = compiler_begin(super().visit(call.args[0]), "ccompiler")
rhs = compiler_begin(super().visit(call.args[1]), "ccompiler")
op = relay.add(lhs, rhs)
self.in_compiler = 2
return op
elif call.op.name == "subtract":
if self.in_compiler == 1:
lhs = super().visit(call.args[0])
rhs = super().visit(call.args[1])
if isinstance(lhs, relay.expr.Var):
lhs = compiler_begin(lhs, "ccompiler")
if isinstance(rhs, relay.expr.Var):
rhs = compiler_begin(rhs, "ccompiler")
return relay.subtract(lhs, rhs)
elif call.op.name == "multiply": # Annotate end at output
self.in_compiler = 1
lhs = super().visit(call.args[0])
rhs = super().visit(call.args[1])
if isinstance(lhs, relay.expr.Var):
lhs = compiler_begin(lhs, "ccompiler")
if isinstance(rhs, relay.expr.Var):
rhs = compiler_begin(rhs, "ccompiler")
op = relay.multiply(lhs, rhs)
if self.in_compiler == 2:
op = compiler_end(op, "ccompiler")
self.in_compiler = 0
return op
return super().visit_call(call)
class WholeGraphAnnotator(ExprMutator):
"""
An annotator that creates a compiler for an entire graph.
"""
def __init__(self, compiler):
super(WholeGraphAnnotator, self).__init__()
self.compiler = compiler
self.last_call = True
def visit_call(self, call):
curr_last = self.last_call
self.last_call = False
params = []
for arg in call.args:
param = super().visit(arg)
if isinstance(param, relay.expr.Var):
param = compiler_begin(param, self.compiler)
params.append(param)
new_call = relay.Call(call.op, params, call.attrs)
if curr_last:
new_call = compiler_end(new_call, self.compiler)
return new_call
class MobileNetAnnotator(ExprMutator):
"""
Annotate mobilenet until global_avg_pool.
"""
def __init__(self, compiler):
super(MobileNetAnnotator, self).__init__()
self.compiler = compiler
self.compiler_open = False
def visit_call(self, call):
if call.op.name == 'nn.global_avg_pool2d':
self.compiler_open = True
compiler_open = self.compiler_open
params = []
for arg in call.args:
param = super().visit(arg)
if call.op.name == 'nn.global_avg_pool2d':
param = compiler_end(param, self.compiler)
if compiler_open and isinstance(param, relay.expr.Var):
param = compiler_begin(param, self.compiler)
params.append(param)
new_call = relay.Call(call.op, params, call.attrs)
return new_call
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu(), params=None):
if sys.platform == "win32":
print("Skip test on Windows for now")
return
def update_lib(lib):
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)
return lib
def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.set_input(**param)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
check_vm_result()
check_graph_runtime_result()
def test_multi_node_compiler():
x = relay.var('x', shape=(10, 10))
w0 = relay.var('w0', shape=(10, 10))
w1 = relay.var('w1', shape=(10, 10))
w2 = relay.var('w2', shape=(10, 10))
w3 = relay.var('w3', shape=(10, 10))
w4 = relay.var('w4', shape=(10, 10))
w5 = relay.var('w5', shape=(10, 10))
w6 = relay.var('w6', shape=(10, 10))
w7 = relay.var('w7', shape=(10, 10))
# C compiler
# FIXME: We generate two compilers for this case but they should be merged to one
# due to the common input (x).
z0 = relay.add(x, w0)
p0 = relay.subtract(z0, w1)
q0 = relay.multiply(p0, w2)
z1 = relay.add(x, w3)
p1 = relay.subtract(z1, w4)
q1 = relay.multiply(p1, w5)
# Other parts on TVM
z2 = relay.add(x, w6)
q2 = relay.subtract(z2, w7)
r = relay.concatenate((q0, q1, q2), axis=0)
f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)
mod = relay.Module()
ann = CcompilerAnnotator()
mod["main"] = ann.visit(f)
mod = transform.PartitionGraph()(mod)
mod = transform.InferType()(mod)
x_data = np.random.rand(10, 10).astype('float32')
w_data = []
for _ in range(8):
w_data.append(np.random.rand(10, 10).astype('float32'))
map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
map_inputs["x"] = x_data
check_result(
mod, map_inputs, (30, 10),
np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
((x_data + w_data[3]) - w_data[4]) * w_data[5],
x_data + w_data[6] - w_data[7]),
axis=0))
def test_extern_ccompiler_single_op():
@transform.function_pass(opt_level=0)
class MyAnnotator:
def transform_function(self, func, mod, ctx):
class Annotator(tvm.relay.ExprMutator):
def visit_call(self, call):
new_args = []
for arg in call.args:
ann = compiler_begin(self.visit(arg), "ccompiler")
new_args.append(ann)
new_call = relay.Call(call.op, new_args)
return compiler_end(new_call, "ccompiler")
return Annotator().visit(func)
x = relay.var('x', shape=(8, 8))
y = relay.var('y', shape=(8, 8))
z = x + y
f = relay.Function([x, y], z)
x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
mod = relay.Module()
mod["main"] = f
mod = MyAnnotator()(mod)
mod = transform.PartitionGraph()(mod)
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
def test_extern_ccompiler_default_ops():
def expected():
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1))
func = func.set_attribute("Compiler",
tvm.expr.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
tvm.expr.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.set_attribute("Primitive",
tvm.expr.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod = relay.Module()
mod["main"] = main
return mod
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = relay.Module()
mod["main"] = f
mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)
fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected()
assert relay.alpha_equal(fused_mod, expected_mod)
x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
np_add = x_data + y_data
res = np.concatenate([np.log(np_add), np.exp(np_add)])
check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
def test_extern_ccompiler():
x = relay.var('x', shape=(2, 2))
y = relay.var('y', shape=(2, 2))
z = x + x
p = y * y
f = relay.Function([x, y], p - z)
x_data = np.random.rand(2, 2).astype('float32')
y_data = np.random.rand(2, 2).astype('float32')
mod = relay.Module()
mod["main"] = f
mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)
check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
def test_extern_dnnl():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
dtype = 'float32'
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data, weight1], out)
mod = relay.Module()
mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
mod = transform.PartitionGraph()(mod)
ref_mod = relay.Module()
ref_mod['main'] = f
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
ref_res = ref_ex.evaluate()(i_data, w1_data)
check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
dtype = 'float32'
ishape = (1, 3, 224, 224)
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod = WhiteListAnnotator(op_list, "dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
dtype='float32')
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **params)
check_result(mod, {"data": i_data},
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment