Commit 93949456 by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220)

parent 166936cd
...@@ -150,13 +150,14 @@ def optimize(func, params=None): ...@@ -150,13 +150,14 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func) func = ir_pass.combine_parallel_conv2d(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)
if cfg.pass_enabled("FoldScaleAxis"): if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func) func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func) func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func) func = ir_pass.forward_fold_scale_axis(func)
if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func) func = ir_pass.fold_constant(func)
if cfg.pass_enabled("AlterOpLayout"): if cfg.pass_enabled("AlterOpLayout"):
......
...@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor { ...@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward // Per operator defs for FScaleAxisForward
//---------------------------------------------- //----------------------------------------------
// Helper functions
Expr GetForwardScale(const Expr& expr, AxesSet out) {
static const Op& multiply = Op::Get("multiply");
static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
const CallNode* call = expr.as<CallNode>();
if (!call) return NullValue<Expr>();
auto f = fprep.get(call->op, nullptr);
if (call->op.same_as(multiply)) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
return call->args[1];
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
return call->args[0];
} else {
return NullValue<Expr>();
}
} else if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
for (size_t i = 0; i < call->args.size(); i++) {
auto scale = GetForwardScale(call->args[i], in_axes[i]);
if (scale.defined()) {
return scale;
}
}
}
return NullValue<Expr>();
}
// Intermediate operators // Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out}; Expr scale = GetForwardScale(call->args[0], out);
if (IsPositiveConstant(scale)) {
return {out};
}
return {NullValue<AxesSet>()};
} }
Expr ReluForwardRewrite(const Call& ref_call, Expr ReluForwardRewrite(const Call& ref_call,
...@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract") ...@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP("subtract") RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Find relu in the backward path between multiply and conv2d
bool FindBackwardRelu(const Expr& expr) {
const CallNode* call = expr.as<CallNode>();
static const Op& conv2d = Op::Get("nn.conv2d");
static const Op& relu = Op::Get("nn.relu");
if (!call) return false;
if (call->op.same_as(relu)) return true;
if (call->op.same_as(conv2d)) return false;
for (size_t i = 0; i < call->args.size(); i++) {
if (FindBackwardRelu(call->args[i])) return true;
}
return false;
}
// Producer operators // Producer operators
// Multiply produces the scale-axis pair. // Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call, Expr MultiplyBackwardTransform(const Call& call,
...@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call, ...@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part. // NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part. // since there won't be scale chance within scale part.
Expr rhs = call->args[1]; Expr rhs = call->args[1];
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) { if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
(!FindBackwardRelu(call->args[0]) ||
IsPositiveConstant(call->args[1]))) {
return transformer->Transform(call->args[0], lhs_axes, rhs); return transformer->Transform(call->args[0], lhs_axes, rhs);
} }
} else if (rhs_axes.defined() && rhs_axes.size() != 0) { } else if (rhs_axes.defined() && rhs_axes.size() != 0) {
Expr lhs = call->args[0]; Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) { if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
(!FindBackwardRelu(call->args[1]) ||
IsPositiveConstant(call->args[0]))) {
return transformer->Transform(call->args[1], rhs_axes, lhs); return transformer->Transform(call->args[1], rhs_axes, lhs);
} }
} }
......
...@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis); ...@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
template <typename T>
bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
CHECK_EQ(tensor->ctx.device_type, kDLCPU);
CHECK(tensor->strides == nullptr);
CHECK_EQ(tensor->byte_offset, 0);
const T* data = static_cast<const T*>(tensor->data);
int64_t num_elems = 1;
for (int i = 0; i < tensor->ndim; ++i) {
num_elems *= tensor->shape[i];
}
for (int64_t i = 0; i < num_elems; i++) {
if (*data < value) {
return false;
}
data++;
}
return true;
}
inline bool IsPositiveConstant(const Expr& expr) {
const auto* constant = expr.as<ConstantNode>();
if (!constant) return false;
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
if (dtype.lanes != 1) {
// pass
} else if (dtype.code == kDLFloat && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<float>(tensor, 0);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
return IsNDArrayAllGreaterEqual<double>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
} else if (dtype.code == kDLInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
} else if (dtype.code == kDLUInt && dtype.bits == 32) {
return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
}
LOG(WARNING) << "Unsupported data type (code = " << dtype.code
<< ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
<< ")";
return false;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment