Unverified Commit d2de35eb by mbaret Committed by GitHub

[RELAY][BYOC] Add support for composite functions in BYOC (#5261)

* [RELAY] Add 'check' functions to MergeComposite

Currently, MergeComposite can only perform structural
matches. This patch introduces the ability to specify
a 'check' function alongside the pattern which can include
custom logic to determine whether an extracted pattern
should be merged.

For example, if you only want to merge 'NHWC' convolutions,
you can specify a 'check' function which queries the
data_layout value of the extracted pattern (see the test).

Change-Id: I9337ce39f10997051a286d888be38ed0d410d340

* [RELAY] Reformat merge_composite.cc

Run clang-format on merge_composite.cc

Change-Id: I1736bff798cc6d93e57519b08ab3362869098779

* [RELAY][BYOC] Support composite functions in AnnotateTarget

This patch introduces support to annotate composite functions
in the AnnotateTarget pass. In order for a composite function
to be annotated, you should name it according to the style:

{codegen}.{name}
eg. dnnl.add_relu

Change-Id: I74d6c0b506153d866f6d1feb203b32dad59f2871
parent 53a4ad35
...@@ -378,9 +378,12 @@ def MergeComposite(pattern_table): ...@@ -378,9 +378,12 @@ def MergeComposite(pattern_table):
Parameters Parameters
---------- ----------
pattern_table : list(tuple) pattern_table : list(tuple)
A list of (pattern_name, pattern) tuples. A list of (pattern_name, pattern, check) tuples.
The order of the patterns in the list will determine the order The order of the patterns in the list will determine the order
of priority in which they are matched. of priority in which they are matched.
'check' is a function to check whether an extracted pattern matches.
It can be implemented by pattern writer but if not specified it will
always return True.
Returns Returns
------- -------
...@@ -390,11 +393,19 @@ def MergeComposite(pattern_table): ...@@ -390,11 +393,19 @@ def MergeComposite(pattern_table):
""" """
pattern_names = [] pattern_names = []
patterns = [] patterns = []
for pattern_name, pattern in pattern_table: checks = []
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True
elif len(tup) == 3:
pattern_name, pattern, check = tup
pattern_names.append(pattern_name) pattern_names.append(pattern_name)
patterns.append(pattern) patterns.append(pattern)
checks.append(check)
return _ffi_api.MergeComposite(pattern_names, patterns) return _ffi_api.MergeComposite(pattern_names, patterns, *checks)
def MergeCompilerRegions(): def MergeCompilerRegions():
......
...@@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator {
if (expr->IsInstance<CallNode>()) { if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr); Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_); auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op); if (call->op->IsInstance<OpNode>()) {
CHECK(op.defined()); Op op = Downcast<Op>(call->op);
if (fannotate.count(op)) { CHECK(op.defined());
return fannotate[op](call->attrs, call->args); if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
} else if (call->op->IsInstance<FunctionNode>()) {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
if (comp_name.defined()) {
size_t i = comp_name->value.find('.');
if (i != std::string::npos) {
std::string target = comp_name->value.substr(0, i);
if (target == target_) return true;
}
}
} }
} }
if (expr->IsInstance<TupleGetItemNode>()) { if (expr->IsInstance<TupleGetItemNode>()) {
...@@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator {
} }
Expr VisitExpr_(const CallNode* cn) { Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn); auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e); Call call = Downcast<Call>(new_e);
...@@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator {
} }
} }
Expr VisitExpr_(const FunctionNode* op) { Expr VisitExpr_(const FunctionNode* fn) {
auto new_e = ExprMutator::VisitExpr_(op); Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body);
}
auto func = Downcast<Function>(new_e);
return Function( return Function(
func->params, func->params,
InsertEnd(func->body), new_body,
func->ret_type, func->ret_type,
func->type_params, func->type_params,
func->attrs); func->attrs);
......
...@@ -25,11 +25,11 @@ ...@@ -25,11 +25,11 @@
* Relay operators map to a single external operator. * Relay operators map to a single external operator.
*/ */
#include <tvm/te/operation.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/te/operation.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -37,11 +37,12 @@ namespace merge_composite { ...@@ -37,11 +37,12 @@ namespace merge_composite {
class MergeCompositeWrapper : public ExprMutator { class MergeCompositeWrapper : public ExprMutator {
public: public:
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern,
: pattern_name_(pattern_name), pattern_(pattern) {} const PackedFunc& check)
: pattern_name_(pattern_name), pattern_(pattern), check_(check) {}
Expr ExtractPattern(const Var& pattern, const Expr& root, Expr ExtractPattern(const Var& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) { Map<std::string, Array<Expr>>* var_map) {
if (var_map->find(pattern->name_hint()) == var_map->end()) { if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate // if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root' // it with the value at 'root'
...@@ -62,12 +63,12 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -62,12 +63,12 @@ class MergeCompositeWrapper : public ExprMutator {
} }
Expr ExtractPattern(const Constant& pattern, const Expr& root, Expr ExtractPattern(const Constant& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) { Map<std::string, Array<Expr>>* var_map) {
return root; return root;
} }
Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root, Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) { Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
if (!root->IsInstance<TupleGetItemNode>()) { if (!root->IsInstance<TupleGetItemNode>()) {
return Expr(); return Expr();
} }
...@@ -75,14 +76,12 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -75,14 +76,12 @@ class MergeCompositeWrapper : public ExprMutator {
if (pattern->index != root_node->index) { if (pattern->index != root_node->index) {
return Expr(); return Expr();
} }
if (pattern->tuple->IsInstance<CallNode>() && if (pattern->tuple->IsInstance<CallNode>() && root_node->tuple->IsInstance<CallNode>()) {
root_node->tuple->IsInstance<CallNode>()) {
Expr new_arg; Expr new_arg;
if (call_map->find(pattern->tuple) != call_map->end()) { if (call_map->find(pattern->tuple) != call_map->end()) {
new_arg = (*call_map)[pattern->tuple]; new_arg = (*call_map)[pattern->tuple];
} else { } else {
new_arg = ExtractPattern(Downcast<Call>(pattern->tuple), new_arg = ExtractPattern(Downcast<Call>(pattern->tuple), Downcast<Call>(root_node->tuple),
Downcast<Call>(root_node->tuple),
var_map, call_map); var_map, call_map);
call_map->Set(pattern->tuple, new_arg); call_map->Set(pattern->tuple, new_arg);
} }
...@@ -104,20 +103,18 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -104,20 +103,18 @@ class MergeCompositeWrapper : public ExprMutator {
* and free variables. The free variables indicate where the pattern can 'attach' in your * and free variables. The free variables indicate where the pattern can 'attach' in your
* graph. This function takes the final call node of the pattern and the call node currently * graph. This function takes the final call node of the pattern and the call node currently
* being traversed in the Relay graph. It traverses through the pattern in lockstep with call node * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
* from the graph (referred to as the 'root' node here) to check they're identical. If at any point * from the graph (referred to as the 'root' node here) to check they're identical. If at any
* they differ, an empty expression is returned to signify the extract failed. If a free var is * point they differ, an empty expression is returned to signify the extract failed. If a free var
* reached in the pattern, the corresponding value in the root is associated with the name of the * is reached in the pattern, the corresponding value in the root is associated with the name of
* free var (via the var_map) so that when we construct the composite function, the inputs match * the free var (via the var_map) so that when we construct the composite function, the inputs
* up correctly with the rest of the graph. The return value of this function when successful is * match up correctly with the rest of the graph. The return value of this function when
* a new Relay expression ready to be wrapped into a composite function. * successful is a new Relay expression ready to be wrapped into a composite function.
*/ */
Expr ExtractPattern(const Call& pattern, const Call& root, Expr ExtractPattern(const Call& pattern, const Call& root, Map<std::string, Array<Expr>>* var_map,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) { Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions) // check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) return Expr();
return Expr(); if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name) return Expr();
if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
return Expr();
unsigned int i = 0; unsigned int i = 0;
Array<Expr> new_args; Array<Expr> new_args;
...@@ -133,27 +130,20 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -133,27 +130,20 @@ class MergeCompositeWrapper : public ExprMutator {
return Expr(); return Expr();
} }
// if it's a call node, recursively call this function // if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg), new_arg =
Downcast<Call>(root->args[i]), ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
var_map, call_map);
call_map->Set(arg, new_arg); call_map->Set(arg, new_arg);
} }
} else if (arg->IsInstance<VarNode>()) { } else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var // if there's a var in the pattern, it must be a free var
// so call the function to update the var_map // so call the function to update the var_map
new_arg = ExtractPattern(Downcast<Var>(arg), new_arg = ExtractPattern(Downcast<Var>(arg), root->args[i], var_map);
root->args[i],
var_map);
} else if (arg->IsInstance<ConstantNode>()) { } else if (arg->IsInstance<ConstantNode>()) {
// if there's a constant, simply get the corresponding // if there's a constant, simply get the corresponding
// value of the constant from the root // value of the constant from the root
new_arg = ExtractPattern(Downcast<Constant>(arg), new_arg = ExtractPattern(Downcast<Constant>(arg), root->args[i], var_map);
root->args[i],
var_map);
} else if (arg->IsInstance<TupleGetItemNode>()) { } else if (arg->IsInstance<TupleGetItemNode>()) {
new_arg = ExtractPattern(Downcast<TupleGetItem>(arg), new_arg = ExtractPattern(Downcast<TupleGetItem>(arg), root->args[i], var_map, call_map);
root->args[i],
var_map, call_map);
} }
if (!new_arg.defined()) { if (!new_arg.defined()) {
return Expr(); return Expr();
...@@ -169,8 +159,7 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -169,8 +159,7 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) { if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op); Function func = Downcast<Function>(call->op);
CHECK(func.defined()); CHECK(func.defined());
const auto name_node = const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
func->GetAttr<tir::StringImm>(attr::kComposite);
// don't step into existing composite functions // don't step into existing composite functions
if (name_node.defined() && name_node->value != "") { if (name_node.defined() && name_node->value != "") {
tvm::Array<tvm::relay::Expr> new_args; tvm::Array<tvm::relay::Expr> new_args;
...@@ -184,8 +173,7 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -184,8 +173,7 @@ class MergeCompositeWrapper : public ExprMutator {
Expr expr = ExprMutator::VisitExpr_(cn); Expr expr = ExprMutator::VisitExpr_(cn);
call = Downcast<Call>(expr); call = Downcast<Call>(expr);
if (!call->op->IsInstance<OpNode>()) if (!call->op->IsInstance<OpNode>()) return std::move(call);
return std::move(call);
// only call patterns are supported // only call patterns are supported
Call pattern = Downcast<Call>(pattern_); Call pattern = Downcast<Call>(pattern_);
...@@ -193,7 +181,7 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -193,7 +181,7 @@ class MergeCompositeWrapper : public ExprMutator {
Map<std::string, Array<Expr>> args_map; Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map; Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map); auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) { if (extract.defined() && static_cast<bool>(check_(extract))) {
auto free_vars = FreeVars(extract); auto free_vars = FreeVars(extract);
// make the composite function // make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
...@@ -215,17 +203,20 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -215,17 +203,20 @@ class MergeCompositeWrapper : public ExprMutator {
std::string pattern_name_; std::string pattern_name_;
/*! \brief The pattern to match */ /*! \brief The pattern to match */
Expr pattern_; Expr pattern_;
/*! \brief The function to check whether an extract is supported */
PackedFunc check_;
}; };
Expr MergeComposite(const Expr& expr, Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) { const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size()); CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr; Expr merged_expr = expr;
// merge the patterns one-by-one in order // merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) { for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value; std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i]; Expr pattern = patterns[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); PackedFunc check = checks[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
} }
return merged_expr; return merged_expr;
} }
...@@ -235,18 +226,25 @@ Expr MergeComposite(const Expr& expr, ...@@ -235,18 +226,25 @@ Expr MergeComposite(const Expr& expr,
namespace transform { namespace transform {
Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names, Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
const tvm::Array<Expr>& patterns) { const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>( return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns)); relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
}; };
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass; return func_pass;
} }
TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body_typed(MergeComposite); tvm::Array<tir::StringImm> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
checks.push_back(args[i]);
}
*rv = MergeComposite(pattern_names, patterns, checks);
});
} // namespace transform } // namespace transform
......
...@@ -219,7 +219,53 @@ def test_multiple_ends(): ...@@ -219,7 +219,53 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result) assert tvm.ir.structural_equal(expected, result)
def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
# merged function
r = relay.Call(add_relu, [a, b])
f = relay.Function([a, b], r)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
cb_2 = relay.annotation.compiler_begin(b, "test")
r = relay.Call(add_relu, [cb_1, cb_2])
ce_1 = relay.annotation.compiler_end(r, "test")
f = relay.Function([a, b], ce_1)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_ends() test_multiple_ends()
test_extern_dnnl() test_extern_dnnl()
test_extern_dnnl_mobilenet() test_extern_dnnl_mobilenet()
test_composite_function()
...@@ -732,6 +732,43 @@ def test_tuple_get_item_merge(): ...@@ -732,6 +732,43 @@ def test_tuple_get_item_merge():
assert tvm.ir.structural_equal(result, expected, map_free_vars=True) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_pattern_with_check():
def before():
x = relay.var('x', shape=(1, 10, 10, 10))
w = relay.var('w', shape=(10, 10, 3, 3))
b = relay.var('b', shape=(8,))
conv = relay.nn.conv2d(x,
w,
kernel_size=(3, 3),
kernel_layout="OIHW",
data_layout="NHWC")
bias = relay.nn.bias_add(conv, b)
relu = relay.nn.relu(bias)
return relay.Function([x, w, b], relu)
def _check_true(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NHWC"
def _check_false(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NCHW"
pattern_table_true = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)
]
pattern_table_false = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
expected = run_opt_pass(before(), relay.transform.InferType())
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
assert result.body.op.attrs["Composite"] == "conv_bias_relu"
if __name__ == "__main__": if __name__ == "__main__":
test_simple_merge() test_simple_merge()
test_branch_merge() test_branch_merge()
...@@ -741,3 +778,4 @@ if __name__ == "__main__": ...@@ -741,3 +778,4 @@ if __name__ == "__main__":
test_multiple_input_subgraphs() test_multiple_input_subgraphs()
test_reuse_call_merge() test_reuse_call_merge()
test_tuple_get_item_merge() test_tuple_get_item_merge()
test_pattern_with_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment