Unverified Commit f506c8b1 by Cody Yu Committed by GitHub

[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)

* add target to region

* refactor annotate_target

* Make all unit test working

* quick fix

* enable BN, unit test failed

* Fix vm test, unit test. Refactor annotate_target a bit.

* quick fix fusion

* revert fusion change

* style fix

* Refactor merge region pass

* format

* minor fix

* Skip e2e test

* lint

* support AnnotateTarget multiple runs

* Add HasAttr and revert DNNL codegen

* address comment

Co-authored-by: Zhi Chen <chzhi@amazon.com>
parent 5795539c
......@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True):
return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
......@@ -587,14 +587,14 @@ def PartitionGraph():
def AnnotateTarget(target):
def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Parameters
----------
target : String
The target compiler used for codegen.
targets : str or List[str]
The list of target compilers used for codegen.
Returns
-------
......@@ -602,7 +602,9 @@ def AnnotateTarget(target):
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _ffi_api.AnnotateTarget(target)
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
def Inline():
......
......@@ -21,6 +21,7 @@
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/runtime/container.h>
#include <unordered_map>
#include <vector>
......@@ -31,7 +32,7 @@ namespace relay {
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate;
}
}
......@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
}
// Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
for (const auto& input : src->ins) {
dest->ins.push_back(input);
dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins_) {
dest->ins_.push_back(input);
}
for (const auto& output : src->outs) {
dest->outs.push_back(output);
for (const auto& output : src->outs_) {
dest->outs_.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]);
if (it != src->nodes.end()) {
dest->outs.remove(*it);
auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes_.end()) {
dest->outs_.remove(*it);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
dest->ins_.remove(input);
}
regions_.erase(src);
}
......@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) {
MergeRegions(src, dest);
} else {
dest->nodes.insert(expr);
dest->nodes_.insert(expr);
}
}
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++;
(*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first;
}
class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op) :
begin_op_(region_begin_op), end_op_(region_end_op) {}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {}
void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
......@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
region->ins.push_back(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
region = region_set_->MakeRegion();
region->nodes.insert(call->args[0]);
// Create a new region if the argument is not belonged to any regions yet.
region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
}
region->nodes.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call));
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}
AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
......
......@@ -32,6 +32,7 @@
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h>
#include <string>
......@@ -49,33 +50,39 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id);
Array<Expr> nodes_array(nodes.begin(), nodes.end());
v->Visit("id", &id_);
v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end());
Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end());
Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array);
}
/*! \brief Get the region ID. */
int GetID() const {
return id;
return id_;
}
/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
}
/*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const {
return ins;
return ins_;
}
/*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const {
return outs;
return outs_;
}
/*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes;
return nodes_;
}
static constexpr const char* _type_key = "relay.AnnotatedRegion";
......@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object {
protected:
/*! \brief The region ID. */
int id{-1};
int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */
std::list<Expr> ins;
std::list<Expr> ins_;
/*! \brief The outputs of this region */
std::list<Expr> outs;
std::list<Expr> outs_;
/*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode;
......@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*!
* \brief Make a new region.
* \brief Make a new region for a target.
*
* \return The new region.
*/
AnnotatedRegion MakeRegion();
AnnotatedRegion MakeRegion(const std::string& target);
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */
......
......@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
// Do nothing
}
void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
......@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
}
// Analyze the output buffers
std::vector<Type> out_types;
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}
out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
......@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}
std::string JIT(void) {
......
......@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
......@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
......
......@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator {
IRModule module_;
};
class DefaultRemover : public ExprMutator {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
return ExprMutator::VisitExpr_(call);
}
private:
IRModule module_;
};
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
[=](IRModule m, PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
......
......@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory);
}
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag;
using dt = memory::data_type;
......
......@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_e_);
extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
......
......@@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
def check_region(region_set, args, nodes, rets):
def check_region(region_set, target, args, nodes, rets):
region = region_set.get_region(args[0])
assert region
assert target == region.target
assert set(args) == set(region.args)
assert set(nodes) == set(region.nodes)
assert set(rets) == set(region.rets)
......@@ -51,24 +53,28 @@ def test_region_set_creator_diamond():
assert len(region_set) == 4
check_region(
region_set,
'test_target',
[cb_1],
[cb_1, O_1, ce_1, ce_2],
[ce_1, ce_2],
)
check_region(
region_set,
'test_target',
[cb_2],
[cb_2, O_2, ce_3],
[ce_3],
)
check_region(
region_set,
'default',
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
'test_target',
[cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4],
[ce_4],
......@@ -88,7 +94,9 @@ def test_region_set_creator_merged():
cb_3 = compiler_begin(ce_3, 'test_target')
cb_4 = compiler_begin(ce_d, 'test_target')
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, 'test_target')
O_4 = relay.add(cb_3, cb_4)
O_5 = relay.Tuple([O_3, O_4])
ce_4 = compiler_end(O_5, 'test_target')
merged = relay.Function([data], ce_4)
region_set = relay.analysis.AnnotatedRegionSet(merged,
......@@ -97,20 +105,23 @@ def test_region_set_creator_merged():
assert len(region_set) == 3
check_region(
region_set,
'test_target',
[cb_1],
[cb_1, O_1, O_2, ce_2, ce_3],
[ce_2, ce_3],
)
check_region(
region_set,
'default',
[cb_d],
[cb_d, X, ce_d],
[ce_d],
)
check_region(
region_set,
'test_target',
[cb_3, cb_4],
[cb_3, cb_4, O_3, ce_4],
[cb_3, cb_4, O_3, O_4, O_5, ce_4],
[ce_4],
)
......@@ -118,4 +129,3 @@ def test_region_set_creator_merged():
if __name__ == "__main__":
test_region_set_creator_diamond()
test_region_set_creator_merged()
......@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
@reg.register("nn.relu", "target.test")
def relu(attrs, args):
return True
def test_multiple_ends():
@reg.register("nn.relu", "target.test")
def relu(attrs, args): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
......@@ -208,10 +207,17 @@ def test_multiple_ends():
r = relay.nn.relu(cb_1)
ce_1 = relay.annotation.compiler_end(r, "test")
ce_2 = relay.annotation.compiler_end(r, "test")
a_1 = relay.abs(ce_1)
a_2 = relay.abs(ce_2)
out = relay.add(a_1, a_2)
f = relay.Function([x], out)
cb_2 = relay.annotation.compiler_begin(ce_1, "default")
cb_3 = relay.annotation.compiler_begin(ce_2, "default")
a_1 = relay.abs(cb_2)
a_2 = relay.abs(cb_3)
ce_3 = relay.annotation.compiler_end(a_1, "default")
ce_4 = relay.annotation.compiler_end(a_2, "default")
cb_4 = relay.annotation.compiler_begin(ce_3, "default")
cb_5 = relay.annotation.compiler_begin(ce_4, "default")
out = relay.add(cb_4, cb_5)
ce_6 = relay.annotation.compiler_end(out, "default")
f = relay.Function([x], ce_6)
mod = tvm.IRModule.from_expr(f)
return mod
......@@ -220,6 +226,72 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result)
def test_type_propagation():
target = "test_type_propagation"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return args[0].checked_type.dtype == "float32"
def before():
x = relay.var("x", shape=(10, 10))
r = relay.nn.relu(x)
out = relay.nn.relu(r)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
# If the type isn't propogated, then the relu checker function will fail to get the dtype.
assert transform.AnnotateTarget(target)(before())
def test_tuple():
target = "test_tuple"
@reg.register("nn.relu", "target." + target)
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("concatenate", "target." + target)
def concatenate(attrs, args): # pylint: disable=unused-variable
return True
"""Test that TupleNode is included in annotation when surrounded by supported nodes."""
def before():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.nn.relu(y)
out = relay.concatenate((a_1, a_2), axis=1)
f = relay.Function([x, y], out)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5))
cb_1 = relay.annotation.compiler_begin(x, target)
cb_2 = relay.annotation.compiler_begin(y, target)
a_1 = relay.nn.relu(cb_1)
a_2 = relay.nn.relu(cb_2)
ce_1 = relay.annotation.compiler_end(a_1, target)
ce_2 = relay.annotation.compiler_end(a_2, target)
cb_3 = relay.annotation.compiler_begin(ce_1, target)
cb_4 = relay.annotation.compiler_begin(ce_2, target)
tup = relay.Tuple([cb_3, cb_4])
ce_3 = relay.annotation.compiler_end(tup, target)
cb_3 = relay.annotation.compiler_begin(ce_3, target)
out = relay.op._make.concatenate(cb_3, 1)
ce_4 = relay.annotation.compiler_end(out, target)
f = relay.Function([x, y], ce_4)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
......@@ -265,8 +337,37 @@ def test_composite_function():
assert tvm.ir.structural_equal(expected, result)
def test_multiple_runs():
@reg.register("nn.relu", "target.A")
def relu(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("add", "target.B")
def add(attrs, args): # pylint: disable=unused-variable
return True
def before():
x = relay.var("x", shape=(10, 5))
a_1 = relay.nn.relu(x)
a_2 = relay.abs(a_1)
a_3 = relay.nn.relu(a_1)
out = relay.add(a_2, a_3)
f = relay.Function([x], out)
mod = tvm.IRModule.from_expr(f)
return mod
mod = transform.AnnotateTarget("A")(before())
mod = transform.AnnotateTarget("B")(mod)
expected = transform.AnnotateTarget(["A", "B"])(before())
assert tvm.ir.structural_equal(expected, mod)
if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
#test_extern_dnnl_mobilenet()
test_composite_function()
#test_extern_dnnl_mobilenet()
test_multiple_ends()
test_type_propagation()
test_tuple()
test_multiple_runs()
......@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts():
X = not supported by target
O O
/ \ / \
/ \\ / \\
O X --> O + + X
\ / \ /
\\ / \\ /
O O
Note that we can't just merge the three supported operators together,
......@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts():
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
diamond = relay.Function([data], ce_4)
cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond
def expected():
......@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_2, "default")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
func = relay.Function([data], ce_4)
func = relay.Function([data], ce_5)
return func
result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
......@@ -85,7 +90,7 @@ def test_example_graph():
"""This tests the merging algorithm on the example used in the RFC.
See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
Blue nodes are adds, red nodes are subtracts.
Blue nodes are adds (target: test), red nodes are subtracts (target: default).
"""
def annotated():
in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
......@@ -112,21 +117,30 @@ def test_example_graph():
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
node3 = relay.subtract(in_5, in_6)
node4 = relay.subtract(in_7, node3)
dbegin0 = compiler_begin(in_5, "default")
dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node4, "test")
begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end3)
dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
begin10 = compiler_begin(node6, "test")
dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
......@@ -159,20 +173,27 @@ def test_example_graph():
node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1)
node3 = relay.subtract(in_5, in_6)
node4 = relay.subtract(in_7, node3)
dbegin0 = compiler_begin(in_5, "default")
dbegin1 = compiler_begin(in_6, "default")
dbegin2 = compiler_begin(in_7, "default")
node3 = relay.subtract(dbegin0, dbegin1)
node4 = relay.subtract(dbegin2, node3)
dend0 = compiler_end(node4, "default")
begin4 = compiler_begin(node4, "test")
begin4 = compiler_begin(dend0, "test")
begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test")
node6 = relay.subtract(in_8, end1)
dbegin4 = compiler_begin(end1, "default")
dbegin5 = compiler_begin(in_8, "default")
node6 = relay.subtract(dbegin5, dbegin4)
dend1 = compiler_end(node6, "default")
node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node6, "test")
begin7 = compiler_begin(dend1, "test")
node8 = relay.add(begin7, begin6)
......
......@@ -17,6 +17,7 @@
"""Unit tests for graph partitioning."""
import os
import sys
import numpy as np
import pytest
......@@ -26,8 +27,12 @@ from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.runtime import container
# Leverage the pass manager to write a simple white list based annotator
@transform.function_pass(opt_level=0)
......@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib
def check_vm_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
......@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
compile_engine.get().clear()
with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
......@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod)
mod = transform.AnnotateTarget(["dnnl"])(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
......@@ -851,6 +858,7 @@ if __name__ == "__main__":
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
test_extern_dnnl()
# TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
#test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment