Commit 93949456 by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220)

parent 166936cd
......@@ -150,13 +150,14 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("AlterOpLayout"):
......
......@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Helper functions
Expr GetForwardScale(const Expr& expr, AxesSet out) {
static const Op& multiply = Op::Get("multiply");
static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
const CallNode* call = expr.as<CallNode>();
if (!call) return NullValue<Expr>();
auto f = fprep.get(call->op, nullptr);
if (call->op.same_as(multiply)) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
return call->args[1];
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
return call->args[0];
} else {
return NullValue<Expr>();
}
} else if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
for (size_t i = 0; i < call->args.size(); i++) {
auto scale = GetForwardScale(call->args[i], in_axes[i]);
if (scale.defined()) {
return scale;
}
}
}
return NullValue<Expr>();
}
// Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out};
Expr scale = GetForwardScale(call->args[0], out);
if (IsPositiveConstant(scale)) {
return {out};
}
return {NullValue<AxesSet>()};
}
Expr ReluForwardRewrite(const Call& ref_call,
......@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Find relu in the backward path between multiply and conv2d
bool FindBackwardRelu(const Expr& expr) {
const CallNode* call = expr.as<CallNode>();
static const Op& conv2d = Op::Get("nn.conv2d");
static const Op& relu = Op::Get("nn.relu");
if (!call) return false;
if (call->op.same_as(relu)) return true;
if (call->op.same_as(conv2d)) return false;
for (size_t i = 0; i < call->args.size(); i++) {
if (FindBackwardRelu(call->args[i])) return true;
}
return false;
}
// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
......@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr rhs = call->args[1];
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) {
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
(!FindBackwardRelu(call->args[0]) ||
IsPositiveConstant(call->args[1]))) {
return transformer->Transform(call->args[0], lhs_axes, rhs);
}
} else if (rhs_axes.defined() && rhs_axes.size() != 0) {
Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) {
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
(!FindBackwardRelu(call->args[1]) ||
IsPositiveConstant(call->args[0]))) {
return transformer->Transform(call->args[1], rhs_axes, lhs);
}
}
......
......@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
inline bool IsPositiveConstant(const Expr& expr) {
const auto* constant = expr.as<ConstantNode>();
if (!constant) return false;
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
// pass
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
}
LOG(WARNING) << "Unsupported data type (code = " << dtype.code
<< ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
<< ")";
return false;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
from tvm import relay
import numpy as np
def test_fold_fwd_simple():
"""Simple testcase."""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, in_bias]
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
......@@ -18,8 +18,7 @@ def test_fold_fwd_simple():
def expected(x, conv_weight, in_bias, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, in_bias, in_scale]
in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, in_bias]
in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
x = relay.nn.relu(x)
......@@ -38,7 +37,7 @@ def test_fold_fwd_simple():
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32'))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -56,7 +55,7 @@ def test_fold_fwd_simple():
def test_fold_fwd_dual_path():
"""scale axis being consumed by two consumers"""
def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
args = [x, conv_weight, in_bias]
x = relay.multiply(in_scale, x)
x = relay.nn.relu(x)
x = relay.subtract(x, in_bias)
......@@ -78,7 +77,7 @@ def test_fold_fwd_dual_path():
return relay.Function(args, z)
def expected(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias, in_scale]
args = [x, conv_weight, in_bias]
x = relay.nn.relu(x)
in_bias = relay.divide(in_bias, in_scale)
x = relay.subtract(x, in_bias)
......@@ -108,7 +107,7 @@ def test_fold_fwd_dual_path():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......@@ -142,7 +141,7 @@ def test_fold_fwd_fail():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......@@ -151,11 +150,42 @@ def test_fold_fwd_fail():
check((2, 11, 10, 4), 4)
def test_fold_fwd_relu_fail():
"""testcase where we canont fold because scale can not pass relu"""
def before(x, conv_weight, in_bias, in_scale, channels):
x = relay.multiply(x, in_scale)
xx = relay.nn.relu(x)
y1 = relay.nn.conv2d(xx, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
padding=(1, 1))
z = relay.add(y1, x)
return relay.Function(relay.ir_pass.free_vars(z), z)
def check(shape, channels, in_scale):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32")
check((2, 11, 10, 4), 4, in_scale)
def test_fold_bwd_simple():
"""Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
......@@ -168,8 +198,7 @@ def test_fold_bwd_simple():
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
......@@ -190,7 +219,7 @@ def test_fold_bwd_simple():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -208,9 +237,7 @@ def test_fold_bwd_simple():
def test_fold_bwd_dual_path():
"""Dual path testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
......@@ -227,8 +254,7 @@ def test_fold_bwd_dual_path():
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
def fold_conv_weight():
......@@ -253,7 +279,7 @@ def test_fold_bwd_dual_path():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -270,8 +296,7 @@ def test_fold_bwd_dual_path():
def test_fold_bwd_dual_consumer():
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
y0 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
......@@ -298,8 +323,7 @@ def test_fold_bwd_dual_consumer():
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
def fold_conv_weight():
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
return relay.multiply(
......@@ -328,7 +352,7 @@ def test_fold_bwd_dual_consumer():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32"))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -346,8 +370,7 @@ def test_fold_bwd_dual_consumer():
def test_fold_bwd_fail():
"""Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
......@@ -367,8 +390,7 @@ def test_fold_bwd_fail():
return relay.Function(args, y)
def fail2(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
......@@ -380,13 +402,12 @@ def test_fold_bwd_fail():
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels, fbefore):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
......@@ -396,11 +417,40 @@ def test_fold_bwd_fail():
check((4, 4, 10, 10), 4, fail2)
def test_fold_bwd_relu_fail():
"""testcase where we canont fold because scale can not pass relu"""
def before(x, conv_weight, out_scale, channels):
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NCHW",
padding=(1, 1))
y = relay.nn.relu(y)
y = relay.multiply(x, out_scale)
return relay.Function(relay.ir_pass.free_vars(y), y)
def check(shape, channels, out_scale):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
y1 = before(x, weight, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1, y1_folded)
out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, out_scale)
out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
check((4, 4, 10, 10), 4, out_scale)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
test_fold_fwd_relu_fail()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail()
test_fold_bwd_relu_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment