Unverified Commit d2de35eb by mbaret Committed by GitHub

[RELAY][BYOC] Add support for composite functions in BYOC (#5261)

* [RELAY] Add 'check' functions to MergeComposite

Currently, MergeComposite can only perform structural
matches. This patch introduces the ability to specify
a 'check' function alongside the pattern which can include
custom logic to determine whether an extracted pattern
should be merged.

For example, if you only want to merge 'NHWC' convolutions,
you can specify a 'check' function which queries the
data_layout value of the extracted pattern (see the test).

Change-Id: I9337ce39f10997051a286d888be38ed0d410d340

* [RELAY] Reformat merge_composite.cc

Run clang-format on merge_composite.cc

Change-Id: I1736bff798cc6d93e57519b08ab3362869098779

* [RELAY][BYOC] Support composite functions in AnnotateTarget

This patch introduces support to annotate composite functions
in the AnnotateTarget pass. In order for a composite function
to be annotated, you should name it according to the style:

{codegen}.{name}
eg. dnnl.add_relu

Change-Id: I74d6c0b506153d866f6d1feb203b32dad59f2871
parent 53a4ad35
......@@ -378,9 +378,12 @@ def MergeComposite(pattern_table):
Parameters
----------
pattern_table : list(tuple)
A list of (pattern_name, pattern) tuples.
A list of (pattern_name, pattern, check) tuples.
The order of the patterns in the list will determine the order
of priority in which they are matched.
'check' is a function to check whether an extracted pattern matches.
It can be implemented by pattern writer but if not specified it will
always return True.
Returns
-------
......@@ -390,11 +393,19 @@ def MergeComposite(pattern_table):
"""
pattern_names = []
patterns = []
for pattern_name, pattern in pattern_table:
checks = []
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True
elif len(tup) == 3:
pattern_name, pattern, check = tup
pattern_names.append(pattern_name)
patterns.append(pattern)
checks.append(check)
return _ffi_api.MergeComposite(pattern_names, patterns)
return _ffi_api.MergeComposite(pattern_names, patterns, *checks)
def MergeCompilerRegions():
......
......@@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
if (call->op->IsInstance<OpNode>()) {
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
} else if (call->op->IsInstance<FunctionNode>()) {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
if (comp_name.defined()) {
size_t i = comp_name->value.find('.');
if (i != std::string::npos) {
std::string target = comp_name->value.substr(0, i);
if (target == target_) return true;
}
}
}
}
if (expr->IsInstance<TupleGetItemNode>()) {
......@@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator {
}
Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e);
......@@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator {
}
}
Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
Expr VisitExpr_(const FunctionNode* fn) {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body);
}
auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
new_body,
func->ret_type,
func->type_params,
func->attrs);
......
......@@ -219,7 +219,53 @@ def test_multiple_ends():
assert tvm.ir.structural_equal(expected, result)
def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
# merged function
r = relay.Call(add_relu, [a, b])
f = relay.Function([a, b], r)
mod = tvm.IRModule.from_expr(f)
return mod
def after():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
cb_2 = relay.annotation.compiler_begin(b, "test")
r = relay.Call(add_relu, [cb_1, cb_2])
ce_1 = relay.annotation.compiler_end(r, "test")
f = relay.Function([a, b], ce_1)
mod = tvm.IRModule.from_expr(f)
return mod
result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)
if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
test_composite_function()
......@@ -732,6 +732,43 @@ def test_tuple_get_item_merge():
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_pattern_with_check():
def before():
x = relay.var('x', shape=(1, 10, 10, 10))
w = relay.var('w', shape=(10, 10, 3, 3))
b = relay.var('b', shape=(8,))
conv = relay.nn.conv2d(x,
w,
kernel_size=(3, 3),
kernel_layout="OIHW",
data_layout="NHWC")
bias = relay.nn.bias_add(conv, b)
relu = relay.nn.relu(bias)
return relay.Function([x, w, b], relu)
def _check_true(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NHWC"
def _check_false(extract):
conv = extract.args[0].args[0]
return conv.attrs.data_layout == "NCHW"
pattern_table_true = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)
]
pattern_table_false = [
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)
]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
expected = run_opt_pass(before(), relay.transform.InferType())
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
assert result.body.op.attrs["Composite"] == "conv_bias_relu"
if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
......@@ -741,3 +778,4 @@ if __name__ == "__main__":
test_multiple_input_subgraphs()
test_reuse_call_merge()
test_tuple_get_item_merge()
test_pattern_with_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment