Commit 87a37684 by Tianqi Chen Committed by ziheng

[PASS] Avoid recursion in FoldScaleAxis (#2299)

* [PASS] Avoid recursion in FoldScaleAxis

* remove GetForwardScale
parent e9e12f03
......@@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Helper functions
Expr GetForwardScale(const Expr& expr, AxesSet out) {
static const Op& multiply = Op::Get("multiply");
static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
const CallNode* call = expr.as<CallNode>();
if (!call) return NullValue<Expr>();
auto f = fprep.get(call->op, nullptr);
if (call->op.same_as(multiply)) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
return call->args[1];
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
return call->args[0];
} else {
return NullValue<Expr>();
}
} else if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
for (size_t i = 0; i < call->args.size(); i++) {
auto scale = GetForwardScale(call->args[i], in_axes[i]);
if (scale.defined()) {
return scale;
}
}
}
return NullValue<Expr>();
}
// Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
Expr scale = GetForwardScale(call->args[0], out);
if (IsPositiveConstant(scale)) {
return {out};
}
return {NullValue<AxesSet>()};
}
Expr ReluForwardRewrite(const Call& ref_call,
......@@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
Expr lhs = new_args[0];
Expr rhs = new_args[1];
auto rnode = make_node<ScaledExprNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
IsAllPositiveConstant(rhs)) {
rnode->value = lhs;
rnode->scale = rhs;
rnode->axes = expected_out_axes;
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) {
return Expr(rnode);
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
IsAllPositiveConstant(lhs)) {
rnode->value = rhs;
rnode->scale = lhs;
rnode->axes = expected_out_axes;
}
return Expr(rnode);
} else {
return Expr();
}
}
RELAY_REGISTER_OP("multiply")
......@@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Find relu in the backward path between multiply and conv2d
bool FindBackwardRelu(const Expr& expr) {
const CallNode* call = expr.as<CallNode>();
static const Op& conv2d = Op::Get("nn.conv2d");
static const Op& relu = Op::Get("nn.relu");
if (!call) return false;
if (call->op.same_as(relu)) return true;
if (call->op.same_as(conv2d)) return false;
for (size_t i = 0; i < call->args.size(); i++) {
if (FindBackwardRelu(call->args[i])) return true;
}
return false;
}
// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
......@@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr rhs = call->args[1];
// Only propagate positive scaling.
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
(!FindBackwardRelu(call->args[0]) ||
IsPositiveConstant(call->args[1]))) {
IsAllPositiveConstant(rhs)) {
return transformer->Transform(call->args[0], lhs_axes, rhs);
}
} else if (rhs_axes.defined() && rhs_axes.size() != 0) {
// Only propagate positive scaling.
Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
(!FindBackwardRelu(call->args[1]) ||
IsPositiveConstant(call->args[0]))) {
IsAllPositiveConstant(lhs)) {
return transformer->Transform(call->args[1], rhs_axes, lhs);
}
}
......
......@@ -22,6 +22,15 @@ namespace relay {
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body);
/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
* \return Whether all elements of expr is positive constant.
*/
bool IsAllPositiveConstant(const Expr& expr);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
......
......@@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
inline bool IsPositiveConstant(const Expr& expr) {
const auto* constant = expr.as<ConstantNode>();
if (!constant) return false;
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
// pass
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
}
LOG(WARNING) << "Unsupported data type (code = " << dtype.code
<< ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
<< ")";
return false;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) {
return ExprRefCounter().Get(body);
}
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
bool IsAllPositiveConstant(const Expr& expr) {
// peel through a few common transform ops.
static const auto& expand_dims = Op::Get("expand_dims");
static const auto& reshape = Op::Get("reshape");
static const auto& transpose = Op::Get("transpose");
static const auto& squeeze = Op::Get("squeeze");
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
return false;
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
} else {
return false;
}
} else if (const auto* op = expr.as<CallNode>()) {
// tail recursion.
if (op->op.same_as(expand_dims) ||
op->op.same_as(reshape) ||
op->op.same_as(transpose) ||
op->op.same_as(squeeze)) {
return IsAllPositiveConstant(op->args[0]);
} else {
return false;
}
} else {
return false;
}
}
} // namespace relay
} // namespace tvm
from tvm import relay
import numpy as np
def _get_positive_scale(size):
return np.random.uniform(0.5, 1, size=size).astype('float32')
def test_fold_fwd_simple():
"""Simple testcase."""
......@@ -14,6 +17,7 @@ def test_fold_fwd_simple():
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_bias, in_scale, channels):
......@@ -37,14 +41,14 @@ def test_fold_fwd_simple():
in_channels = shape[1]
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32'))
in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
......@@ -107,7 +111,7 @@ def test_fold_fwd_dual_path():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
in_scale = relay.const(_get_positive_scale(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......@@ -141,7 +145,7 @@ def test_fold_fwd_fail():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment