Unverified Commit f506c8b1 by Cody Yu Committed by GitHub

[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)

* add target to region

* refactor annotate_target

* Make all unit test working

* quick fix

* enable BN, unit test failed

* Fix vm test, unit test. Refactor annotate_target a bit.

* quick fix fusion

* revert fusion change

* style fix

* Refactor merge region pass

* format

* minor fix

* Skip e2e test

* lint

* support AnnotateTarget multiple runs

* Add HasAttr and revert DNNL codegen

* address comment

Co-authored-by: Zhi Chen <chzhi@amazon.com>
parent 5795539c
...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True): ...@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense") _register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu") _register_external_op_helper("nn.relu")
_register_external_op_helper("add") _register_external_op_helper("add")
_register_external_op_helper("subtract") _register_external_op_helper("subtract")
_register_external_op_helper("multiply") _register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
...@@ -587,14 +587,14 @@ def PartitionGraph(): ...@@ -587,14 +587,14 @@ def PartitionGraph():
def AnnotateTarget(target): def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then """Annotate ops in an experession with a provied compiler/target and then
use it for codegen. use it for codegen.
Parameters Parameters
---------- ----------
target : String targets : str or List[str]
The target compiler used for codegen. The list of target compilers used for codegen.
Returns Returns
------- -------
...@@ -602,7 +602,9 @@ def AnnotateTarget(target): ...@@ -602,7 +602,9 @@ def AnnotateTarget(target):
The annotated pass that wrapps ops with subgraph_start and The annotated pass that wrapps ops with subgraph_start and
subgraph_end. subgraph_end.
""" """
return _ffi_api.AnnotateTarget(target) if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
def Inline(): def Inline():
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/runtime/container.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -31,7 +32,7 @@ namespace relay { ...@@ -31,7 +32,7 @@ namespace relay {
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) { for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) { if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate; return candidate;
} }
} }
...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
} }
// Merge src to dest and erase src. // Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end()); dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins) { for (const auto& input : src->ins_) {
dest->ins.push_back(input); dest->ins_.push_back(input);
} }
for (const auto& output : src->outs) { for (const auto& output : src->outs_) {
dest->outs.push_back(output); dest->outs_.push_back(output);
} }
// if any of the outputs of src are inputs of dest, they become internal nodes // if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs // so remove them from outs
std::vector<Expr> ins_to_remove; std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) { for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input); auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]); auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes.end()) { if (it != src->nodes_.end()) {
dest->outs.remove(*it); dest->outs_.remove(*it);
ins_to_remove.push_back(input); ins_to_remove.push_back(input);
} }
} }
for (const auto& input : ins_to_remove) { for (const auto& input : ins_to_remove) {
dest->ins.remove(input); dest->ins_.remove(input);
} }
regions_.erase(src); regions_.erase(src);
} }
...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) ...@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) { if (src.defined()) {
MergeRegions(src, dest); MergeRegions(src, dest);
} else { } else {
dest->nodes.insert(expr); dest->nodes_.insert(expr);
} }
} }
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion()); auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++; (*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first; return *ret.first;
} }
class AnnotatedRegionSet::Creator : public ExprVisitor { class AnnotatedRegionSet::Creator : public ExprVisitor {
public: public:
Creator(const Op& region_begin_op, const Op& region_end_op) : Creator(const Op& region_begin_op, const Op& region_end_op)
begin_op_(region_begin_op), end_op_(region_end_op) {} : begin_op_(region_begin_op), end_op_(region_end_op) {}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const CallNode* call) { void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>(); auto op_node = call->op.as<OpNode>();
...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { ...@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n" << "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false)); << AsText(GetRef<Call>(call), false));
} }
region->ins.push_back(GetRef<Call>(call)); region->ins_.push_back(GetRef<Call>(call));
} else { } else {
CHECK_EQ(call->op, end_op_); CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument. // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U); CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region // Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]); auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) { if (!region.defined()) {
region = region_set_->MakeRegion(); // Create a new region if the argument is not belonged to any regions yet.
region->nodes.insert(call->args[0]); region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
} }
region->nodes.insert(GetRef<Call>(call)); region->nodes_.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call)); region->outs_.push_back(GetRef<Call>(call));
} }
ExprVisitor::VisitExpr_(call); ExprVisitor::VisitExpr_(call);
} }
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const TupleNode* op) { void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op)); auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) { if (region.defined()) {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <string> #include <string>
...@@ -49,33 +50,39 @@ class AnnotatedRegionSet; ...@@ -49,33 +50,39 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object { class AnnotatedRegionNode : public Object {
public: public:
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id); v->Visit("id", &id_);
Array<Expr> nodes_array(nodes.begin(), nodes.end()); v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array); v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end()); Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array); v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end()); Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array); v->Visit("rets", &rets_array);
} }
/*! \brief Get the region ID. */ /*! \brief Get the region ID. */
int GetID() const { int GetID() const {
return id; return id_;
}
/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
} }
/*! \brief Get the region's inputs. */ /*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const { std::list<Expr> GetInputs() const {
return ins; return ins_;
} }
/*! \brief Get the region's outputs. */ /*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const { std::list<Expr> GetOutputs() const {
return outs; return outs_;
} }
/*! \brief Get the region's nodes. */ /*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const { std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes; return nodes_;
} }
static constexpr const char* _type_key = "relay.AnnotatedRegion"; static constexpr const char* _type_key = "relay.AnnotatedRegion";
...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object { ...@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object {
protected: protected:
/*! \brief The region ID. */ /*! \brief The region ID. */
int id{-1}; int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */ /*! \brief The inputs to this region. */
std::list<Expr> ins; std::list<Expr> ins_;
/*! \brief The outputs of this region */ /*! \brief The outputs of this region */
std::list<Expr> outs; std::list<Expr> outs_;
/*! \brief Nodes in this region. */ /*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes; std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
friend class AnnotatedRegionSet; friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode; friend class AnnotatedRegionSetNode;
...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object { ...@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr); void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*! /*!
* \brief Make a new region. * \brief Make a new region for a target.
* *
* \return The new region. * \return The new region.
*/ */
AnnotatedRegion MakeRegion(); AnnotatedRegion MakeRegion(const std::string& target);
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_; std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */ /*! \brief The next region ID to assign. */
......
...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple); // Do nothing
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
} }
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream; std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID // Args: ID
std::vector<std::string> args; std::vector<std::string> args;
...@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
} }
// Analyze the output buffers // Analyze the output buffer
std::vector<Type> out_types; auto type_node = call->checked_type().as<TensorTypeNode>();
if (call->checked_type()->IsInstance<TupleTypeNode>()) { CHECK(type_node);
auto type_node = call->checked_type().as<TupleTypeNode>(); const auto& dtype = GetDtypeString(type_node);
for (auto field : type_node->fields) { std::string out = "buf_" + std::to_string(buf_idx_++);
CHECK(field->IsInstance<TensorTypeNode>()); auto out_shape = GetShape(call->checked_type());
out_types.push_back(field); int out_size = 1;
} for (size_t i = 0; i < out_shape.size(); ++i) {
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) { out_size *= out_shape[i];
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}
out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments // Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
...@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ...@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
} }
decl_stream << ");"; decl_stream << ");";
ext_func_body.push_back(decl_stream.str()); ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
} }
std::string JIT(void) { std::string JIT(void) {
......
...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this // Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation // pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op` // pass. This is because memory allocation pass will insert `invoke_tvm_op`
...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen. // external codegen.
pass_seqs.push_back(transform::Inline()); pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions. // Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator { ...@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator {
IRModule module_; IRModule module_;
}; };
class DefaultRemover : public ExprMutator {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
return ExprMutator::VisitExpr_(call);
}
private:
IRModule module_;
};
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; [=](IRModule m, PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, ...@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory); read_from_dnnl_memory(out, dst_memory);
} }
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_, float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) { int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag; using tag = memory::format_tag;
using dt = memory::data_type; using dt = memory::data_type;
......
...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p ...@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance, float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_); int p_h_, int p_w_);
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import tvm
from tvm import relay from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
def check_region(region_set, args, nodes, rets): def check_region(region_set, target, args, nodes, rets):
region = region_set.get_region(args[0]) region = region_set.get_region(args[0])
assert region assert region
assert target == region.target
assert set(args) == set(region.args) assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes) assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets) assert set(rets) == set(region.rets)
...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond(): ...@@ -51,24 +53,28 @@ def test_region_set_creator_diamond():
assert len(region_set) == 4 assert len(region_set) == 4
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, ce_1, ce_2], [cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2], [ce_1, ce_2],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_2], [cb_2],
[cb_2, O_2, ce_3], [cb_2, O_2, ce_3],
[ce_3], [ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, ce_4],
[ce_4], [ce_4],
...@@ -88,7 +94,9 @@ def test_region_set_creator_merged(): ...@@ -88,7 +94,9 @@ def test_region_set_creator_merged():
cb_3 = compiler_begin(ce_3, 'test_target') cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target') cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target') O_4 = relay.add(cb_3, cb_4)
O_5 = relay.Tuple([O_3, O_4])
ce_4 = compiler_end(O_5, 'test_target')
merged = relay.Function([data], ce_4) merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged, region_set = relay.analysis.AnnotatedRegionSet(merged,
...@@ -97,20 +105,23 @@ def test_region_set_creator_merged(): ...@@ -97,20 +105,23 @@ def test_region_set_creator_merged():
assert len(region_set) == 3 assert len(region_set) == 3
check_region( check_region(
region_set, region_set,
'test_target',
[cb_1], [cb_1],
[cb_1, O_1, O_2, ce_2, ce_3], [cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3], [ce_2, ce_3],
) )
check_region( check_region(
region_set, region_set,
'default',
[cb_d], [cb_d],
[cb_d, X, ce_d], [cb_d, X, ce_d],
[ce_d], [ce_d],
) )
check_region( check_region(
region_set, region_set,
'test_target',
[cb_3, cb_4], [cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4], [cb_3, cb_4, O_3, O_4, O_5, ce_4],
[ce_4], [ce_4],
) )
...@@ -118,4 +129,3 @@ def test_region_set_creator_merged(): ...@@ -118,4 +129,3 @@ def test_region_set_creator_merged():
if __name__ == "__main__": if __name__ == "__main__":
test_region_set_creator_diamond() test_region_set_creator_diamond()
test_region_set_creator_merged() test_region_set_creator_merged()
...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet(): ...@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True
def test_multiple_ends(): def test_multiple_ends():
@reg.register("nn.relu", "target.test")
def relu(attrs, args): # pylint: disable=unused-variable
return True
def before(): def before():
x = relay.var("x", shape=(10, 10)) x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x) r = relay.nn.relu(x)
...@@ -208,10 +207,17 @@ def test_multiple_ends(): ...@@ -208,10 +207,17 @@ def test_multiple_ends():
r = relay.nn.relu(cb_1) r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test") ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test") ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1) cb_2 = relay.annotation.compiler_begin(ce_1, "default")
a_2 = relay.abs(ce_2) cb_3 = relay.annotation.compiler_begin(ce_2, "default")
out = relay.add(a_1, a_2) a_1 = relay.abs(cb_2)
f = relay.Function([x], out) a_2 = relay.abs(cb_3)
ce_3 = relay.annotation.compiler_end(a_1, "default")
ce_4 = relay.annotation.compiler_end(a_2, "default")
cb_4 = relay.annotation.compiler_begin(ce_3, "default")
cb_5 = relay.annotation.compiler_begin(ce_4, "default")
out = relay.add(cb_4, cb_5)
ce_6 = relay.annotation.compiler_end(out, "default")
f = relay.Function([x], ce_6)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
...@@ -220,6 +226,72 @@ def test_multiple_ends(): ...@@ -220,6 +226,72 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_type_propagation():
target = "test_type_propagation"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return args[0].checked_type.dtype == "float32"
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
out = relay.nn.relu(r)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
# If the type isn't propogated, then the relu checker function will fail to get the dtype.
assert transform.AnnotateTarget(target)(before())
def test_tuple():
target = "test_tuple"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("concatenate", "target." + target)
def concatenate(attrs, args): # pylint: disable=unused-variable
return True
"""Test that TupleNode is included in annotation when surrounded by supported nodes."""
def before():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.nn.relu(y)
out = relay.concatenate((a_1, a_2), axis=1)
f = relay.Function([x, y], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
cb_1 = relay.annotation.compiler_begin(x, target)
cb_2 = relay.annotation.compiler_begin(y, target)
a_1 = relay.nn.relu(cb_1)
a_2 = relay.nn.relu(cb_2)
ce_1 = relay.annotation.compiler_end(a_1, target)
ce_2 = relay.annotation.compiler_end(a_2, target)
cb_3 = relay.annotation.compiler_begin(ce_1, target)
cb_4 = relay.annotation.compiler_begin(ce_2, target)
tup = relay.Tuple([cb_3, cb_4])
ce_3 = relay.annotation.compiler_end(tup, target)
cb_3 = relay.annotation.compiler_begin(ce_3, target)
out = relay.op._make.concatenate(cb_3, 1)
ce_4 = relay.annotation.compiler_end(out, target)
f = relay.Function([x, y], ce_4)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_composite_function(): def test_composite_function():
def before(): def before():
a = relay.var('a', shape=(10, 10)) a = relay.var('a', shape=(10, 10))
...@@ -265,8 +337,37 @@ def test_composite_function(): ...@@ -265,8 +337,37 @@ def test_composite_function():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_multiple_runs():
@reg.register("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("add", "target.B")
def add(attrs, args): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
mod = transform.AnnotateTarget("A")(before())
mod = transform.AnnotateTarget("B")(mod)
expected = transform.AnnotateTarget(["A", "B"])(before())
assert tvm.ir.structural_equal(expected, mod)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
#test_extern_dnnl_mobilenet()
test_composite_function() test_composite_function()
#test_extern_dnnl_mobilenet()
test_multiple_ends()
test_type_propagation()
test_tuple()
test_multiple_runs()
...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts(): ...@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts():
X = not supported by target X = not supported by target
O O O O
/ \ / \ / \\ / \\
O X --> O + + X O X --> O + + X
\ / \ / \\ / \\ /
O O O O
Note that we can't just merge the three supported operators together, Note that we can't just merge the three supported operators together,
...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts(): ...@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts():
ce_1 = compiler_end(O_1, "test") ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test") ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test") cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2) O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test") X = relay.tanh(cb_3)
cb_4 = compiler_begin(X, "test") ce_4 = compiler_end(X, "default")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_4) cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond return diamond
def expected(): def expected():
...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts(): ...@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1) O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2) cb_3 = compiler_begin(ce_2, "default")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test") cb_4 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test") cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_4, cb_5)
ce_4 = compiler_end(O_3, "test") ce_5 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4) func = relay.Function([data], ce_5)
return func return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
...@@ -85,7 +90,7 @@ def test_example_graph(): ...@@ -85,7 +90,7 @@ def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC. """This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts. Blue nodes are adds (target: test), red nodes are subtracts (target: default).
""" """
def annotated(): def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
...@@ -112,21 +117,30 @@ def test_example_graph(): ...@@ -112,21 +117,30 @@ def test_example_graph():
node2 = relay.add(begin4, begin5) node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test") end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test") begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7) node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test") end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test") end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3) dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test") begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test") begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9) node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test") end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test") dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test") begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11) node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test") end6 = compiler_end(node8, "test")
...@@ -159,20 +173,27 @@ def test_example_graph(): ...@@ -159,20 +173,27 @@ def test_example_graph():
node1 = relay.add(begin2, begin3) node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1) node2 = relay.add(node0, node1)
node3 = relay.subtract(in_5, in_6) dbegin0 = compiler_begin(in_5, "default")
node4 = relay.subtract(in_7, node3) dbegin1 = compiler_begin(in_6, "default")
dbegin2 = compiler_begin(in_7, "default")
node3 = relay.subtract(dbegin0, dbegin1)
node4 = relay.subtract(dbegin2, node3)
dend0 = compiler_end(node4, "default")
begin4 = compiler_begin(node4, "test") begin4 = compiler_begin(dend0, "test")
begin5 = compiler_begin(in_9, "test") begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4) node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test") end1 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end1) dbegin4 = compiler_begin(end1, "default")
dbegin5 = compiler_begin(in_8, "default")
node6 = relay.subtract(dbegin5, dbegin4)
dend1 = compiler_end(node6, "default")
node7 = relay.add(begin5, node5) node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test") end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test") begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node6, "test") begin7 = compiler_begin(dend1, "test")
node8 = relay.add(begin7, begin6) node8 = relay.add(begin7, begin6)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Unit tests for graph partitioning.""" """Unit tests for graph partitioning."""
import os import os
import sys import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -26,8 +27,12 @@ from tvm import relay ...@@ -26,8 +27,12 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0) @transform.function_pass(opt_level=0)
...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib return lib
def check_vm_result(): def check_vm_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params) exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save() code, lib = exe.save()
...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result(): def check_graph_runtime_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params) json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib) lib = update_lib(lib)
...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet(): ...@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload( mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32') batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod) mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype) i_data = np.random.uniform(0, 1, ishape).astype(dtype)
...@@ -851,6 +858,7 @@ if __name__ == "__main__": ...@@ -851,6 +858,7 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops() test_extern_ccompiler_default_ops()
test_extern_ccompiler() test_extern_ccompiler()
test_extern_dnnl() test_extern_dnnl()
# TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
#test_extern_dnnl_mobilenet() #test_extern_dnnl_mobilenet()
test_function_lifting() test_function_lifting()
test_function_lifting_inline() test_function_lifting_inline()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment