Unverified Commit f506c8b1 by Cody Yu Committed by GitHub

[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)

* add target to region

* refactor annotate_target

* Make all unit test working

* quick fix

* enable BN, unit test failed

* Fix vm test, unit test. Refactor annotate_target a bit.

* quick fix fusion

* revert fusion change

* style fix

* Refactor merge region pass

* format

* minor fix

* Skip e2e test

* lint

* support AnnotateTarget multiple runs

* Add HasAttr and revert DNNL codegen

* address comment

Co-authored-by: Zhi Chen <chzhi@amazon.com>
parent 5795539c
...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -587,14 +587,14 @@ def PartitionGraph(): ...@@ -587,14 +587,14 @@ def PartitionGraph():
def AnnotateTarget(target): def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then """Annotate ops in an experession with a provied compiler/target and then
use it for codegen. use it for codegen.
Parameters Parameters
---------- ----------
target : String targets : str or List[str]
The target compiler used for codegen. The list of target compilers used for codegen.
Returns Returns
------- -------
...@@ -602,7 +602,9 @@ def AnnotateTarget(target): ...@@ -602,7 +602,9 @@ def AnnotateTarget(target):
The annotated pass that wrapps ops with subgraph_start and The annotated pass that wrapps ops with subgraph_start and
subgraph_end. subgraph_end.
""" """
return _ffi_api.AnnotateTarget(target) if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
def Inline(): def Inline():
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/runtime/container.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -31,7 +32,7 @@ namespace relay { ...@@ -31,7 +32,7 @@ namespace relay {
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) { for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) { if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate; return candidate;
} }
} }
...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
} }
// Merge src to dest and erase src. // Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end()); dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins) { for (const auto& input : src->ins_) {
dest->ins.push_back(input); dest->ins_.push_back(input);
} }
for (const auto& output : src->outs) { for (const auto& output : src->outs_) {
dest->outs.push_back(output); dest->outs_.push_back(output);
} }
// if any of the outputs of src are inputs of dest, they become internal nodes // if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs // so remove them from outs
std::vector<Expr> ins_to_remove; std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]); auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes.end()) { if (it != src->nodes_.end()) {
dest->outs.remove(*it); dest->outs_.remove(*it);
ins_to_remove.push_back(input); ins_to_remove.push_back(input);
} }
} }
for (const auto& input : ins_to_remove) { for (const auto& input : ins_to_remove) {
dest->ins.remove(input); dest->ins_.remove(input);
} }
regions_.erase(src); regions_.erase(src);
} }
...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) ...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) { if (src.defined()) {
MergeRegions(src, dest); MergeRegions(src, dest);
} else { } else {
dest->nodes.insert(expr); dest->nodes_.insert(expr);
} }
} }
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion()); auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++; (*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first; return *ret.first;
} }
class AnnotatedRegionSet::Creator : public ExprVisitor { class AnnotatedRegionSet::Creator : public ExprVisitor {
public: public:
Creator(const Op& region_begin_op, const Op& region_end_op) : Creator(const Op& region_begin_op, const Op& region_end_op)
begin_op_(region_begin_op), end_op_(region_end_op) {} : begin_op_(region_begin_op), end_op_(region_end_op) {}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const CallNode* call) { void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { ...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n" << "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false)); << AsText(GetRef<Call>(call), false));
} }
region->ins.push_back(GetRef<Call>(call)); region->ins_.push_back(GetRef<Call>(call));
} else { } else {
CHECK_EQ(call->op, end_op_); CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region // Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]); auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) { if (!region.defined()) {
region = region_set_->MakeRegion(); // Create a new region if the argument is not belonged to any regions yet.
region->nodes.insert(call->args[0]); region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
} }
region->nodes.insert(GetRef<Call>(call)); region->nodes_.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call)); region->outs_.push_back(GetRef<Call>(call));
} }
ExprVisitor::VisitExpr_(call); ExprVisitor::VisitExpr_(call);
} }
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const TupleNode* op) { void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op)); auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) { if (region.defined()) {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <string> #include <string>
...@@ -49,33 +50,39 @@ class AnnotatedRegionSet; ...@@ -49,33 +50,39 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object { class AnnotatedRegionNode : public Object {
public: public:
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id); v->Visit("id", &id_);
Array<Expr> nodes_array(nodes.begin(), nodes.end()); v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array); v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end()); Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array); v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end()); Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array); v->Visit("rets", &rets_array);
} }
/*! \brief Get the region ID. */ /*! \brief Get the region ID. */
int GetID() const { int GetID() const {
return id; return id_;
}
/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
} }
/*! \brief Get the region's inputs. */ /*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const { std::list<Expr> GetInputs() const {
return ins; return ins_;
} }
/*! \brief Get the region's outputs. */ /*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const { std::list<Expr> GetOutputs() const {
return outs; return outs_;
} }
/*! \brief Get the region's nodes. */ /*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const { std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes; return nodes_;
} }
static constexpr const char* _type_key = "relay.AnnotatedRegion"; static constexpr const char* _type_key = "relay.AnnotatedRegion";
...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object { ...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object {
protected: protected:
/*! \brief The region ID. */ /*! \brief The region ID. */
int id{-1}; int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */ /*! \brief The inputs to this region. */
std::list<Expr> ins; std::list<Expr> ins_;
/*! \brief The outputs of this region */ /*! \brief The outputs of this region */
std::list<Expr> outs; std::list<Expr> outs_;
/*! \brief Nodes in this region. */ /*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes; std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
friend class AnnotatedRegionSet; friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode; friend class AnnotatedRegionSetNode;
...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object { ...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr); void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*! /*!
* \brief Make a new region. * \brief Make a new region for a target.
* *
* \return The new region. * \return The new region.
*/ */
AnnotatedRegion MakeRegion(); AnnotatedRegion MakeRegion(const std::string& target);
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_; std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */ /*! \brief The next region ID to assign. */
......
...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple); // Do nothing
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream; std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID // Args: ID
std::vector<std::string> args; std::vector<std::string> args;
...@@ -103,38 +96,30 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -103,38 +96,30 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
} }
// Analyze the output buffers // Analyze the output buffer
std::vector<Type> out_types; auto type_node = call->checked_type().as<TensorTypeNode>();
if (call->checked_type()->IsInstance<TupleTypeNode>()) { CHECK(type_node);
auto type_node = call->checked_type().as<TupleTypeNode>(); const auto& dtype = GetDtypeString(type_node);
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}
out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
std::string out = "buf_" + std::to_string(buf_idx_++); std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type); auto out_shape = GetShape(call->checked_type());
int out_size = 1; int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) { for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i]; out_size *= out_shape[i];
} }
this->PrintIndents(); this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str()); buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out; decl_stream << ", " << out;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer // Update output buffer
out_.clear();
Output output; Output output;
output.name = out; output.name = out;
output.dtype = dtype; output.dtype = dtype;
...@@ -143,14 +128,6 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -143,14 +128,6 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
out_.push_back(output); out_.push_back(output);
} }
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
}
std::string JIT(void) { std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
} }
......
...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -19,132 +19,203 @@ ...@@ -19,132 +19,203 @@
/*! /*!
* \file src/relay/transforms/annotate_target.cc * \file src/relay/transforms/annotate_target.cc
* \brief Wraps a call with compiler_begin and compiler_end to indicate that * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
* the op of this call node will use external compiler. * this expr should be handled by the external compiler.
*/ */
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace annotate_target { namespace annotate_target {
// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
// A helper class to insert annotation boundaries for a program region that will // A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler. // be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator { class AnnotateTargetWrapper : public ExprMutator {
public: public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
/*!
* \brief This function annotates a compiler end and a compiler begin to all arguments.
*
* The compiler end is based on the arg target while the compiler begin is based on the given
* target. If target is not given and all arguments are going to the same target, then we will
* use that target; otherwise we use default for this op. Note that all arg exprs must be
* available in op_expr_to_target before calling this function.
*
* \param args An array of arguments of the given node.
* \param target The target of the current node.
* \return A pair of target and annotated argument expressions.
*/
std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
const std::string& target = "") {
std::string ref_target = "";
Array<Expr> compiler_ends;
for (auto arg : args) {
std::string arg_target = "defualt";
const CallNode* call = arg.as<CallNode>();
if (call && call->op == compiler_begin_op) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == compiler_end_op) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
} else {
// Input vars.
compiler_ends.push_back(arg);
}
Expr Annotate(const Expr& expr) { // Maintain reference target in case the target of the current node is unassigned.
return InsertEnd(Mutate(expr)); if (ref_target == "") {
ref_target = arg_target;
} else if (ref_target != arg_target) {
ref_target = "default";
}
}
// Determine compiler begin target.
std::string op_target = (target == "") ? ref_target : target;
Array<Expr> compiler_begins;
for (const auto& end : compiler_ends) {
compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
}
return {op_target, compiler_begins};
}
Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
Expr new_op = (*ann_op)(expr, target);
new_op->checked_type_ = expr->checked_type_;
return new_op;
}
Expr VisitExpr_(const CallNode* cn) final {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;
auto op_node = cn->op.as<OpNode>();
// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && cn->op == compiler_begin_op) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(cn->args.size(), 1U);
return VisitExpr(cn->args[0]);
} else if (op_node && cn->op == compiler_end_op) {
// Override compiler end with the new target.
CHECK_EQ(cn->args.size(), 1U);
auto input_expr = VisitExpr(cn->args[0]);
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}
// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = cn->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
}
} }
bool IsSupported(const Expr& expr) { // Check which targets this op can be offloaded.
if (expr->IsInstance<CallNode>()) { if (op_node) {
Call call = Downcast<Call>(expr); // TVM operators: Check target specific op checking function and add to supported_targets
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_); // if it is supported.
if (call->op->IsInstance<OpNode>()) { Op op = Downcast<Op>(cn->op);
Op op = Downcast<Op>(call->op);
CHECK(op.defined()); CHECK(op.defined());
if (fannotate.count(op)) { for (const auto& target : this->targets_) {
return fannotate[op](call->attrs, call->args); if (!Op::HasAttr("target." + std::string(target))) {
continue;
} }
} else if (call->op->IsInstance<FunctionNode>()) { auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
// handle composite functions if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
Function func = Downcast<Function>(call->op); supported_targets.push_back(target);
}
}
} else if (cn->op->IsInstance<FunctionNode>()) {
// Composite function: Add the target of a composite function to supported_targets
// if it is in the target list.
Function func = Downcast<Function>(cn->op);
CHECK(func.defined()); CHECK(func.defined());
auto comp_name = func->GetAttr<String>(attr::kComposite); auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) { if (comp_name.defined()) {
std::string comp_name_str = comp_name; std::string comp_name_str = comp_name;
size_t i = comp_name_str.find('.'); size_t i = comp_name_str.find('.');
if (i != std::string::npos) { if (i != std::string::npos) {
std::string target = comp_name_str.substr(0, i); std::string comp_target = comp_name_str.substr(0, i);
if (target == target_) return true; for (const auto& target : this->targets_) {
} if (std::string(target) == comp_target) {
supported_targets.push_back(comp_target);
break;
} }
} }
} }
if (expr->IsInstance<TupleGetItemNode>()) {
TupleGetItem get = Downcast<TupleGetItem>(expr);
if (get->tuple->IsInstance<CallNode>() &&
get->tuple.as<CallNode>()->op == compiler_begin_op) {
return true;
} }
} }
return false; supported_targets.push_back("default"); // Make default as the last option.
}
Expr InsertEnd(const Expr& arg) { // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
if (IsSupported(arg)) { // the highest priority, but we should preserve all supported targets so that
const auto *end_op = // we can make a better decision.
runtime::Registry::Get("relay.op.annotation._make.compiler_end"); std::string target = supported_targets[0];
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}
Expr VisitExpr_(const CallNode* cn) { // Visit and mutate arguments after the target of this op has been determined.
auto new_e = ExprMutator::VisitExpr_(cn); auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
Call call = Downcast<Call>(new_e); // Add annotations to each arg.
auto target_n_args = AnnotateArgs(new_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args);
Call call = Call(new_call->op, compiler_begins, new_call->attrs);
call->checked_type_ = cn->checked_type_;
// add end annotations if the args are supported // Update the target map.
Array<Expr> compiler_ends; op_expr_to_target_[call] = target;
for (const auto& it : call->args) {
compiler_ends.push_back(InsertEnd(it));
}
call = Call(call->op, compiler_ends, call->attrs);
// add begin annotations if the call node is supported
if (IsSupported(call)) {
tvm::Array<tvm::relay::Expr> compiler_begins;
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
for (const auto& it : call->args) {
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
call = Call(call->op, compiler_begins, call->attrs);
}
return std::move(call); return std::move(call);
} }
Expr VisitExpr_(const TupleNode* op) { Expr VisitExpr_(const TupleNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<Tuple>(new_e);
auto tup = Downcast<Tuple>(new_e); auto target_n_args = AnnotateArgs(expr->fields);
Array<Expr> new_fields; auto new_expr = Tuple(std::get<1>(target_n_args));
for (auto field : tup->fields) { op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
new_fields.push_back(InsertEnd(field)); return std::move(new_expr);
}
return Tuple(new_fields);
} }
Expr VisitExpr_(const TupleGetItemNode* op) { Expr VisitExpr_(const TupleGetItemNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<TupleGetItem>(new_e);
auto get = Downcast<TupleGetItem>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
if (IsSupported(get->tuple)) { auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
const auto* begin_op = op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); return std::move(new_expr);
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
} }
Expr VisitExpr_(const FunctionNode* fn) { Expr VisitExpr_(const FunctionNode* fn) final {
Function func; Function func;
Expr new_body; Expr new_body;
// don't step into composite functions // don't step into composite functions
...@@ -154,84 +225,93 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -154,84 +225,93 @@ class AnnotateTargetWrapper : public ExprMutator {
} else { } else {
auto new_e = ExprMutator::VisitExpr_(fn); auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e); func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body); new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
} }
}
return Function( return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
func->params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} }
Expr VisitExpr_(const LetNode* op) { Expr VisitExpr_(const LetNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e); auto let = Downcast<Let>(new_e);
return Let(
let->var, auto target_n_args = AnnotateArgs({let->value, let->body});
InsertEnd(let->value), auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
InsertEnd(let->body)); op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const IfNode* op) { Expr VisitExpr_(const IfNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<If>(new_e);
auto iff = Downcast<If>(new_e);
return If( auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
InsertEnd(iff->cond), CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
InsertEnd(iff->true_branch), auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
InsertEnd(iff->false_branch)); std::get<1>(target_n_args)[2]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefCreateNode* op) { Expr VisitExpr_(const RefCreateNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefCreate>(new_e);
auto create = Downcast<RefCreate>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
return RefCreate(InsertEnd(create->value)); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefReadNode* op) { Expr VisitExpr_(const RefReadNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefRead>(new_e);
auto read = Downcast<RefRead>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
return RefRead(InsertEnd(read->ref)); auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
} }
Expr VisitExpr_(const RefWriteNode* op) { Expr VisitExpr_(const RefWriteNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefWrite>(new_e);
auto write = Downcast<RefWrite>(new_e); auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
return RefWrite( auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
InsertEnd(write->ref), op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
InsertEnd(write->value)); return std::move(new_expr);
} }
private: private:
std::string target_; /*! \brief The target backends for annotation. */
Array<runtime::String> targets_;
/*! \brief Maintain the decision of the target for each op expr. */
std::unordered_map<Expr, std::string, ObjectHash, ObjectEqual> op_expr_to_target_;
}; };
Expr AnnotateTarget(const Expr& expr, const std::string& target) { Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
return AnnotateTargetWrapper(target).Annotate(expr); return AnnotateTargetWrapper(targets).Mutate(expr);
} }
} // namespace annotate_target } // namespace annotate_target
namespace transform { namespace transform {
Pass AnnotateTarget(const std::string& target) { Pass AnnotateTarget(const Array<runtime::String>& targets) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target)); return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
}; };
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{"InferType"}); {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
} }
TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget") TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget);
.set_body_typed(AnnotateTarget);
} // namespace transform } // namespace transform
......
...@@ -46,182 +46,98 @@ ...@@ -46,182 +46,98 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace partitioning { namespace merge_compiler_region {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to // Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead. // reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*! \brief This is a pre-requisite pass to merge-supported pass. class RegionMerger : public ExprVisitor {
* The AnnotateRestDefault pass will put "default" Compiler Annotations to
* nodes that are not annotated already. This is there to ensure that the
* user will not leave un-annotated nodes MergeCompilerRegions pass is run.
* Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
*/
class AnnotateRestDefault : public ExprMutator {
public: public:
explicit AnnotateRestDefault(const Expr& expr) { explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
}
Expr Annotate(const Expr& expr) {
// Its a function that is being passed on to annotate
func_ = Downcast<Function>(expr);
// Corner Case CC1 : If the last node does not belong void VisitExpr_(const CallNode* call) final {
// to a region node to add a compiler_end if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(func_->body); auto region = regions_->GetRegion(GetRef<Call>(call));
auto mutated_expr = this->VisitExpr(expr);
if (!region.defined()) {
func_ = Downcast<Function>(mutated_expr);
// CC1 : add that compiler end after mutation
auto body = InsertEnd(func_->body);
func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
return Downcast<Expr>(func_);
}
return mutated_expr;
}
/*! \brief This function adds compiler ends to nodes that // Skip this region if it has been merged to the other region.
* don't belong to a region already (default). if (merged_regions_.find(region->GetID()) != merged_regions_.end()) {
* \param expr The expression to add a compiler end to. return;
* \return expr The expression with or without a compiler end added.
*/
Expr InsertEnd(const Expr& expr) {
if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
!expr->IsInstance<ConstantNode>()) {
const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(expr, target_);
return end;
}
return expr;
} }
/*! \brief This function adds compiler begins to nodes that // Check the region target.
* don't belong to a region already (default). auto compiler_attrs = call->attrs.as<CompilerAttrs>();
* \param expr The expression to add a compiler begin to. CHECK_EQ(region->GetTarget(), compiler_attrs->compiler);
* \return expr The expression with or without a compiler begin added.
*/
Expr InsertBegin(const Expr& expr) {
const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(expr, target_);
annotated_nodes_.insert(begin);
return begin;
}
Expr VisitExpr_(const CallNode* cn) final { // Visit the unmerged parent regions.
auto region = regions_->GetRegion(GetRef<Call>(cn)); for (const auto& arg : region->GetInputs()) {
auto new_e = ExprMutator::VisitExpr_(cn); // Region inputs must be begin annotation, and the region of
Call call = Downcast<Call>(new_e); // the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
// Add compiler ends if the parent isn't annotated // Skip this region if it has been merged.
Array<Expr> args; if (!parent_region.defined()) {
for (auto arg : call->args) { continue;
args.push_back(InsertEnd(arg)); } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]);
}
} }
Expr updated_call = Call(call->op, args, call->attrs); // Collect unmerged parent regions.
if (!region.defined()) { std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
// if the current node does not belong to annotated region for (const auto& arg : region->GetInputs()) {
// annotate the all incoming edges (args) auto begin = Downcast<Call>(arg);
// with "default" compiler_begin annotations. CHECK_EQ(begin->op, compiler_begin_op);
Array<Expr> compiler_begins; auto parent_region = regions_->GetRegion(begin->args[0]);
for (auto arg : args) { if (parent_region.defined()) {
compiler_begins.push_back(InsertBegin(arg)); mergeable_regions.insert(parent_region);
} }
updated_call = Call(call->op, compiler_begins, call->attrs);
} else {
annotated_nodes_.insert(updated_call);
} }
return updated_call;
};
Expr VisitExpr_(const TupleNode* op) { // Propogate all the parent restrictions to the current region.
auto region = regions_->GetRegion(GetRef<Tuple>(op)); auto& region_restrictions = region_restrictions_[region->GetID()];
auto new_e = ExprMutator::VisitExpr_(op); for (const auto& parent_region : mergeable_regions) {
Tuple tup = Downcast<Tuple>(new_e); auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
Array<Expr> fields;
for (auto field : tup->fields) {
fields.push_back(InsertEnd(field));
} }
Expr updated_tuple = Tuple(fields); for (const auto& parent_region : mergeable_regions) {
if (!region.defined()) { // Skip the parent region with a different target.
Array<Expr> compiler_begins; if (parent_region->GetTarget() != compiler_attrs->compiler) {
for (const auto& field : fields) { region_restrictions.insert(parent_region->GetID());
compiler_begins.push_back(InsertBegin(field)); continue;
}
updated_tuple = Tuple(compiler_begins);
} else {
annotated_nodes_.insert(updated_tuple);
}
return updated_tuple;
} }
Expr VisitExpr_(const TupleGetItemNode* op) { // Skip the parent region if it is in the restriction set.
auto region = regions_->GetRegion(GetRef<TupleGetItem>(op)); if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) {
auto new_e = ExprMutator::VisitExpr_(op); continue;
auto get = Downcast<TupleGetItem>(new_e);
auto updated_tuple = InsertEnd(get->tuple);
Expr updated_get = TupleGetItem(updated_tuple, get->index);
if (!region.defined()) {
updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
} else {
annotated_nodes_.insert(updated_get);
}
return updated_get;
} }
Expr VisitExpr_(const IfNode* op) { // Merge the parent region to the current one.
auto region = regions_->GetRegion(GetRef<If>(op)); regions_->MergeRegions(parent_region, region);
auto new_e = ExprMutator::VisitExpr_(op);
auto iff = Downcast<If>(new_e);
if (!region.defined()) { // Replace the parent region ID with the current region for all
return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)), // other regions' restriction sets.
InsertBegin(InsertEnd(iff->false_branch))); for (const auto& r : regions_) {
} else { auto& restrictions = region_restrictions_[r->GetID()];
Expr updated_iff = if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch)); restrictions.erase(parent_region->GetID());
annotated_nodes_.insert(updated_iff); restrictions.insert(region->GetID());
return updated_iff;
} }
} }
Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
} }
merged_regions_.insert(region->GetID());
Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}
Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
} }
ExprVisitor::VisitExpr_(call);
Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
auto write = Downcast<RefWrite>(new_e);
return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
} }
private: private:
AnnotatedRegionSet regions_; AnnotatedRegionSet regions_;
const std::string target_ = "default"; std::unordered_set<int> merged_regions_;
Function func_; std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
}; };
class MergeAnnotations : public ExprMutator { class MergeAnnotations : public ExprMutator {
...@@ -229,16 +145,10 @@ class MergeAnnotations : public ExprMutator { ...@@ -229,16 +145,10 @@ class MergeAnnotations : public ExprMutator {
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
Expr VisitExpr_(const CallNode* call) final { Expr VisitExpr_(const CallNode* call) final {
// remove 'default' annotations
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
// Merge annotations which are now internal to a region. // Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a // This happens if we see a compiler begin next to a
// compiler end and they're both in the same region. // compiler end and they're both in the same region.
if (call->op == compiler_begin_op) { if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
if (call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]); auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) { if (arg->op == compiler_end_op) {
auto region1 = regions_->GetRegion(GetRef<Call>(call)); auto region1 = regions_->GetRegion(GetRef<Call>(call));
...@@ -248,7 +158,6 @@ class MergeAnnotations : public ExprMutator { ...@@ -248,7 +158,6 @@ class MergeAnnotations : public ExprMutator {
} }
} }
} }
}
return ExprMutator::VisitExpr_(call); return ExprMutator::VisitExpr_(call);
} }
...@@ -256,114 +165,30 @@ class MergeAnnotations : public ExprMutator { ...@@ -256,114 +165,30 @@ class MergeAnnotations : public ExprMutator {
AnnotatedRegionSet regions_; AnnotatedRegionSet regions_;
}; };
class RegionMerger : public ExprVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
auto region = regions_->GetRegion(GetRef<Call>(call));
if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
// set the region target
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
region_targets_[region->GetID()] = compiler_attrs->compiler;
// first look at the region args to determine the parent regions
for (const auto& arg : region->GetInputs()) {
// all args should be begin annotations
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
// the arguments of the begin annotations will be in the parent regions
auto parent_region = regions_->GetRegion(begin->args[0]);
// if there is no parent region, move on
if (!parent_region.defined()) continue;
// merge the parent region if it hasn't been done already
if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
VisitExpr(begin->args[0]);
}
}
// get the mergeable regions now all the parents have been visited
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
auto parent_region = regions_->GetRegion(begin->args[0]);
if (!parent_region.defined()) continue;
mergeable_regions.insert(parent_region);
}
auto& region_restrictions = region_restrictions_[region->GetID()];
for (const auto& parent_region : mergeable_regions) {
// add all the parent restrictions to the current region
auto parent_restrictions = region_restrictions_[parent_region->GetID()];
region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
}
for (const auto& parent_region : mergeable_regions) {
bool merged = false;
// check the parent region has the same target
if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) {
// check the parent region isn't in the restrictions
if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
// merge the parent region into the current region
regions_->MergeRegions(parent_region, region);
// update the restrictions of all other regions to reflect the
// change in id
for (const auto& r : regions_) {
auto& restrictions = region_restrictions_[r->GetID()];
if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
restrictions.erase(parent_region->GetID());
restrictions.insert(region->GetID());
}
}
merged = true;
}
}
// if the parent wasn't merged, add it as a restriction to the current
// region
if (!merged) region_restrictions.insert(parent_region->GetID());
}
merged_regions_.insert(region->GetID());
}
ExprVisitor::VisitExpr_(call);
}
private:
AnnotatedRegionSet regions_;
std::unordered_set<int> merged_regions_;
std::map<int, std::unordered_set<int>> region_restrictions_;
std::map<int, std::string> region_targets_;
};
Expr MergeCompilerRegions(const Expr& expr) { Expr MergeCompilerRegions(const Expr& expr) {
// Annotate all the nodes that aren't annotated as 'default'.
AnnotateRestDefault anno_default(expr);
auto expr_all_annotated = anno_default.Annotate(expr);
// Create regions using the annotations. // Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
// By now, all the nodes have some sort of annotation. // Analyze the graph to explore the opportunities of merging regions.
// Region merger is an ExprVisitor that will update the
// AnnotatedRegionSet, merging all the regions that can be merged.
RegionMerger merger(regions); RegionMerger merger(regions);
merger.VisitExpr(expr_all_annotated); merger.VisitExpr(expr);
// This updates the expression to remove annotations that are now // Remove annotations that are not in the region boundaries.
// 'internal' to a merged region.
MergeAnnotations merge_anno(regions); MergeAnnotations merge_anno(regions);
return merge_anno.Mutate(expr_all_annotated); return merge_anno.Mutate(expr);
} }
} // namespace partitioning } // namespace merge_compiler_region
namespace transform { namespace transform {
Pass MergeCompilerRegions() { Pass MergeCompilerRegions() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::MergeCompilerRegions(f)); return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f));
}; };
auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
return Sequential({partitioned, InferType()}); return Sequential({merged, InferType()});
} }
TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
......
...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator { ...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator {
IRModule module_; IRModule module_;
}; };
class DefaultRemover : public ExprMutator {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
return ExprMutator::VisitExpr_(call);
}
private:
IRModule module_;
};
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; [=](IRModule m, PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_, float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p ...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance, float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import tvm
from tvm import relay from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
def check_region(region_set, args, nodes, rets): def check_region(region_set, target, args, nodes, rets):
region = region_set.get_region(args[0]) region = region_set.get_region(args[0])
assert region assert region
assert target == region.target
assert set(args) == set(region.args) assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes) assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets) assert set(rets) == set(region.rets)
...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond(): ...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond():
assert len(region_set) == 4 assert len(region_set) == 4
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, ce_1, ce_2], [cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2], [ce_1, ce_2],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_2], [cb_2],
[cb_2, O_2, ce_3], [cb_2, O_2, ce_3],
[ce_3], [ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, ce_4],
[ce_4], [ce_4],
...@@ -88,7 +94,9 @@ def test_region_set_creator_merged(): ...@@ -88,7 +94,9 @@ def test_region_set_creator_merged():
cb_3 = compiler_begin(ce_3, 'test_target') cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target') cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target') O_4 = relay.add(cb_3, cb_4)
O_5 = relay.Tuple([O_3, O_4])
ce_4 = compiler_end(O_5, 'test_target')
merged = relay.Function([data], ce_4) merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged, region_set = relay.analysis.AnnotatedRegionSet(merged,
...@@ -97,20 +105,23 @@ def test_region_set_creator_merged(): ...@@ -97,20 +105,23 @@ def test_region_set_creator_merged():
assert len(region_set) == 3 assert len(region_set) == 3
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, O_2, ce_2, ce_3], [cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3], [ce_2, ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, O_4, O_5, ce_4],
[ce_4], [ce_4],
) )
...@@ -118,4 +129,3 @@ def test_region_set_creator_merged(): ...@@ -118,4 +129,3 @@ def test_region_set_creator_merged():
if __name__ == "__main__": if __name__ == "__main__":
test_region_set_creator_diamond() test_region_set_creator_diamond()
test_region_set_creator_merged() test_region_set_creator_merged()
...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet(): ...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test") def test_multiple_ends():
def relu(attrs, args): @reg.register("nn.relu", "target.test")
def relu(attrs, args): # pylint: disable=unused-variable
return True return True
def test_multiple_ends():
def before(): def before():
x = relay.var("x", shape=(10, 10)) x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x) r = relay.nn.relu(x)
...@@ -208,10 +207,17 @@ def test_multiple_ends(): ...@@ -208,10 +207,17 @@ def test_multiple_ends():
r = relay.nn.relu(cb_1) r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test") ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test") ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1) cb_2 = relay.annotation.compiler_begin(ce_1, "default")
a_2 = relay.abs(ce_2) cb_3 = relay.annotation.compiler_begin(ce_2, "default")
out = relay.add(a_1, a_2) a_1 = relay.abs(cb_2)
f = relay.Function([x], out) a_2 = relay.abs(cb_3)
ce_3 = relay.annotation.compiler_end(a_1, "default")
ce_4 = relay.annotation.compiler_end(a_2, "default")
cb_4 = relay.annotation.compiler_begin(ce_3, "default")
cb_5 = relay.annotation.compiler_begin(ce_4, "default")
out = relay.add(cb_4, cb_5)
ce_6 = relay.annotation.compiler_end(out, "default")
f = relay.Function([x], ce_6)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
...@@ -220,6 +226,72 @@ def test_multiple_ends(): ...@@ -220,6 +226,72 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_type_propagation():
target = "test_type_propagation"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return args[0].checked_type.dtype == "float32"
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
out = relay.nn.relu(r)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
# If the type isn't propogated, then the relu checker function will fail to get the dtype.
assert transform.AnnotateTarget(target)(before())
def test_tuple():
target = "test_tuple"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("concatenate", "target." + target)
def concatenate(attrs, args): # pylint: disable=unused-variable
return True
"""Test that TupleNode is included in annotation when surrounded by supported nodes."""
def before():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.nn.relu(y)
out = relay.concatenate((a_1, a_2), axis=1)
f = relay.Function([x, y], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
cb_1 = relay.annotation.compiler_begin(x, target)
cb_2 = relay.annotation.compiler_begin(y, target)
a_1 = relay.nn.relu(cb_1)
a_2 = relay.nn.relu(cb_2)
ce_1 = relay.annotation.compiler_end(a_1, target)
ce_2 = relay.annotation.compiler_end(a_2, target)
cb_3 = relay.annotation.compiler_begin(ce_1, target)
cb_4 = relay.annotation.compiler_begin(ce_2, target)
tup = relay.Tuple([cb_3, cb_4])
ce_3 = relay.annotation.compiler_end(tup, target)
cb_3 = relay.annotation.compiler_begin(ce_3, target)
out = relay.op._make.concatenate(cb_3, 1)
ce_4 = relay.annotation.compiler_end(out, target)
f = relay.Function([x, y], ce_4)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_composite_function(): def test_composite_function():
def before(): def before():
a = relay.var('a', shape=(10, 10)) a = relay.var('a', shape=(10, 10))
...@@ -265,8 +337,37 @@ def test_composite_function(): ...@@ -265,8 +337,37 @@ def test_composite_function():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_multiple_runs():
@reg.register("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("add", "target.B")
def add(attrs, args): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
mod = transform.AnnotateTarget("A")(before())
mod = transform.AnnotateTarget("B")(mod)
expected = transform.AnnotateTarget(["A", "B"])(before())
assert tvm.ir.structural_equal(expected, mod)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
#test_extern_dnnl_mobilenet()
test_composite_function() test_composite_function()
#test_extern_dnnl_mobilenet()
test_multiple_ends()
test_type_propagation()
test_tuple()
test_multiple_runs()
...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts(): ...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts():
X = not supported by target X = not supported by target
O O O O
/ \ / \ / \\ / \\
O X --> O + + X O X --> O + + X
\ / \ / \\ / \\ /
O O O O
Note that we can't just merge the three supported operators together, Note that we can't just merge the three supported operators together,
...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts(): ...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts():
ce_1 = compiler_end(O_1, "test") ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test") ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test") cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2) O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test") X = relay.tanh(cb_3)
cb_4 = compiler_begin(X, "test") ce_4 = compiler_end(X, "default")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_4) cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond return diamond
def expected(): def expected():
...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts(): ...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1) O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2) cb_3 = compiler_begin(ce_2, "default")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test") cb_4 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test") cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_4, cb_5)
ce_4 = compiler_end(O_3, "test") ce_5 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4) func = relay.Function([data], ce_5)
return func return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
...@@ -85,7 +90,7 @@ def test_example_graph(): ...@@ -85,7 +90,7 @@ def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC. """This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts. Blue nodes are adds (target: test), red nodes are subtracts (target: default).
""" """
def annotated(): def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
...@@ -112,21 +117,30 @@ def test_example_graph(): ...@@ -112,21 +117,30 @@ def test_example_graph():
node2 = relay.add(begin4, begin5) node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test") end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test") begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7) node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test") end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test") end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3) dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test") begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test") begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9) node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test") end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test") dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test") begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11) node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test") end6 = compiler_end(node8, "test")
...@@ -159,20 +173,27 @@ def test_example_graph(): ...@@ -159,20 +173,27 @@ def test_example_graph():
node1 = relay.add(begin2, begin3) node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1) node2 = relay.add(node0, node1)
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
dbegin2 = compiler_begin(in_7, "default")
node3 = relay.subtract(dbegin0, dbegin1)
node4 = relay.subtract(dbegin2, node3)
dend0 = compiler_end(node4, "default")
begin4 = compiler_begin(node4, "test") begin4 = compiler_begin(dend0, "test")
begin5 = compiler_begin(in_9, "test") begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4) node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test") end1 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end1) dbegin4 = compiler_begin(end1, "default")
dbegin5 = compiler_begin(in_8, "default")
node6 = relay.subtract(dbegin5, dbegin4)
dend1 = compiler_end(node6, "default")
node7 = relay.add(begin5, node5) node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test") end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node6, "test") begin7 = compiler_begin(dend1, "test")
node8 = relay.add(begin7, begin6) node8 = relay.add(begin7, begin6)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Unit tests for graph partitioning.""" """Unit tests for graph partitioning."""
import os import os
import sys import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -26,8 +27,12 @@ from tvm import relay ...@@ -26,8 +27,12 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0) @transform.function_pass(opt_level=0)
...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib return lib
def check_vm_result(): def check_vm_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params) exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save() code, lib = exe.save()
...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result(): def check_graph_runtime_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params) json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib) lib = update_lib(lib)
...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet(): ...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype) i_data = np.random.uniform(0, 1, ishape).astype(dtype)
...@@ -851,6 +858,7 @@ if __name__ == "__main__": ...@@ -851,6 +858,7 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops() test_extern_ccompiler_default_ops()
test_extern_ccompiler() test_extern_ccompiler()
test_extern_dnnl() test_extern_dnnl()
# TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
#test_extern_dnnl_mobilenet() #test_extern_dnnl_mobilenet()
test_function_lifting() test_function_lifting()
test_function_lifting_inline() test_function_lifting_inline()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment