Commit 93949456 by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220)

parent 166936cd
...@@ -150,13 +150,14 @@ def optimize(func, params=None): ...@@ -150,13 +150,14 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func) func = ir_pass.combine_parallel_conv2d(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("FoldScaleAxis"): if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func) func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func) func = ir_pass.forward_fold_scale_axis(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func) func = ir_pass.fold_constant(func)
if cfg.pass_enabled("AlterOpLayout"): if cfg.pass_enabled("AlterOpLayout"):
......
...@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor { ...@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward // Per operator defs for FScaleAxisForward
//---------------------------------------------- //----------------------------------------------
// Helper functions
Expr GetForwardScale(const Expr& expr, AxesSet out) {
static const Op& multiply = Op::Get("multiply");
static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
const CallNode* call = expr.as<CallNode>();
if (!call) return NullValue<Expr>();
auto f = fprep.get(call->op, nullptr);
if (call->op.same_as(multiply)) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
return call->args[1];
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
return call->args[0];
} else {
return NullValue<Expr>();
}
} else if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
for (size_t i = 0; i < call->args.size(); i++) {
auto scale = GetForwardScale(call->args[i], in_axes[i]);
if (scale.defined()) {
return scale;
}
}
}
return NullValue<Expr>();
}
// Intermediate operators // Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out}; Expr scale = GetForwardScale(call->args[0], out);
if (IsPositiveConstant(scale)) {
return {out};
}
return {NullValue<AxesSet>()};
} }
Expr ReluForwardRewrite(const Call& ref_call, Expr ReluForwardRewrite(const Call& ref_call,
...@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract") ...@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP("subtract") RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Find relu in the backward path between multiply and conv2d
bool FindBackwardRelu(const Expr& expr) {
const CallNode* call = expr.as<CallNode>();
static const Op& conv2d = Op::Get("nn.conv2d");
static const Op& relu = Op::Get("nn.relu");
if (!call) return false;
if (call->op.same_as(relu)) return true;
if (call->op.same_as(conv2d)) return false;
for (size_t i = 0; i < call->args.size(); i++) {
if (FindBackwardRelu(call->args[i])) return true;
}
return false;
}
// Producer operators // Producer operators
// Multiply produces the scale-axis pair. // Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call, Expr MultiplyBackwardTransform(const Call& call,
...@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call, ...@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part. // NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part. // since there won't be scale chance within scale part.
Expr rhs = call->args[1]; Expr rhs = call->args[1];
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) { if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
(!FindBackwardRelu(call->args[0]) ||
IsPositiveConstant(call->args[1]))) {
return transformer->Transform(call->args[0], lhs_axes, rhs); return transformer->Transform(call->args[0], lhs_axes, rhs);
} }
} else if (rhs_axes.defined() && rhs_axes.size() != 0) { } else if (rhs_axes.defined() && rhs_axes.size() != 0) {
Expr lhs = call->args[0]; Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) { if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
(!FindBackwardRelu(call->args[1]) ||
IsPositiveConstant(call->args[0]))) {
return transformer->Transform(call->args[1], rhs_axes, lhs); return transformer->Transform(call->args[1], rhs_axes, lhs);
} }
} }
......
...@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis); ...@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
inline bool IsPositiveConstant(const Expr& expr) {
const auto* constant = expr.as<ConstantNode>();
if (!constant) return false;
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
// pass
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
}
LOG(WARNING) << "Unsupported data type (code = " << dtype.code
<< ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
<< ")";
return false;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
from tvm import relay from tvm import relay
import numpy as np
def test_fold_fwd_simple(): def test_fold_fwd_simple():
"""Simple testcase.""" """Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels): def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale] args = [x, conv_weight, in_bias]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale) x = relay.multiply(x, in_scale)
x = relay.nn.relu(x) x = relay.nn.relu(x)
...@@ -18,8 +18,7 @@ def test_fold_fwd_simple(): ...@@ -18,8 +18,7 @@ def test_fold_fwd_simple():
def expected(x, conv_weight, in_bias, in_scale, channels): def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass # use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale] args = [x, conv_weight, in_bias]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x) x = relay.nn.relu(x)
...@@ -38,7 +37,7 @@ def test_fold_fwd_simple(): ...@@ -38,7 +37,7 @@ def test_fold_fwd_simple():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,)) in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32'))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -56,7 +55,7 @@ def test_fold_fwd_simple(): ...@@ -56,7 +55,7 @@ def test_fold_fwd_simple():
def test_fold_fwd_dual_path(): def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers""" """scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels): def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale] args = [x, conv_weight, in_bias]
x = relay.multiply(in_scale, x) x = relay.multiply(in_scale, x)
x = relay.nn.relu(x) x = relay.nn.relu(x)
x = relay.subtract(x, in_bias) x = relay.subtract(x, in_bias)
...@@ -78,7 +77,7 @@ def test_fold_fwd_dual_path(): ...@@ -78,7 +77,7 @@ def test_fold_fwd_dual_path():
return relay.Function(args, z) return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels): def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale] args = [x, conv_weight, in_bias]
x = relay.nn.relu(x) x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale) in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias) x = relay.subtract(x, in_bias)
...@@ -108,7 +107,7 @@ def test_fold_fwd_dual_path(): ...@@ -108,7 +107,7 @@ def test_fold_fwd_dual_path():
assert in_channels == channels assert in_channels == channels
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,)) in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
...@@ -142,7 +141,7 @@ def test_fold_fwd_fail(): ...@@ -142,7 +141,7 @@ def test_fold_fwd_fail():
assert in_channels == channels assert in_channels == channels
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,)) in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
...@@ -151,11 +150,42 @@ def test_fold_fwd_fail(): ...@@ -151,11 +150,42 @@ def test_fold_fwd_fail():
check((2, 11, 10, 4), 4) check((2, 11, 10, 4), 4)
def test_fold_fwd_relu_fail():
"""testcase where we canont fold because scale can not pass relu"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.relu(x)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels, in_scale):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32")
check((2, 11, 10, 4), 4, in_scale)
def test_fold_bwd_simple(): def test_fold_bwd_simple():
"""Simple testcase.""" """Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels): def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y = relay.nn.conv2d(x, conv_weight, y = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
...@@ -168,8 +198,7 @@ def test_fold_bwd_simple(): ...@@ -168,8 +198,7 @@ def test_fold_bwd_simple():
def expected(x, conv_weight, out_bias, out_scale, channels): def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass # use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply( conv_weight = relay.multiply(
...@@ -190,7 +219,7 @@ def test_fold_bwd_simple(): ...@@ -190,7 +219,7 @@ def test_fold_bwd_simple():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,)) out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -208,9 +237,7 @@ def test_fold_bwd_simple(): ...@@ -208,9 +237,7 @@ def test_fold_bwd_simple():
def test_fold_bwd_dual_path(): def test_fold_bwd_dual_path():
"""Dual path testcase.""" """Dual path testcase."""
def before(x, conv_weight, out_bias, out_scale, channels): def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight, y1 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -227,8 +254,7 @@ def test_fold_bwd_dual_path(): ...@@ -227,8 +254,7 @@ def test_fold_bwd_dual_path():
def expected(x, conv_weight, out_bias, out_scale, channels): def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass # use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
def fold_conv_weight(): def fold_conv_weight():
...@@ -253,7 +279,7 @@ def test_fold_bwd_dual_path(): ...@@ -253,7 +279,7 @@ def test_fold_bwd_dual_path():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,)) out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -270,8 +296,7 @@ def test_fold_bwd_dual_path(): ...@@ -270,8 +296,7 @@ def test_fold_bwd_dual_path():
def test_fold_bwd_dual_consumer(): def test_fold_bwd_dual_consumer():
def before(x, conv_weight, out_bias, out_scale, channels): def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
y0 = relay.nn.conv2d(x, conv_weight, y0 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -298,8 +323,7 @@ def test_fold_bwd_dual_consumer(): ...@@ -298,8 +323,7 @@ def test_fold_bwd_dual_consumer():
def expected(x, conv_weight, out_bias, out_scale, channels): def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass # use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
def fold_conv_weight(): def fold_conv_weight():
squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
return relay.multiply( return relay.multiply(
...@@ -328,7 +352,7 @@ def test_fold_bwd_dual_consumer(): ...@@ -328,7 +352,7 @@ def test_fold_bwd_dual_consumer():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,)) out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -346,8 +370,7 @@ def test_fold_bwd_dual_consumer(): ...@@ -346,8 +370,7 @@ def test_fold_bwd_dual_consumer():
def test_fold_bwd_fail(): def test_fold_bwd_fail():
"""Dual path testcase.""" """Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels): def fail1(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight, y1 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
...@@ -367,8 +390,7 @@ def test_fold_bwd_fail(): ...@@ -367,8 +390,7 @@ def test_fold_bwd_fail():
return relay.Function(args, y) return relay.Function(args, y)
def fail2(x, conv_weight, out_bias, out_scale, channels): def fail2(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale] args = [x, conv_weight, out_bias]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight, y1 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
...@@ -380,13 +402,12 @@ def test_fold_bwd_fail(): ...@@ -380,13 +402,12 @@ def test_fold_bwd_fail():
y = relay.add(y1, y2) y = relay.add(y1, y2)
return relay.Function(args, y) return relay.Function(args, y)
def check(shape, channels, fbefore): def check(shape, channels, fbefore):
x = relay.var("x", shape=shape) x = relay.var("x", shape=shape)
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,)) out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
...@@ -396,11 +417,40 @@ def test_fold_bwd_fail(): ...@@ -396,11 +417,40 @@ def test_fold_bwd_fail():
check((4, 4, 10, 10), 4, fail2) check((4, 4, 10, 10), 4, fail2)
def test_fold_bwd_relu_fail():
"""testcase where we canont fold because scale can not pass relu"""
def before(x, conv_weight, out_scale, channels):
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NCHW",
padding=(1, 1))
y = relay.nn.relu(y)
y = relay.multiply(x, out_scale)
return relay.Function(relay.ir_pass.free_vars(y), y)
def check(shape, channels, out_scale):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
y1 = before(x, weight, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, out_scale)
out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
check((4, 4, 10, 10), 4, out_scale)
if __name__ == "__main__": if __name__ == "__main__":
test_fold_fwd_simple() test_fold_fwd_simple()
test_fold_fwd_dual_path() test_fold_fwd_dual_path()
test_fold_fwd_fail() test_fold_fwd_fail()
test_fold_fwd_relu_fail()
test_fold_bwd_simple() test_fold_bwd_simple()
test_fold_bwd_dual_path() test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer() test_fold_bwd_dual_consumer()
test_fold_bwd_fail() test_fold_bwd_fail()
test_fold_bwd_relu_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment