Commit cf3f5bce by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Memorize FoldScaleAxis backward transform result (#2214)

parent 1a9df7be
...@@ -556,9 +556,7 @@ class BackwardTransformerNode : ...@@ -556,9 +556,7 @@ class BackwardTransformerNode :
* \return The result of transformation. * \return The result of transformation.
*/ */
Expr Transform(const Expr& expr, AxesSet axes, Expr scale) { Expr Transform(const Expr& expr, AxesSet axes, Expr scale) {
// NOTE: the result of Transform is not memoized. // NOTE: the result of Transform is memoized.
// However, in the current rule, Transform will
// only be called to expr that is referred once.
if (const CallNode* call_node = expr.as<CallNode>()) { if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, axes, scale); return Transform(call_node, axes, scale);
} else { } else {
...@@ -572,7 +570,14 @@ class BackwardTransformerNode : ...@@ -572,7 +570,14 @@ class BackwardTransformerNode :
* \return the result of the call Mutation. * \return the result of the call Mutation.
*/ */
Expr NormalCallTransform(const CallNode* call_node) { Expr NormalCallTransform(const CallNode* call_node) {
return ExprMutator::VisitExpr_(call_node); const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
Expr new_expr = ExprMutator::VisitExpr_(call_node);
memo_[call] = new_expr;
return new_expr;
} }
/*! /*!
* \brief Get the expected axes on expr. * \brief Get the expected axes on expr.
...@@ -620,10 +625,17 @@ Expr BackwardTransformerNode::Transform( ...@@ -620,10 +625,17 @@ Expr BackwardTransformerNode::Transform(
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform"); Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr); auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) { if (f != nullptr) {
return f(GetRef<Call>(call_node), const Call call = GetRef<Call>(call_node);
axes, const auto it = memo_.find(call);
scale, if (it != memo_.end()) {
GetRef<BackwardTransformer>(this)); return it->second;
}
Expr new_expr = f(GetRef<Call>(call_node),
axes,
scale,
GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else { } else {
CHECK(!axes.defined()) << "outstanding scale"; CHECK(!axes.defined()) << "outstanding scale";
return NormalCallTransform(call_node); return NormalCallTransform(call_node);
......
...@@ -268,6 +268,81 @@ def test_fold_bwd_dual_path(): ...@@ -268,6 +268,81 @@ def test_fold_bwd_dual_path():
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
def test_fold_bwd_dual_consumer():
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
y0 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y0 = relay.multiply(y0, out_scale)
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(y0, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.multiply(y1, out_scale)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(y0, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.multiply(y2, out_scale)
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
def fold_conv_weight():
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
return relay.multiply(
conv_weight ,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y0 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(y0, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(y0, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
def test_fold_bwd_fail(): def test_fold_bwd_fail():
"""Dual path testcase.""" """Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels): def fail1(x, conv_weight, out_bias, out_scale, channels):
...@@ -327,4 +402,5 @@ if __name__ == "__main__": ...@@ -327,4 +402,5 @@ if __name__ == "__main__":
test_fold_fwd_fail() test_fold_fwd_fail()
test_fold_bwd_simple() test_fold_bwd_simple()
test_fold_bwd_dual_path() test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail() test_fold_bwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment