Unverified Commit d5103bbc by Tianqi Chen Committed by GitHub

[RELAY][PASS] FoldScaleAxis Backward (#2024)

parent 25e4dc51
...@@ -135,9 +135,9 @@ class ExprVisitor ...@@ -135,9 +135,9 @@ class ExprVisitor
void VisitExpr_(const TupleGetItemNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t); virtual void VisitType(const Type& t);
private: protected:
// internal visited flag. // Internal visiting counter
std::unordered_set<const Node*> visited_; std::unordered_map<const Node*, size_t> visit_counter_;
}; };
/*! /*!
......
...@@ -31,6 +31,29 @@ def infer_type(expr, env=None): ...@@ -31,6 +31,29 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env) return _ir_pass.infer_type(expr, env)
def backward_fold_scale_axis(expr):
"""Backward fold axis scaling into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.backward_fold_scale_axis(expr)
def forward_fold_scale_axis(expr): def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense. """Fold the scaling of axis into weights of conv2d/dense.
...@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr): ...@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
------- -------
folded_expr : tvm.relay.Expr folded_expr : tvm.relay.Expr
The folded expression after transformation. The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
""" """
return _ir_pass.forward_fold_scale_axis(expr) return _ir_pass.forward_fold_scale_axis(expr)
......
...@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { ...@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; } Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) { void ExprVisitor::VisitExpr(const Expr& expr) {
if (visited_.count(expr.get())) return; auto it = visit_counter_.find(expr.get());
using TParent = ExprFunctor<void(const Expr&)>; if (it != visit_counter_.end()) {
TParent::VisitExpr(expr); ++it->second;
visited_.insert(expr.get()); } else {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visit_counter_.insert({expr.get(), 1});
}
} }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include "../op/nn/layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, ...@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
return bias; return bias;
} }
/*!
* \brief Check if the call is depthwise conv2d.
*
* \param call The conv2d call.
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
*/
inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param,
const Layout& weight_layout) {
static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape,
weight_layout, kOIHW);
return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1);
}
inline Expr Multiply(Expr lhs, Expr rhs) { inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply"); static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
} }
inline Expr Divide(Expr lhs, Expr rhs) { inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide"); static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
...@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { ...@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
} }
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
...@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path(): ...@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWOI", weight_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight, y2 = relay.nn.conv2d(x, conv_weight,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWOI", weight_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
z = relay.add(y1, y2) z = relay.add(y1, y2)
...@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path(): ...@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWOI", weight_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
y2 = relay.nn.conv2d(x, y2 = relay.nn.conv2d(x,
...@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path(): ...@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path():
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
data_layout="NHWC", data_layout="NHWC",
weight_layout="HWOI", weight_layout="HWIO",
groups=channels, groups=channels,
padding=(1, 1)) padding=(1, 1))
z = relay.add(y1, y2) z = relay.add(y1, y2)
...@@ -147,7 +147,176 @@ def test_fold_fwd_fail(): ...@@ -147,7 +147,176 @@ def test_fold_fwd_fail():
check((2, 11, 10, 4), 4) check((2, 11, 10, 4), 4)
def test_fold_bwd_simple():
"""Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.add(y, out_bias)
y = relay.nn.relu(y)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
out_bias = relay.multiply(out_bias,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.add(y, out_bias)
y = relay.nn.relu(y)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
def test_fold_bwd_dual_path():
"""Dual path testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
def fold_conv_weight():
return relay.multiply(
conv_weight ,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y1 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
def test_fold_bwd_fail():
"""Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
out_layout="CNHW")
# fold will fail because the axis from two path
# differs from each other.
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def fail2(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y1)
# fold will fail because y1 is referred also by y2
y1 = relay.multiply(y1, out_scale)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels, fbefore):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1_folded, y1)
check((4, 4, 10, 10), 4, fail1)
check((4, 4, 10, 10), 4, fail2)
if __name__ == "__main__": if __name__ == "__main__":
test_fold_fwd_simple() test_fold_fwd_simple()
test_fold_fwd_dual_path() test_fold_fwd_dual_path()
test_fold_fwd_fail() test_fold_fwd_fail()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment