Commit cf3f5bce by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Memorize FoldScaleAxis backward transform result (#2214)

parent 1a9df7be
......@@ -556,9 +556,7 @@ class BackwardTransformerNode :
* \return The result of transformation.
*/
Expr Transform(const Expr& expr, AxesSet axes, Expr scale) {
// NOTE: the result of Transform is not memoized.
// However, in the current rule, Transform will
// only be called to expr that is referred once.
// NOTE: the result of Transform is memoized.
if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, axes, scale);
} else {
......@@ -572,7 +570,14 @@ class BackwardTransformerNode :
* \return the result of the call Mutation.
*/
Expr NormalCallTransform(const CallNode* call_node) {
return ExprMutator::VisitExpr_(call_node);
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
Expr new_expr = ExprMutator::VisitExpr_(call_node);
memo_[call] = new_expr;
return new_expr;
}
/*!
* \brief Get the expected axes on expr.
......@@ -620,10 +625,17 @@ Expr BackwardTransformerNode::Transform(
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
return f(GetRef<Call>(call_node),
axes,
scale,
GetRef<BackwardTransformer>(this));
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
Expr new_expr = f(GetRef<Call>(call_node),
axes,
scale,
GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else {
CHECK(!axes.defined()) << "outstanding scale";
return NormalCallTransform(call_node);
......
......@@ -268,6 +268,81 @@ def test_fold_bwd_dual_path():
check((2, 4, 10, 10), 8)
def test_fold_bwd_dual_consumer():
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
y0 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y0 = relay.multiply(y0, out_scale)
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(y0, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.multiply(y1, out_scale)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(y0, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.multiply(y2, out_scale)
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
def fold_conv_weight():
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
return relay.multiply(
conv_weight ,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y0 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y0 = relay.nn.relu(y0)
y1 = relay.nn.conv2d(y0, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(y0, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
def test_fold_bwd_fail():
"""Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels):
......@@ -327,4 +402,5 @@ if __name__ == "__main__":
test_fold_fwd_fail()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment