Commit f23a7a54 by Wuwei Lin Committed by Tianqi Chen

[RELAY] Stop_fusion annotation (#2624)

parent 8b1d07ff
......@@ -29,3 +29,19 @@ def on_device(data, device):
raise ValueError("device is expected to be the type of TVMContext or "
"str, but received %s" % (type(device)))
return _make.on_device(data, device)
def stop_fusion(data):
"""Annotate an expression to prevent it being fused with previous expressions.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.stop_fusion(data)
......@@ -9,6 +9,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
......@@ -37,6 +38,31 @@ RELAY_REGISTER_OP("on_device")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout);
Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
return CallNode::make(op, {data}, Attrs{}, {});
}
TVM_REGISTER_API("relay.op.annotation._make.stop_fusion")
.set_body_typed<Expr(Expr)>([](Expr data) {
return StopFusion(data);
});
RELAY_REGISTER_OP("annotation.stop_fusion")
.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
} // namespace relay
} // namespace tvm
......@@ -741,10 +741,14 @@ class FuseMutator : private ExprMutator {
}
// Transform calls.
Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) {
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
if (call->op.same_as(stop_fusion)) {
return ExprMutator::VisitExpr(call->args[0]);
}
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
......
......@@ -329,6 +329,8 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expr StopFusion(Expr data);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -220,9 +220,41 @@ def test_tuple_strided_slice():
print(zz.astext())
def test_stop_fusion():
def before(dshape):
x = relay.var("x", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
y = relay.annotation.stop_fusion(y)
z = relay.exp(y)
return relay.Function([x], z)
def expected(dshape):
x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
f1 = relay.Function([x], y)
x = relay.var("p01", shape=dshape)
y = relay.exp(x)
f2 = relay.Function([x], y)
x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x])
z = relay.Call(f2, [y])
return relay.Function([x], z)
dshape = (10, 20)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
z = relay.ir_pass.fuse_ops(z)
z = relay.ir_pass.infer_type(z)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(z, after)
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment