Unverified Commit f506c8b1 by Cody Yu Committed by GitHub

[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)

* add target to region

* refactor annotate_target

* Make all unit test working

* quick fix

* enable BN, unit test failed

* Fix vm test, unit test. Refactor annotate_target a bit.

* quick fix fusion

* revert fusion change

* style fix

* Refactor merge region pass

* format

* minor fix

* Skip e2e test

* lint

* support AnnotateTarget multiple runs

* Add HasAttr and revert DNNL codegen

* address comment

Co-authored-by: Zhi Chen <chzhi@amazon.com>
parent 5795539c
...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -587,14 +587,14 @@ def PartitionGraph(): ...@@ -587,14 +587,14 @@ def PartitionGraph():
def AnnotateTarget(target): def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then """Annotate ops in an experession with a provied compiler/target and then
use it for codegen. use it for codegen.
Parameters Parameters
---------- ----------
target : String targets : str or List[str]
The target compiler used for codegen. The list of target compilers used for codegen.
Returns Returns
------- -------
...@@ -602,7 +602,9 @@ def AnnotateTarget(target): ...@@ -602,7 +602,9 @@ def AnnotateTarget(target):
The annotated pass that wrapps ops with subgraph_start and The annotated pass that wrapps ops with subgraph_start and
subgraph_end. subgraph_end.
""" """
return _ffi_api.AnnotateTarget(target) if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
def Inline(): def Inline():
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/runtime/container.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -31,7 +32,7 @@ namespace relay { ...@@ -31,7 +32,7 @@ namespace relay {
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) { for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) { if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate; return candidate;
} }
} }
...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
} }
// Merge src to dest and erase src. // Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end()); dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins) { for (const auto& input : src->ins_) {
dest->ins.push_back(input); dest->ins_.push_back(input);
} }
for (const auto& output : src->outs) { for (const auto& output : src->outs_) {
dest->outs.push_back(output); dest->outs_.push_back(output);
} }
// if any of the outputs of src are inputs of dest, they become internal nodes // if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs // so remove them from outs
std::vector<Expr> ins_to_remove; std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]); auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes.end()) { if (it != src->nodes_.end()) {
dest->outs.remove(*it); dest->outs_.remove(*it);
ins_to_remove.push_back(input); ins_to_remove.push_back(input);
} }
} }
for (const auto& input : ins_to_remove) { for (const auto& input : ins_to_remove) {
dest->ins.remove(input); dest->ins_.remove(input);
} }
regions_.erase(src); regions_.erase(src);
} }
...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) ...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) { if (src.defined()) {
MergeRegions(src, dest); MergeRegions(src, dest);
} else { } else {
dest->nodes.insert(expr); dest->nodes_.insert(expr);
} }
} }
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion()); auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++; (*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first; return *ret.first;
} }
class AnnotatedRegionSet::Creator : public ExprVisitor { class AnnotatedRegionSet::Creator : public ExprVisitor {
public: public:
Creator(const Op& region_begin_op, const Op& region_end_op) : Creator(const Op& region_begin_op, const Op& region_end_op)
begin_op_(region_begin_op), end_op_(region_end_op) {} : begin_op_(region_begin_op), end_op_(region_end_op) {}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const CallNode* call) { void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { ...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n" << "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false)); << AsText(GetRef<Call>(call), false));
} }
region->ins.push_back(GetRef<Call>(call)); region->ins_.push_back(GetRef<Call>(call));
} else { } else {
CHECK_EQ(call->op, end_op_); CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region // Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]); auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) { if (!region.defined()) {
region = region_set_->MakeRegion(); // Create a new region if the argument is not belonged to any regions yet.
region->nodes.insert(call->args[0]); region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
} }
region->nodes.insert(GetRef<Call>(call)); region->nodes_.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call)); region->outs_.push_back(GetRef<Call>(call));
} }
ExprVisitor::VisitExpr_(call); ExprVisitor::VisitExpr_(call);
} }
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const TupleNode* op) { void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op)); auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) { if (region.defined()) {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <string> #include <string>
...@@ -49,33 +50,39 @@ class AnnotatedRegionSet; ...@@ -49,33 +50,39 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object { class AnnotatedRegionNode : public Object {
public: public:
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id); v->Visit("id", &id_);
Array<Expr> nodes_array(nodes.begin(), nodes.end()); v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array); v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end()); Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array); v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end()); Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array); v->Visit("rets", &rets_array);
} }
/*! \brief Get the region ID. */ /*! \brief Get the region ID. */
int GetID() const { int GetID() const {
return id; return id_;
}
/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
} }
/*! \brief Get the region's inputs. */ /*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const { std::list<Expr> GetInputs() const {
return ins; return ins_;
} }
/*! \brief Get the region's outputs. */ /*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const { std::list<Expr> GetOutputs() const {
return outs; return outs_;
} }
/*! \brief Get the region's nodes. */ /*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const { std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes; return nodes_;
} }
static constexpr const char* _type_key = "relay.AnnotatedRegion"; static constexpr const char* _type_key = "relay.AnnotatedRegion";
...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object { ...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object {
protected: protected:
/*! \brief The region ID. */ /*! \brief The region ID. */
int id{-1}; int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */ /*! \brief The inputs to this region. */
std::list<Expr> ins; std::list<Expr> ins_;
/*! \brief The outputs of this region */ /*! \brief The outputs of this region */
std::list<Expr> outs; std::list<Expr> outs_;
/*! \brief Nodes in this region. */ /*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes; std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
friend class AnnotatedRegionSet; friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode; friend class AnnotatedRegionSetNode;
...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object { ...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr); void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*! /*!
* \brief Make a new region. * \brief Make a new region for a target.
* *
* \return The new region. * \return The new region.
*/ */
AnnotatedRegion MakeRegion(); AnnotatedRegion MakeRegion(const std::string& target);
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_; std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */ /*! \brief The next region ID to assign. */
......
...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple); // Do nothing
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream; std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID // Args: ID
std::vector<std::string> args; std::vector<std::string> args;
...@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
} }
// Analyze the output buffers // Analyze the output buffer
std::vector<Type> out_types; auto type_node = call->checked_type().as<TensorTypeNode>();
if (call->checked_type()->IsInstance<TupleTypeNode>()) { CHECK(type_node);
auto type_node = call->checked_type().as<TupleTypeNode>(); const auto& dtype = GetDtypeString(type_node);
for (auto field : type_node->fields) { std::string out = "buf_" + std::to_string(buf_idx_++);
CHECK(field->IsInstance<TensorTypeNode>()); auto out_shape = GetShape(call->checked_type());
out_types.push_back(field); int out_size = 1;
} for (size_t i = 0; i < out_shape.size(); ++i) {
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) { out_size *= out_shape[i];
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}
out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments // Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
...@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
decl_stream << ");"; decl_stream << ");";
ext_func_body.push_back(decl_stream.str()); ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
std::string JIT(void) { std::string JIT(void) {
......
...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -19,132 +19,203 @@ ...@@ -19,132 +19,203 @@
/*! /*!
* \file src/relay/transforms/annotate_target.cc * \file src/relay/transforms/annotate_target.cc
* \brief Wraps a call with compiler_begin and compiler_end to indicate that * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
* the op of this call node will use external compiler. * this expr should be handled by the external compiler.
*/ */
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace annotate_target { namespace annotate_target {
// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
// A helper class to insert annotation boundaries for a program region that will // A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler. // be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator { class AnnotateTargetWrapper : public ExprMutator {
public: public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
Expr Annotate(const Expr& expr) { /*!
return InsertEnd(Mutate(expr)); * \brief This function annotates a compiler end and a compiler begin to all arguments.
} *
* The compiler end is based on the arg target while the compiler begin is based on the given
bool IsSupported(const Expr& expr) { * target. If target is not given and all arguments are going to the same target, then we will
if (expr->IsInstance<CallNode>()) { * use that target; otherwise we use default for this op. Note that all arg exprs must be
Call call = Downcast<Call>(expr); * available in op_expr_to_target before calling this function.
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_); *
if (call->op->IsInstance<OpNode>()) { * \param args An array of arguments of the given node.
Op op = Downcast<Op>(call->op); * \param target The target of the current node.
CHECK(op.defined()); * \return A pair of target and annotated argument expressions.
if (fannotate.count(op)) { */
return fannotate[op](call->attrs, call->args); std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
} const std::string& target = "") {
} else if (call->op->IsInstance<FunctionNode>()) { std::string ref_target = "";
// handle composite functions Array<Expr> compiler_ends;
Function func = Downcast<Function>(call->op); for (auto arg : args) {
CHECK(func.defined()); std::string arg_target = "defualt";
auto comp_name = func->GetAttr<String>(attr::kComposite); const CallNode* call = arg.as<CallNode>();
if (comp_name.defined()) {
std::string comp_name_str = comp_name; if (call && call->op == compiler_begin_op) {
size_t i = comp_name_str.find('.'); // Argument is already compiler begin node meaning that this is not the first time
if (i != std::string::npos) { // running this pass, so we simply remove it and will add a new one later.
std::string target = comp_name_str.substr(0, i); CHECK_EQ(call->args.size(), 1U);
if (target == target_) return true; const CallNode* end = call->args[0].as<CallNode>();
} if (end->op == compiler_end_op) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
} }
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
} else {
// Input vars.
compiler_ends.push_back(arg);
} }
}
if (expr->IsInstance<TupleGetItemNode>()) { // Maintain reference target in case the target of the current node is unassigned.
TupleGetItem get = Downcast<TupleGetItem>(expr); if (ref_target == "") {
if (get->tuple->IsInstance<CallNode>() && ref_target = arg_target;
get->tuple.as<CallNode>()->op == compiler_begin_op) { } else if (ref_target != arg_target) {
return true; ref_target = "default";
} }
} }
return false;
}
Expr InsertEnd(const Expr& arg) { // Determine compiler begin target.
if (IsSupported(arg)) { std::string op_target = (target == "") ? ref_target : target;
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end"); Array<Expr> compiler_begins;
CHECK(end_op); for (const auto& end : compiler_ends) {
Expr end = (*end_op)(arg, target_); compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
return end;
} }
return arg;
return {op_target, compiler_begins};
} }
Expr VisitExpr_(const CallNode* cn) { Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
auto new_e = ExprMutator::VisitExpr_(cn); Expr new_op = (*ann_op)(expr, target);
new_op->checked_type_ = expr->checked_type_;
return new_op;
}
Call call = Downcast<Call>(new_e); Expr VisitExpr_(const CallNode* cn) final {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;
auto op_node = cn->op.as<OpNode>();
// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && cn->op == compiler_begin_op) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(cn->args.size(), 1U);
return VisitExpr(cn->args[0]);
} else if (op_node && cn->op == compiler_end_op) {
// Override compiler end with the new target.
CHECK_EQ(cn->args.size(), 1U);
auto input_expr = VisitExpr(cn->args[0]);
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}
// add end annotations if the args are supported // Peek the first argument. If it is compiler begin then this node had annotated by
Array<Expr> compiler_ends; // another target before, so we also consider that target as a supported target.
for (const auto& it : call->args) { const CallNode* first_arg_call = cn->args[0].as<CallNode>();
compiler_ends.push_back(InsertEnd(it)); if (first_arg_call && first_arg_call->op == compiler_begin_op) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
} }
call = Call(call->op, compiler_ends, call->attrs);
// Check which targets this op can be offloaded.
// add begin annotations if the call node is supported if (op_node) {
if (IsSupported(call)) { // TVM operators: Check target specific op checking function and add to supported_targets
tvm::Array<tvm::relay::Expr> compiler_begins; // if it is supported.
const auto* begin_op = Op op = Downcast<Op>(cn->op);
runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); CHECK(op.defined());
for (const auto& it : call->args) { for (const auto& target : this->targets_) {
CHECK(begin_op); if (!Op::HasAttr("target." + std::string(target))) {
Expr begin = (*begin_op)(it, target_); continue;
compiler_begins.push_back(begin); }
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
supported_targets.push_back(target);
}
}
} else if (cn->op->IsInstance<FunctionNode>()) {
// Composite function: Add the target of a composite function to supported_targets
// if it is in the target list.
Function func = Downcast<Function>(cn->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
std::string comp_name_str = comp_name;
size_t i = comp_name_str.find('.');
if (i != std::string::npos) {
std::string comp_target = comp_name_str.substr(0, i);
for (const auto& target : this->targets_) {
if (std::string(target) == comp_target) {
supported_targets.push_back(comp_target);
break;
}
}
}
} }
call = Call(call->op, compiler_begins, call->attrs);
} }
supported_targets.push_back("default"); // Make default as the last option.
// TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
// the highest priority, but we should preserve all supported targets so that
// we can make a better decision.
std::string target = supported_targets[0];
// Visit and mutate arguments after the target of this op has been determined.
auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
// Add annotations to each arg.
auto target_n_args = AnnotateArgs(new_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args);
Call call = Call(new_call->op, compiler_begins, new_call->attrs);
call->checked_type_ = cn->checked_type_;
// Update the target map.
op_expr_to_target_[call] = target;
return std::move(call); return std::move(call);
} }
Expr VisitExpr_(const TupleNode* op) { Expr VisitExpr_(const TupleNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<Tuple>(new_e);
auto tup = Downcast<Tuple>(new_e); auto target_n_args = AnnotateArgs(expr->fields);
Array<Expr> new_fields; auto new_expr = Tuple(std::get<1>(target_n_args));
for (auto field : tup->fields) { op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
new_fields.push_back(InsertEnd(field)); return std::move(new_expr);
}
return Tuple(new_fields);
} }
Expr VisitExpr_(const TupleGetItemNode* op) { Expr VisitExpr_(const TupleGetItemNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<TupleGetItem>(new_e);
auto get = Downcast<TupleGetItem>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
if (IsSupported(get->tuple)) { auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
const auto* begin_op = op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); return std::move(new_expr);
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
} }
Expr VisitExpr_(const FunctionNode* fn) { Expr VisitExpr_(const FunctionNode* fn) final {
Function func; Function func;
Expr new_body; Expr new_body;
// don't step into composite functions // don't step into composite functions
...@@ -154,84 +225,93 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -154,84 +225,93 @@ class AnnotateTargetWrapper : public ExprMutator {
} else { } else {
auto new_e = ExprMutator::VisitExpr_(fn); auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e); func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body); new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
}
} }
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
return Function(
func->params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} }
Expr VisitExpr_(const LetNode* op) { Expr VisitExpr_(const LetNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e); auto let = Downcast<Let>(new_e);
return Let(
let->var, auto target_n_args = AnnotateArgs({let->value, let->body});
InsertEnd(let->value), auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
InsertEnd(let->body)); op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const IfNode* op) { Expr VisitExpr_(const IfNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<If>(new_e);
auto iff = Downcast<If>(new_e);
return If( auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
InsertEnd(iff->cond), CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
InsertEnd(iff->true_branch), auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
InsertEnd(iff->false_branch)); std::get<1>(target_n_args)[2]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefCreateNode* op) { Expr VisitExpr_(const RefCreateNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefCreate>(new_e);
auto create = Downcast<RefCreate>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
return RefCreate(InsertEnd(create->value)); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefReadNode* op) { Expr VisitExpr_(const RefReadNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefRead>(new_e);
auto read = Downcast<RefRead>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
return RefRead(InsertEnd(read->ref)); auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefWriteNode* op) { Expr VisitExpr_(const RefWriteNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefWrite>(new_e);
auto write = Downcast<RefWrite>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
return RefWrite( auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
InsertEnd(write->ref), op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
InsertEnd(write->value)); return std::move(new_expr);
} }
private: private:
std::string target_; /*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectHash, ObjectEqual> op_expr_to_target_;
}; };
Expr AnnotateTarget(const Expr& expr, const std::string& target) { Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
return AnnotateTargetWrapper(target).Annotate(expr); return AnnotateTargetWrapper(targets).Mutate(expr);
} }
} // namespace annotate_target } // namespace annotate_target
namespace transform { namespace transform {
Pass AnnotateTarget(const std::string& target) { Pass AnnotateTarget(const Array<runtime::String>& targets) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target)); return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
}; };
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{"InferType"}); {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
} }
TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget") TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget);
.set_body_typed(AnnotateTarget);
} // namespace transform } // namespace transform
......
...@@ -46,216 +46,13 @@ ...@@ -46,216 +46,13 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace partitioning { namespace merge_compiler_region {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to // Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead. // reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*! \brief This is a pre-requisite pass to merge-supported pass.
* The AnnotateRestDefault pass will put "default" Compiler Annotations to
* nodes that are not annotated already. This is there to ensure that the
* user will not leave un-annotated nodes MergeCompilerRegions pass is run.
* Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
*/
class AnnotateRestDefault : public ExprMutator {
public:
explicit AnnotateRestDefault(const Expr& expr) {
regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
}
Expr Annotate(const Expr& expr) {
// Its a function that is being passed on to annotate
func_ = Downcast<Function>(expr);
// Corner Case CC1 : If the last node does not belong
// to a region node to add a compiler_end
auto region = regions_->GetRegion(func_->body);
auto mutated_expr = this->VisitExpr(expr);
if (!region.defined()) {
func_ = Downcast<Function>(mutated_expr);
// CC1 : add that compiler end after mutation
auto body = InsertEnd(func_->body);
func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
return Downcast<Expr>(func_);
}
return mutated_expr;
}
/*! \brief This function adds compiler ends to nodes that
* don't belong to a region already (default).
* \param expr The expression to add a compiler end to.
* \return expr The expression with or without a compiler end added.
*/
Expr InsertEnd(const Expr& expr) {
if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
!expr->IsInstance<ConstantNode>()) {
const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(expr, target_);
return end;
}
return expr;
}
/*! \brief This function adds compiler begins to nodes that
* don't belong to a region already (default).
* \param expr The expression to add a compiler begin to.
* \return expr The expression with or without a compiler begin added.
*/
Expr InsertBegin(const Expr& expr) {
const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(expr, target_);
annotated_nodes_.insert(begin);
return begin;
}
Expr VisitExpr_(const CallNode* cn) final {
auto region = regions_->GetRegion(GetRef<Call>(cn));
auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e);
// Add compiler ends if the parent isn't annotated
Array<Expr> args;
for (auto arg : call->args) {
args.push_back(InsertEnd(arg));
}
Expr updated_call = Call(call->op, args, call->attrs);
if (!region.defined()) {
// if the current node does not belong to annotated region
// annotate the all incoming edges (args)
// with "default" compiler_begin annotations.
Array<Expr> compiler_begins;
for (auto arg : args) {
compiler_begins.push_back(InsertBegin(arg));
}
updated_call = Call(call->op, compiler_begins, call->attrs);
} else {
annotated_nodes_.insert(updated_call);
}
return updated_call;
};
Expr VisitExpr_(const TupleNode* op) {
auto region = regions_->GetRegion(GetRef<Tuple>(op));
auto new_e = ExprMutator::VisitExpr_(op);
Tuple tup = Downcast<Tuple>(new_e);
Array<Expr> fields;
for (auto field : tup->fields) {
fields.push_back(InsertEnd(field));
}
Expr updated_tuple = Tuple(fields);
if (!region.defined()) {
Array<Expr> compiler_begins;
for (const auto& field : fields) {
compiler_begins.push_back(InsertBegin(field));
}
updated_tuple = Tuple(compiler_begins);
} else {
annotated_nodes_.insert(updated_tuple);
}
return updated_tuple;
}
Expr VisitExpr_(const TupleGetItemNode* op) {
auto region = regions_->GetRegion(GetRef<TupleGetItem>(op));
auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e);
auto updated_tuple = InsertEnd(get->tuple);
Expr updated_get = TupleGetItem(updated_tuple, get->index);
if (!region.defined()) {
updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
} else {
annotated_nodes_.insert(updated_get);
}
return updated_get;
}
Expr VisitExpr_(const IfNode* op) {
auto region = regions_->GetRegion(GetRef<If>(op));
auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e);
if (!region.defined()) {
return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)),
InsertBegin(InsertEnd(iff->false_branch)));
} else {
Expr updated_iff =
If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch));
annotated_nodes_.insert(updated_iff);
return updated_iff;
}
}
Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
}
Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}
Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
}
Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e);
return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
}
private:
AnnotatedRegionSet regions_;
const std::string target_ = "default";
Function func_;
std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
};
class MergeAnnotations : public ExprMutator {
public:
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
Expr VisitExpr_(const CallNode* call) final {
// remove 'default' annotations
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
if (call->op == compiler_begin_op) {
if (call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
return VisitExpr(arg->args[0]);
}
}
}
}
return ExprMutator::VisitExpr_(call);
}
private:
AnnotatedRegionSet regions_;
};
class RegionMerger : public ExprVisitor { class RegionMerger : public ExprVisitor {
public: public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
...@@ -263,62 +60,74 @@ class RegionMerger : public ExprVisitor { ...@@ -263,62 +60,74 @@ class RegionMerger : public ExprVisitor {
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) { if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call)); auto region = regions_->GetRegion(GetRef<Call>(call));
if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
// set the region target // Skip this region if it has been merged to the other region.
if (merged_regions_.find(region->GetID()) != merged_regions_.end()) {
return;
}
// Check the region target.
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler; CHECK_EQ(region->GetTarget(), compiler_attrs->compiler);
// first look at the region args to determine the parent regions
// Visit the unmerged parent regions.
for (const auto& arg : region->GetInputs()) { for (const auto& arg : region->GetInputs()) {
// all args should be begin annotations // Region inputs must be begin annotation, and the region of
// the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg); auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op); CHECK_EQ(begin->op, compiler_begin_op);
// the arguments of the begin annotations will be in the parent regions
auto parent_region = regions_->GetRegion(begin->args[0]); auto parent_region = regions_->GetRegion(begin->args[0]);
// if there is no parent region, move on
if (!parent_region.defined()) continue; // Skip this region if it has been merged.
// merge the parent region if it hasn't been done already if (!parent_region.defined()) {
if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { continue;
} else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]); VisitExpr(begin->args[0]);
} }
} }
// get the mergeable regions now all the parents have been visited
// Collect unmerged parent regions.
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions; std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) { for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg); auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op); CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]); auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue; if (parent_region.defined()) {
mergeable_regions.insert(parent_region); mergeable_regions.insert(parent_region);
}
} }
// Propogate all the parent restrictions to the current region.
auto& region_restrictions = region_restrictions_[region->GetID()]; auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) { for (const auto& parent_region : mergeable_regions) {
// add all the parent restrictions to the current region
auto parent_restrictions = region_restrictions_[parent_region->GetID()]; auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end()); region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
} }
for (const auto& parent_region : mergeable_regions) { for (const auto& parent_region : mergeable_regions) {
bool merged = false; // Skip the parent region with a different target.
// check the parent region has the same target if (parent_region->GetTarget() != compiler_attrs->compiler) {
if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) { region_restrictions.insert(parent_region->GetID());
// check the parent region isn't in the restrictions continue;
if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { }
// merge the parent region into the current region
regions_->MergeRegions(parent_region, region); // Skip the parent region if it is in the restriction set.
// update the restrictions of all other regions to reflect the if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) {
// change in id continue;
for (const auto& r : regions_) { }
auto& restrictions = region_restrictions_[r->GetID()];
if (restrictions.find(parent_region->GetID()) != restrictions.end()) { // Merge the parent region to the current one.
restrictions.erase(parent_region->GetID()); regions_->MergeRegions(parent_region, region);
restrictions.insert(region->GetID());
} // Replace the parent region ID with the current region for all
} // other regions' restriction sets.
merged = true; for (const auto& r : regions_) {
auto& restrictions = region_restrictions_[r->GetID()];
if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
restrictions.erase(parent_region->GetID());
restrictions.insert(region->GetID());
} }
} }
// if the parent wasn't merged, add it as a restriction to the current
// region
if (!merged) region_restrictions.insert(parent_region->GetID());
} }
merged_regions_.insert(region->GetID()); merged_regions_.insert(region->GetID());
} }
...@@ -328,42 +137,58 @@ class RegionMerger : public ExprVisitor { ...@@ -328,42 +137,58 @@ class RegionMerger : public ExprVisitor {
private: private:
AnnotatedRegionSet regions_; AnnotatedRegionSet regions_;
std::unordered_set<int> merged_regions_; std::unordered_set<int> merged_regions_;
std::map<int, std::unordered_set<int>> region_restrictions_; std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
std::map<int, std::string> region_targets_;
}; };
Expr MergeCompilerRegions(const Expr& expr) { class MergeAnnotations : public ExprMutator {
// Annotate all the nodes that aren't annotated as 'default'. public:
AnnotateRestDefault anno_default(expr); explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
auto expr_all_annotated = anno_default.Annotate(expr);
Expr VisitExpr_(const CallNode* call) final {
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
return VisitExpr(arg->args[0]);
}
}
}
return ExprMutator::VisitExpr_(call);
}
private:
AnnotatedRegionSet regions_;
};
Expr MergeCompilerRegions(const Expr& expr) {
// Create regions using the annotations. // Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
// By now, all the nodes have some sort of annotation. // Analyze the graph to explore the opportunities of merging regions.
// Region merger is an ExprVisitor that will update the
// AnnotatedRegionSet, merging all the regions that can be merged.
RegionMerger merger(regions); RegionMerger merger(regions);
merger.VisitExpr(expr_all_annotated); merger.VisitExpr(expr);
// This updates the expression to remove annotations that are now // Remove annotations that are not in the region boundaries.
// 'internal' to a merged region.
MergeAnnotations merge_anno(regions); MergeAnnotations merge_anno(regions);
return merge_anno.Mutate(expr_all_annotated); return merge_anno.Mutate(expr);
} }
} // namespace partitioning } // namespace merge_compiler_region
namespace transform { namespace transform {
Pass MergeCompilerRegions() { Pass MergeCompilerRegions() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::MergeCompilerRegions(f)); return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f));
}; };
auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
return Sequential({partitioned, InferType()}); return Sequential({merged, InferType()});
} }
TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
......
...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator { ...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator {
IRModule module_; IRModule module_;
}; };
class DefaultRemover : public ExprMutator {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
return ExprMutator::VisitExpr_(call);
}
private:
IRModule module_;
};
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; [=](IRModule m, PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_, float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p ...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance, float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import tvm
from tvm import relay from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
def check_region(region_set, args, nodes, rets): def check_region(region_set, target, args, nodes, rets):
region = region_set.get_region(args[0]) region = region_set.get_region(args[0])
assert region assert region
assert target == region.target
assert set(args) == set(region.args) assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes) assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets) assert set(rets) == set(region.rets)
...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond(): ...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond():
assert len(region_set) == 4 assert len(region_set) == 4
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, ce_1, ce_2], [cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2], [ce_1, ce_2],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_2], [cb_2],
[cb_2, O_2, ce_3], [cb_2, O_2, ce_3],
[ce_3], [ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, ce_4],
[ce_4], [ce_4],
...@@ -88,7 +94,9 @@ def test_region_set_creator_merged(): ...@@ -88,7 +94,9 @@ def test_region_set_creator_merged():
cb_3 = compiler_begin(ce_3, 'test_target') cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target') cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target') O_4 = relay.add(cb_3, cb_4)
O_5 = relay.Tuple([O_3, O_4])
ce_4 = compiler_end(O_5, 'test_target')
merged = relay.Function([data], ce_4) merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged, region_set = relay.analysis.AnnotatedRegionSet(merged,
...@@ -97,20 +105,23 @@ def test_region_set_creator_merged(): ...@@ -97,20 +105,23 @@ def test_region_set_creator_merged():
assert len(region_set) == 3 assert len(region_set) == 3
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, O_2, ce_2, ce_3], [cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3], [ce_2, ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, O_4, O_5, ce_4],
[ce_4], [ce_4],
) )
...@@ -118,4 +129,3 @@ def test_region_set_creator_merged(): ...@@ -118,4 +129,3 @@ def test_region_set_creator_merged():
if __name__ == "__main__": if __name__ == "__main__":
test_region_set_creator_diamond() test_region_set_creator_diamond()
test_region_set_creator_merged() test_region_set_creator_merged()
...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet(): ...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True
def test_multiple_ends(): def test_multiple_ends():
@reg.register("nn.relu", "target.test")
def relu(attrs, args): # pylint: disable=unused-variable
return True
def before(): def before():
x = relay.var("x", shape=(10, 10)) x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x) r = relay.nn.relu(x)
...@@ -208,10 +207,17 @@ def test_multiple_ends(): ...@@ -208,10 +207,17 @@ def test_multiple_ends():
r = relay.nn.relu(cb_1) r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test") ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test") ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1) cb_2 = relay.annotation.compiler_begin(ce_1, "default")
a_2 = relay.abs(ce_2) cb_3 = relay.annotation.compiler_begin(ce_2, "default")
out = relay.add(a_1, a_2) a_1 = relay.abs(cb_2)
f = relay.Function([x], out) a_2 = relay.abs(cb_3)
ce_3 = relay.annotation.compiler_end(a_1, "default")
ce_4 = relay.annotation.compiler_end(a_2, "default")
cb_4 = relay.annotation.compiler_begin(ce_3, "default")
cb_5 = relay.annotation.compiler_begin(ce_4, "default")
out = relay.add(cb_4, cb_5)
ce_6 = relay.annotation.compiler_end(out, "default")
f = relay.Function([x], ce_6)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
...@@ -220,6 +226,72 @@ def test_multiple_ends(): ...@@ -220,6 +226,72 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_type_propagation():
target = "test_type_propagation"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return args[0].checked_type.dtype == "float32"
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
out = relay.nn.relu(r)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
# If the type isn't propogated, then the relu checker function will fail to get the dtype.
assert transform.AnnotateTarget(target)(before())
def test_tuple():
target = "test_tuple"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("concatenate", "target." + target)
def concatenate(attrs, args): # pylint: disable=unused-variable
return True
"""Test that TupleNode is included in annotation when surrounded by supported nodes."""
def before():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.nn.relu(y)
out = relay.concatenate((a_1, a_2), axis=1)
f = relay.Function([x, y], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
cb_1 = relay.annotation.compiler_begin(x, target)
cb_2 = relay.annotation.compiler_begin(y, target)
a_1 = relay.nn.relu(cb_1)
a_2 = relay.nn.relu(cb_2)
ce_1 = relay.annotation.compiler_end(a_1, target)
ce_2 = relay.annotation.compiler_end(a_2, target)
cb_3 = relay.annotation.compiler_begin(ce_1, target)
cb_4 = relay.annotation.compiler_begin(ce_2, target)
tup = relay.Tuple([cb_3, cb_4])
ce_3 = relay.annotation.compiler_end(tup, target)
cb_3 = relay.annotation.compiler_begin(ce_3, target)
out = relay.op._make.concatenate(cb_3, 1)
ce_4 = relay.annotation.compiler_end(out, target)
f = relay.Function([x, y], ce_4)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_composite_function(): def test_composite_function():
def before(): def before():
a = relay.var('a', shape=(10, 10)) a = relay.var('a', shape=(10, 10))
...@@ -265,8 +337,37 @@ def test_composite_function(): ...@@ -265,8 +337,37 @@ def test_composite_function():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_multiple_runs():
@reg.register("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("add", "target.B")
def add(attrs, args): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
mod = transform.AnnotateTarget("A")(before())
mod = transform.AnnotateTarget("B")(mod)
expected = transform.AnnotateTarget(["A", "B"])(before())
assert tvm.ir.structural_equal(expected, mod)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
#test_extern_dnnl_mobilenet()
test_composite_function() test_composite_function()
#test_extern_dnnl_mobilenet()
test_multiple_ends()
test_type_propagation()
test_tuple()
test_multiple_runs()
...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts(): ...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts():
X = not supported by target X = not supported by target
O O O O
/ \ / \ / \\ / \\
O X --> O + + X O X --> O + + X
\ / \ / \\ / \\ /
O O O O
Note that we can't just merge the three supported operators together, Note that we can't just merge the three supported operators together,
...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts(): ...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts():
ce_1 = compiler_end(O_1, "test") ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test") ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test") cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2) O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test") X = relay.tanh(cb_3)
cb_4 = compiler_begin(X, "test") ce_4 = compiler_end(X, "default")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_4) cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond return diamond
def expected(): def expected():
...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts(): ...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1) O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2) cb_3 = compiler_begin(ce_2, "default")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test") cb_4 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test") cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_4, cb_5)
ce_4 = compiler_end(O_3, "test") ce_5 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4) func = relay.Function([data], ce_5)
return func return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
...@@ -85,7 +90,7 @@ def test_example_graph(): ...@@ -85,7 +90,7 @@ def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC. """This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts. Blue nodes are adds (target: test), red nodes are subtracts (target: default).
""" """
def annotated(): def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
...@@ -112,21 +117,30 @@ def test_example_graph(): ...@@ -112,21 +117,30 @@ def test_example_graph():
node2 = relay.add(begin4, begin5) node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test") end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test") begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7) node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test") end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test") end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3) dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test") begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test") begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9) node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test") end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test") dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test") begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11) node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test") end6 = compiler_end(node8, "test")
...@@ -159,20 +173,27 @@ def test_example_graph(): ...@@ -159,20 +173,27 @@ def test_example_graph():
node1 = relay.add(begin2, begin3) node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1) node2 = relay.add(node0, node1)
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
dbegin2 = compiler_begin(in_7, "default")
node3 = relay.subtract(dbegin0, dbegin1)
node4 = relay.subtract(dbegin2, node3)
dend0 = compiler_end(node4, "default")
begin4 = compiler_begin(node4, "test") begin4 = compiler_begin(dend0, "test")
begin5 = compiler_begin(in_9, "test") begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4) node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test") end1 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end1) dbegin4 = compiler_begin(end1, "default")
dbegin5 = compiler_begin(in_8, "default")
node6 = relay.subtract(dbegin5, dbegin4)
dend1 = compiler_end(node6, "default")
node7 = relay.add(begin5, node5) node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test") end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node6, "test") begin7 = compiler_begin(dend1, "test")
node8 = relay.add(begin7, begin6) node8 = relay.add(begin7, begin6)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Unit tests for graph partitioning.""" """Unit tests for graph partitioning."""
import os import os
import sys import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -26,8 +27,12 @@ from tvm import relay ...@@ -26,8 +27,12 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0) @transform.function_pass(opt_level=0)
...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib return lib
def check_vm_result(): def check_vm_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params) exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save() code, lib = exe.save()
...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result(): def check_graph_runtime_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params) json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib) lib = update_lib(lib)
...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet(): ...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype) i_data = np.random.uniform(0, 1, ishape).astype(dtype)
...@@ -851,6 +858,7 @@ if __name__ == "__main__": ...@@ -851,6 +858,7 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops() test_extern_ccompiler_default_ops()
test_extern_ccompiler() test_extern_ccompiler()
test_extern_dnnl() test_extern_dnnl()
# TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
#test_extern_dnnl_mobilenet() #test_extern_dnnl_mobilenet()
test_function_lifting() test_function_lifting()
test_function_lifting_inline() test_function_lifting_inline()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment