Commit 87a37684 by Tianqi Chen Committed by ziheng

[PASS] Avoid recursion in FoldScaleAxis (#2299)

* [PASS] Avoid recursion in FoldScaleAxis

* remove GetForwardScale
parent e9e12f03
...@@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor { ...@@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward // Per operator defs for FScaleAxisForward
//---------------------------------------------- //----------------------------------------------
// Helper functions
Expr GetForwardScale(const Expr& expr, AxesSet out) {
static const Op& multiply = Op::Get("multiply");
static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
const CallNode* call = expr.as<CallNode>();
if (!call) return NullValue<Expr>();
auto f = fprep.get(call->op, nullptr);
if (call->op.same_as(multiply)) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
return call->args[1];
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
return call->args[0];
} else {
return NullValue<Expr>();
}
} else if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
for (size_t i = 0; i < call->args.size(); i++) {
auto scale = GetForwardScale(call->args[i], in_axes[i]);
if (scale.defined()) {
return scale;
}
}
}
return NullValue<Expr>();
}
// Intermediate operators // Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
Expr scale = GetForwardScale(call->args[0], out); return {out};
if (IsPositiveConstant(scale)) {
return {out};
}
return {NullValue<AxesSet>()};
} }
Expr ReluForwardRewrite(const Call& ref_call, Expr ReluForwardRewrite(const Call& ref_call,
...@@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call, ...@@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
Expr lhs = new_args[0]; Expr lhs = new_args[0];
Expr rhs = new_args[1]; Expr rhs = new_args[1];
auto rnode = make_node<ScaledExprNode>(); auto rnode = make_node<ScaledExprNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) { if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
IsAllPositiveConstant(rhs)) {
rnode->value = lhs; rnode->value = lhs;
rnode->scale = rhs; rnode->scale = rhs;
rnode->axes = expected_out_axes; rnode->axes = expected_out_axes;
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) { return Expr(rnode);
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
IsAllPositiveConstant(lhs)) {
rnode->value = rhs; rnode->value = rhs;
rnode->scale = lhs; rnode->scale = lhs;
rnode->axes = expected_out_axes; rnode->axes = expected_out_axes;
return Expr(rnode);
} else {
return Expr();
} }
return Expr(rnode);
} }
RELAY_REGISTER_OP("multiply") RELAY_REGISTER_OP("multiply")
...@@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract") ...@@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP("subtract") RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Find relu in the backward path between multiply and conv2d
bool FindBackwardRelu(const Expr& expr) {
const CallNode* call = expr.as<CallNode>();
static const Op& conv2d = Op::Get("nn.conv2d");
static const Op& relu = Op::Get("nn.relu");
if (!call) return false;
if (call->op.same_as(relu)) return true;
if (call->op.same_as(conv2d)) return false;
for (size_t i = 0; i < call->args.size(); i++) {
if (FindBackwardRelu(call->args[i])) return true;
}
return false;
}
// Producer operators // Producer operators
// Multiply produces the scale-axis pair. // Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call, Expr MultiplyBackwardTransform(const Call& call,
...@@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call, ...@@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part. // NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part. // since there won't be scale chance within scale part.
Expr rhs = call->args[1]; Expr rhs = call->args[1];
// Only propagate positive scaling.
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
(!FindBackwardRelu(call->args[0]) || IsAllPositiveConstant(rhs)) {
IsPositiveConstant(call->args[1]))) {
return transformer->Transform(call->args[0], lhs_axes, rhs); return transformer->Transform(call->args[0], lhs_axes, rhs);
} }
} else if (rhs_axes.defined() && rhs_axes.size() != 0) { } else if (rhs_axes.defined() && rhs_axes.size() != 0) {
// Only propagate positive scaling.
Expr lhs = call->args[0]; Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
(!FindBackwardRelu(call->args[1]) || IsAllPositiveConstant(lhs)) {
IsPositiveConstant(call->args[0]))) {
return transformer->Transform(call->args[1], rhs_axes, lhs); return transformer->Transform(call->args[1], rhs_axes, lhs);
} }
} }
......
...@@ -22,6 +22,15 @@ namespace relay { ...@@ -22,6 +22,15 @@ namespace relay {
std::unordered_map<const Node*, size_t> std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body); GetExprRefCount(const Expr& body);
/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
* \return Whether all elements of expr is positive constant.
*/
bool IsAllPositiveConstant(const Expr& expr);
/*! /*!
* \brief Substitute var with subst. * \brief Substitute var with subst.
* \param type The type to be substituted. * \param type The type to be substituted.
......
...@@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis); ...@@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
inline bool IsPositiveConstant(const Expr& expr) {
const auto* constant = expr.as<ConstantNode>();
if (!constant) return false;
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
// pass
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
}
LOG(WARNING) << "Unsupported data type (code = " << dtype.code
<< ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
<< ")";
return false;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
...@@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) { ...@@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) {
return ExprRefCounter().Get(body); return ExprRefCounter().Get(body);
} }
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
bool IsAllPositiveConstant(const Expr& expr) {
// peel through a few common transform ops.
static const auto& expand_dims = Op::Get("expand_dims");
static const auto& reshape = Op::Get("reshape");
static const auto& transpose = Op::Get("transpose");
static const auto& squeeze = Op::Get("squeeze");
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
return false;
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
} else {
return false;
}
} else if (const auto* op = expr.as<CallNode>()) {
// tail recursion.
if (op->op.same_as(expand_dims) ||
op->op.same_as(reshape) ||
op->op.same_as(transpose) ||
op->op.same_as(squeeze)) {
return IsAllPositiveConstant(op->args[0]);
} else {
return false;
}
} else {
return false;
}
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
from tvm import relay from tvm import relay
import numpy as np import numpy as np
def _get_positive_scale(size):
return np.random.uniform(0.5, 1, size=size).astype('float32')
def test_fold_fwd_simple(): def test_fold_fwd_simple():
"""Simple testcase.""" """Simple testcase."""
...@@ -14,6 +17,7 @@ def test_fold_fwd_simple(): ...@@ -14,6 +17,7 @@ def test_fold_fwd_simple():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1)) padding=(1, 1))
return relay.Function(args, y) return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels): def expected(x, conv_weight, in_bias, in_scale, channels):
...@@ -37,14 +41,14 @@ def test_fold_fwd_simple(): ...@@ -37,14 +41,14 @@ def test_fold_fwd_simple():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32')) in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params} type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded) y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected) y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
...@@ -107,7 +111,7 @@ def test_fold_fwd_dual_path(): ...@@ -107,7 +111,7 @@ def test_fold_fwd_dual_path():
assert in_channels == channels assert in_channels == channels
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) in_scale = relay.const(_get_positive_scale(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
...@@ -141,7 +145,7 @@ def test_fold_fwd_fail(): ...@@ -141,7 +145,7 @@ def test_fold_fwd_fail():
assert in_channels == channels assert in_channels == channels
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment