Commit 2a871f35 by Wuwei Lin Committed by ziheng

[RELAY][PASS] Support Negative Scale in FoldScaleAxis (#2426)

* [RELAY][PASS] Support Negative Scale in FoldScaleAxis

* Fix comment
parent 52e3bd32
......@@ -174,7 +174,6 @@ def test_fold_fwd_relu_fail():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......@@ -182,10 +181,52 @@ def test_fold_fwd_relu_fail():
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32")
in_scale = relay.const(-_get_positive_scale((4,)))
check((2, 11, 10, 4), 4, in_scale)
def test_fold_fwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, in_scale, channels):
args = [x, conv_weight]
x = relay.multiply(x, in_scale)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
weight = relay.var("weight")
y1 = before(x, weight, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
def test_fold_bwd_simple():
"""Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
......@@ -223,7 +264,7 @@ def test_fold_bwd_simple():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -283,7 +324,7 @@ def test_fold_bwd_dual_path():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -356,7 +397,7 @@ def test_fold_bwd_dual_consumer():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels,1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -411,7 +452,7 @@ def test_fold_bwd_fail():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
......@@ -448,13 +489,55 @@ def test_fold_bwd_relu_fail():
check((4, 4, 10, 10), 4, out_scale)
def test_fold_bwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, out_scale, channels):
args = [x, conv_weight]
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
test_fold_fwd_relu_fail()
test_fold_fwd_negative_scale()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail()
test_fold_bwd_relu_fail()
test_fold_bwd_negative_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment