Unverified Commit 04499665 by mbaret Committed by GitHub

[RELAY] Fixes to MergeCompilerRegions (#5195)

* [RELAY] Fixed issues with MergeCompilerRegions

This PR addresses a few outstanding issues with
the implementation of MergeCompilerRegions. In
particular, it now handles TupleGetItem nodes properly
and other minor bugs related to region merging have
been fixed.

Change-Id: I07783afc56183a6f798a510209f23b0a5f252255

* Fixed issue using pre-merged regions

Change-Id: I0a844ac59bda1089ae0c67cef52f0b0c7ab2cbd7

* Removed some debugging logic

Change-Id: Ib6f2eede6f38bbb270073eb8d4c4dc19f60832c6

* Remove default annotations

Change-Id: I9b7696a51c95871491cbea33c40f92ec327e417f

* Annotate default 'if's

Change-Id: I0098bd1bf6788dd6366810dcefa84f1ebbffaab0

* Clang format

Change-Id: I944365cd3080a97a9261f643a8f1efa5a63cf82b

* Use src/dest in merge

Change-Id: Ie43113492bda8f1ce63eaf9615cb645bb9e2ee86

* Fixed partition test

Change-Id: I46f9e349b1a813a9140f7e4f8a2241687e2df73b

* Removed comments

Change-Id: I309afdd1951d7e796e41d13788aa487707e0ac4c
parent 2f41a396
...@@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
regions_.erase(src); regions_.erase(src);
} }
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) {
auto region2 = GetRegion(expr); auto src = GetRegion(expr);
if (region2.defined()) { if (src.defined()) {
MergeRegions(region, region2); MergeRegions(src, dest);
} else { } else {
region->nodes.insert(expr); dest->nodes.insert(expr);
} }
} }
......
...@@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object { ...@@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object {
/*! /*!
* \brief Add an expression to a region. * \brief Add an expression to a region.
* *
* \param region The region to add the expression to. * \param dest The region to add the expression to.
* \param expr The expression. * \param expr The expression.
*/ */
void AddToRegion(AnnotatedRegion region, const Expr& expr); void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*! /*!
* \brief Make a new region. * \brief Make a new region.
......
...@@ -32,6 +32,9 @@ namespace tvm { ...@@ -32,6 +32,9 @@ namespace tvm {
namespace relay { namespace relay {
namespace annotate_target { namespace annotate_target {
// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
// A helper class to insert annotation boundaries for a program region that will // A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler. // be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator { class AnnotateTargetWrapper : public ExprMutator {
...@@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator {
return fannotate[op](call->attrs, call->args); return fannotate[op](call->attrs, call->args);
} }
} }
if (expr->IsInstance<TupleGetItemNode>()) {
TupleGetItem get = Downcast<TupleGetItem>(expr);
if (get->tuple->IsInstance<CallNode>() &&
get->tuple.as<CallNode>()->op == compiler_begin_op) {
return true;
}
}
return false; return false;
} }
...@@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e); auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem( if (IsSupported(get->tuple)) {
InsertEnd(get->tuple), const auto* begin_op =
get->index); runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
} }
Expr VisitExpr_(const FunctionNode* op) { Expr VisitExpr_(const FunctionNode* op) {
......
...@@ -30,10 +30,10 @@ ...@@ -30,10 +30,10 @@
* as external functions. * as external functions.
*/ */
#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
...@@ -44,7 +44,6 @@ ...@@ -44,7 +44,6 @@
#include "../analysis/annotated_region_set.h" #include "../analysis/annotated_region_set.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace partitioning { namespace partitioning {
...@@ -63,7 +62,7 @@ static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); ...@@ -63,7 +62,7 @@ static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
class AnnotateRestDefault : public ExprMutator { class AnnotateRestDefault : public ExprMutator {
public: public:
explicit AnnotateRestDefault(const Expr& expr) { explicit AnnotateRestDefault(const Expr& expr) {
regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
} }
Expr Annotate(const Expr& expr) { Expr Annotate(const Expr& expr) {
...@@ -71,141 +70,158 @@ class AnnotateRestDefault : public ExprMutator { ...@@ -71,141 +70,158 @@ class AnnotateRestDefault : public ExprMutator {
func_ = Downcast<Function>(expr); func_ = Downcast<Function>(expr);
// Corner Case CC1 : If the last node does not belong // Corner Case CC1 : If the last node does not belong
// to a region nede to add a compiler_end // to a region node to add a compiler_end
auto region = regions_->GetRegion(func_->body); auto region = regions_->GetRegion(func_->body);
auto mutated_expr = this->VisitExpr(expr); auto mutated_expr = this->VisitExpr(expr);
if (!region.defined()) { if (!region.defined()) {
func_ = Downcast<Function>(mutated_expr); func_ = Downcast<Function>(mutated_expr);
// CC1 : add that compiler end after mutation // CC1 : add that compiler end after mutation
auto body = AddCompilerEnd_(func_->body); auto body = InsertEnd(func_->body);
func_ = Function(func_->params, body, func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
body->checked_type_, {}, DictAttrs());
return Downcast<Expr>(func_); return Downcast<Expr>(func_);
} }
return mutated_expr; return mutated_expr;
} }
/*! \brief This function adds compiler ends to nodes that /*! \brief This function adds compiler ends to nodes that
* have a region AND they should not be arguments of the * don't belong to a region already (default).
* original function
* \param expr The expression to add a compiler end to. * \param expr The expression to add a compiler end to.
* \return expr The expression with or without a compiler end added. * \return expr The expression with or without a compiler end added.
*/ */
Expr AddCompilerEnd(const Expr& expr) { Expr InsertEnd(const Expr& expr) {
auto region = regions_->GetRegion(expr); if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
auto visited_expr = VisitExpr(expr); !expr->IsInstance<ConstantNode>()) {
const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
// The compiler ends are added to nodes that does have a region CHECK(end_op);
// AND they should not be arguments of the original function Expr end = (*end_op)(expr, target_);
if (!region.defined() && return end;
std::find(func_->params.begin(),
func_->params.end(), visited_expr)
== func_->params.end()) {
return AddCompilerEnd_(visited_expr);
} }
return visited_expr; return expr;
} }
Expr AddCompilerEnd_(const Expr& expr) { /*! \brief This function adds compiler begins to nodes that
const auto* end_op = * don't belong to a region already (default).
runtime::Registry::Get("relay.op.annotation._make.compiler_end"); * \param expr The expression to add a compiler begin to.
CHECK(end_op); * \return expr The expression with or without a compiler begin added.
Expr end = (*end_op)(expr, target_); */
return end; Expr InsertBegin(const Expr& expr) {
const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(expr, target_);
annotated_nodes_.insert(begin);
return begin;
} }
Expr VisitExpr_(const CallNode* call) final { Expr VisitExpr_(const CallNode* cn) final {
auto op_node = call->op.as<OpNode>(); auto region = regions_->GetRegion(GetRef<Call>(cn));
auto ret = GetRef<Call>(call); auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e);
// Add compiler ends if the parent isn't annotated
Array<Expr> args; Array<Expr> args;
// Add compiler ends if the parent is supported
for (auto arg : call->args) { for (auto arg : call->args) {
args.push_back(AddCompilerEnd(arg)); args.push_back(InsertEnd(arg));
} }
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { Expr updated_call = Call(call->op, args, call->attrs);
// Skip annotatation ops, only add default compiler to actual compute nodes if (!region.defined()) {
// if the current node does not belong to annotated region
auto region = regions_->GetRegion(ret); // annotate the all incoming edges (args)
if (!region.defined()) { // with "default" compiler_begin annotations.
// if the current node does not belong to annotated region Array<Expr> compiler_begins;
// annotate the all incoming edges (args) for (auto arg : args) {
// with "default" compile_begin and compiler_end annotations. compiler_begins.push_back(InsertBegin(arg));
tvm::Array<tvm::relay::Expr> compiler_begins;
for (auto arg : args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(arg, target_);
compiler_begins.push_back(begin);
}
Expr update_call = Call(call->op, compiler_begins, call->attrs);
return update_call;
} }
updated_call = Call(call->op, compiler_begins, call->attrs);
} else {
annotated_nodes_.insert(updated_call);
} }
return Call(call->op, args, call->attrs); return updated_call;
}; };
Expr VisitExpr_(const TupleNode *op) { Expr VisitExpr_(const TupleNode* op) {
auto region = regions_->GetRegion(GetRef<Tuple>(op));
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto tup = Downcast<Tuple>(new_e); Tuple tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
Array<Expr> fields;
for (auto field : tup->fields) { for (auto field : tup->fields) {
new_fields.push_back(AddCompilerEnd(field)); fields.push_back(InsertEnd(field));
} }
return Tuple(new_fields);
Expr updated_tuple = Tuple(fields);
if (!region.defined()) {
Array<Expr> compiler_begins;
for (const auto& field : fields) {
compiler_begins.push_back(InsertBegin(field));
}
updated_tuple = Tuple(compiler_begins);
} else {
annotated_nodes_.insert(updated_tuple);
}
return updated_tuple;
} }
Expr VisitExpr_(const TupleGetItemNode *op) { Expr VisitExpr_(const TupleGetItemNode* op) {
auto region = regions_->GetRegion(GetRef<TupleGetItem>(op));
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e); auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(AddCompilerEnd(get->tuple), get->index);
auto updated_tuple = InsertEnd(get->tuple);
Expr updated_get = TupleGetItem(updated_tuple, get->index);
if (!region.defined()) {
updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
} else {
annotated_nodes_.insert(updated_get);
}
return updated_get;
} }
Expr VisitExpr_(const LetNode *op) { Expr VisitExpr_(const IfNode* op) {
auto region = regions_->GetRegion(GetRef<If>(op));
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e); auto iff = Downcast<If>(new_e);
return Let(
let->var, if (!region.defined()) {
AddCompilerEnd(let->value), return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)),
AddCompilerEnd(let->body)); InsertBegin(InsertEnd(iff->false_branch)));
} else {
Expr updated_iff =
If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch));
annotated_nodes_.insert(updated_iff);
return updated_iff;
}
} }
Expr VisitExpr_(const IfNode *op) { Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e); auto let = Downcast<Let>(new_e);
return If( return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
AddCompilerEnd(iff->cond),
AddCompilerEnd(iff->true_branch),
AddCompilerEnd(iff->false_branch));
} }
Expr VisitExpr_(const RefCreateNode *op) { Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e); auto create = Downcast<RefCreate>(new_e);
return RefCreate(AddCompilerEnd(create->value)); return RefCreate(InsertEnd(create->value));
} }
Expr VisitExpr_(const RefReadNode *op) { Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e); auto read = Downcast<RefRead>(new_e);
return RefRead(AddCompilerEnd(read->ref)); return RefRead(InsertEnd(read->ref));
} }
Expr VisitExpr_(const RefWriteNode *op) { Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e); auto write = Downcast<RefWrite>(new_e);
return RefWrite( return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
AddCompilerEnd(write->ref),
AddCompilerEnd(write->value));
} }
private: private:
AnnotatedRegionSet regions_; AnnotatedRegionSet regions_;
const std::string target_ = "default"; const std::string target_ = "default";
Function func_; Function func_;
std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
}; };
class MergeAnnotations : public ExprMutator { class MergeAnnotations : public ExprMutator {
...@@ -213,6 +229,14 @@ class MergeAnnotations : public ExprMutator { ...@@ -213,6 +229,14 @@ class MergeAnnotations : public ExprMutator {
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
Expr VisitExpr_(const CallNode* call) final { Expr VisitExpr_(const CallNode* call) final {
// remove 'default' annotations
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
if (call->op == compiler_begin_op) { if (call->op == compiler_begin_op) {
if (call->args[0]->IsInstance<CallNode>()) { if (call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]); auto arg = Downcast<Call>(call->args[0]);
...@@ -220,7 +244,7 @@ class MergeAnnotations : public ExprMutator { ...@@ -220,7 +244,7 @@ class MergeAnnotations : public ExprMutator {
auto region1 = regions_->GetRegion(GetRef<Call>(call)); auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg); auto region2 = regions_->GetRegion(arg);
if (region1 == region2) { if (region1 == region2) {
return ExprMutator::VisitExpr(arg->args[0]); return VisitExpr(arg->args[0]);
} }
} }
} }
...@@ -242,7 +266,6 @@ class RegionMerger : public ExprVisitor { ...@@ -242,7 +266,6 @@ class RegionMerger : public ExprVisitor {
// set the region target // set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler; region_targets_[region->GetID()] = compiler_attrs->compiler;
std::vector<AnnotatedRegion> mergeable_regions;
// first look at the region args to determine the parent regions // first look at the region args to determine the parent regions
for (const auto& arg : region->GetInputs()) { for (const auto& arg : region->GetInputs()) {
// all args should be begin annotations // all args should be begin annotations
...@@ -256,14 +279,21 @@ class RegionMerger : public ExprVisitor { ...@@ -256,14 +279,21 @@ class RegionMerger : public ExprVisitor {
if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]); VisitExpr(begin->args[0]);
} }
}
// get the mergeable regions now all the parents have been visited
std::vector<AnnotatedRegion> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue;
mergeable_regions.push_back(parent_region); mergeable_regions.push_back(parent_region);
} }
auto& region_restrictions = region_restrictions_[region->GetID()]; auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) { for (const auto& parent_region : mergeable_regions) {
// add all the parent restrictions to the current region // add all the parent restrictions to the current region
auto parent_restrictions = region_restrictions_[parent_region->GetID()]; auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
parent_restrictions.end());
} }
for (const auto& parent_region : mergeable_regions) { for (const auto& parent_region : mergeable_regions) {
bool merged = false; bool merged = false;
...@@ -273,7 +303,8 @@ class RegionMerger : public ExprVisitor { ...@@ -273,7 +303,8 @@ class RegionMerger : public ExprVisitor {
if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
// merge the parent region into the current region // merge the parent region into the current region
regions_->MergeRegions(parent_region, region); regions_->MergeRegions(parent_region, region);
// update the restrictions of all other regions to reflect the change in id // update the restrictions of all other regions to reflect the
// change in id
for (const auto& r : regions_) { for (const auto& r : regions_) {
auto& restrictions = region_restrictions_[r->GetID()]; auto& restrictions = region_restrictions_[r->GetID()];
if (restrictions.find(parent_region->GetID()) != restrictions.end()) { if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
...@@ -284,9 +315,9 @@ class RegionMerger : public ExprVisitor { ...@@ -284,9 +315,9 @@ class RegionMerger : public ExprVisitor {
merged = true; merged = true;
} }
} }
// if the parent wasn't merged, add it as a restriction to the current region // if the parent wasn't merged, add it as a restriction to the current
if (!merged) // region
region_restrictions.insert(parent_region->GetID()); if (!merged) region_restrictions.insert(parent_region->GetID());
} }
merged_regions_.insert(region->GetID()); merged_regions_.insert(region->GetID());
} }
...@@ -300,15 +331,14 @@ class RegionMerger : public ExprVisitor { ...@@ -300,15 +331,14 @@ class RegionMerger : public ExprVisitor {
std::map<int, std::string> region_targets_; std::map<int, std::string> region_targets_;
}; };
Expr MergeCompilerRegions(const Expr& expr) { Expr MergeCompilerRegions(const Expr& expr) {
// Annotate all the nodes that aren't annotated as 'default'. // Annotate all the nodes that aren't annotated as 'default'.
AnnotateRestDefault anno_default(expr); AnnotateRestDefault anno_default(expr);
auto expr_all_annotated = anno_default.Annotate(expr); auto expr_all_annotated = anno_default.Annotate(expr);
// Create regions using the annotations. // Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated, AnnotatedRegionSet regions =
compiler_begin_op, compiler_end_op); AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
// By now, all the nodes have some sort of annotation. // By now, all the nodes have some sort of annotation.
// Region merger is an ExprVisitor that will update the // Region merger is an ExprVisitor that will update the
...@@ -336,7 +366,7 @@ Pass MergeCompilerRegions() { ...@@ -336,7 +366,7 @@ Pass MergeCompilerRegions() {
} }
TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
.set_body_typed(transform::MergeCompilerRegions); .set_body_typed(transform::MergeCompilerRegions);
} // namespace transform } // namespace transform
......
...@@ -113,7 +113,8 @@ def test_extern_dnnl(): ...@@ -113,7 +113,8 @@ def test_extern_dnnl():
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end0, "dnnl") end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end1, "dnnl")
begin3 = relay.annotation.compiler_begin(end0, "dnnl") begin3 = relay.annotation.compiler_begin(end0, "dnnl")
begin4 = relay.annotation.compiler_begin(weight1, "dnnl") begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
depthwise_conv2d_2 = relay.nn.conv2d(begin3, depthwise_conv2d_2 = relay.nn.conv2d(begin3,
...@@ -121,11 +122,11 @@ def test_extern_dnnl(): ...@@ -121,11 +122,11 @@ def test_extern_dnnl():
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
begin5 = relay.annotation.compiler_begin(end1, "dnnl") begin5 = relay.annotation.compiler_begin(end2, "dnnl")
out = relay.add(begin2, begin5) out = relay.add(begin2, begin5)
end2 = relay.annotation.compiler_end(out, "dnnl") end3 = relay.annotation.compiler_end(out, "dnnl")
f = relay.Function([data, weight1], end2) f = relay.Function([data, weight1], end3)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
...@@ -137,7 +138,7 @@ def test_extern_dnnl(): ...@@ -137,7 +138,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape) mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod) mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape) ref_mod = expected(dtype, ishape, w1shape)
# tvm.ir.assert_structural_equal(mod, ref_mod) tvm.ir.assert_structural_equal(mod, ref_mod)
def test_run(): def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True): if not tvm.get_global_func("relay.ext.dnnl", True):
......
...@@ -66,13 +66,10 @@ def test_diamond_graph_fanouts(): ...@@ -66,13 +66,10 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1) O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
cb_x = compiler_begin(ce_2, "default") X = relay.tanh(ce_2)
X = relay.tanh(cb_x)
ce_x1 = compiler_end(X, "default")
ce_x2 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test") cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(ce_x1, "test") cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test") ce_4 = compiler_end(O_3, "test")
...@@ -162,36 +159,28 @@ def test_example_graph(): ...@@ -162,36 +159,28 @@ def test_example_graph():
node1 = relay.add(begin2, begin3) node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1) node2 = relay.add(node0, node1)
begin4 = compiler_begin(in_5, "default") node3 = relay.subtract(in_5, in_6)
begin5 = compiler_begin(in_6, "default") node4 = relay.subtract(in_7, node3)
begin6 = compiler_begin(in_7, "default")
node3 = relay.subtract(begin4, begin5)
node4 = relay.subtract(begin6, node3)
end0 = compiler_end(node4, "default")
begin7 = compiler_begin(end0, "test")
begin8 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin7) begin4 = compiler_begin(node4, "test")
begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test") end1 = compiler_end(node5, "test")
begin9 = compiler_begin(end1, "default") node6 = relay.subtract(in_8, end1)
begin10 = compiler_begin(in_8, "default")
node6 = relay.subtract(begin10, begin9)
end2 = compiler_end(node6, "default")
node7 = relay.add(begin8, node5) node7 = relay.add(begin5, node5)
end3 = compiler_end(node7, "test") end2 = compiler_end(node7, "test")
begin11 = compiler_begin(end3, "test") begin6 = compiler_begin(end2, "test")
begin12 = compiler_begin(end2, "test") begin7 = compiler_begin(node6, "test")
node8 = relay.add(begin12, begin11) node8 = relay.add(begin7, begin6)
begin13 = compiler_begin(in_10, "test") begin8 = compiler_begin(in_10, "test")
node9 = relay.add(begin13, node8) node9 = relay.add(begin8, node8)
end4 = compiler_end(node9, "test") end3 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end3)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
......
...@@ -725,12 +725,12 @@ def test_multiple_outputs(): ...@@ -725,12 +725,12 @@ def test_multiple_outputs():
mod = tvm.IRModule() mod = tvm.IRModule()
# function 0 # function 0
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d( conv_o = relay.nn.conv2d(
data=data, data=data,
...@@ -743,7 +743,7 @@ def test_multiple_outputs(): ...@@ -743,7 +743,7 @@ def test_multiple_outputs():
bn_var) bn_var)
relu_o = relay.nn.relu(bn_o[0]) relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o))
func0 = relay.Function([data, weight, bn_gamma, bn_beta, func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o) bn_mean, bn_var], tuple_o)
...@@ -752,8 +752,8 @@ def test_multiple_outputs(): ...@@ -752,8 +752,8 @@ def test_multiple_outputs():
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target")) tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0")) tvm.tir.StringImm("test_target_2"))
gv0 = relay.GlobalVar("test_target_0") gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0 mod[gv0] = func0
# body # body
...@@ -765,9 +765,9 @@ def test_multiple_outputs(): ...@@ -765,9 +765,9 @@ def test_multiple_outputs():
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 0) f0_relu_o = relay.TupleGetItem(f0_o, 2)
f0_mean_o = relay.TupleGetItem(f0_o, 1) f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 2) f0_var_o = relay.TupleGetItem(f0_o, 0)
f0_mean_abs = relay.abs(f0_mean_o) f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o) f0_var_abs = relay.abs(f0_var_o)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment