Unverified Commit d5103bbc by Tianqi Chen Committed by GitHub

[RELAY][PASS] FoldScaleAxis Backward (#2024)

parent 25e4dc51
......@@ -135,9 +135,9 @@ class ExprVisitor
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);
private:
// internal visited flag.
std::unordered_set<const Node*> visited_;
protected:
// Internal visiting counter
std::unordered_map<const Node*, size_t> visit_counter_;
};
/*!
......
......@@ -31,6 +31,29 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)
def backward_fold_scale_axis(expr):
"""Backward fold axis scaling into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.backward_fold_scale_axis(expr)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
......@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.forward_fold_scale_axis(expr)
......
......@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) {
if (visited_.count(expr.get())) return;
auto it = visit_counter_.find(expr.get());
if (it != visit_counter_.end()) {
++it->second;
} else {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visited_.insert(expr.get());
visit_counter_.insert({expr.get(), 1});
}
}
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
......
......@@ -24,9 +24,9 @@ namespace fold_scale_axis {
using runtime::TypedPackedFunc;
// FoldScaleAxisFoward algorithm:
// FoldScaleAxis algorithm:
//
// The general idea is that we transform Expr to tuple of
// The general idea is to transform Expr to tuple of
// (value, axes, scale), where the final result satiesfies:
//
// result = value
......@@ -41,9 +41,14 @@ using runtime::TypedPackedFunc;
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
// The folding process is done in two steps:
// Forward folding process is done in two steps:
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
//
// Similarly, backward folding process is done in two steps:
// - Prepare phase: forward propagation of demand.
// - Transform phase: transformation by push down the axes scale signal to inputs.
//
/*!
* \brief sorted array axis, can also be nullptr.
......@@ -99,7 +104,7 @@ ValueType GetFunc(const OpMap<ValueType>& op_map,
}
/*!
* \brief Preparation function for for pass scale forward.
* \brief Preparation function for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
* \return The result scaling on axes of the input.
......@@ -144,7 +149,7 @@ using FForwardTransform = TypedPackedFunc<
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class FScaleAxisForwardPrep : private ExprVisitor {
class ForwardPrep : private ExprVisitor {
public:
std::unordered_map<const Node*, AxesSet>
Prepare(const Expr& body) {
......@@ -255,12 +260,12 @@ class FScaleAxisForwardPrep : private ExprVisitor {
}
};
class FScaleAxisForwardTransform : private ExprMutator {
class ForwardTransformer : private ExprMutator {
public:
// Transform expression.
Expr Transform(Expr expr) {
Expr Fold(Expr expr) {
expected_scale_axes_ =
FScaleAxisForwardPrep().Prepare(expr);
ForwardPrep().Prepare(expr);
return this->Mutate(expr);
}
......@@ -352,7 +357,7 @@ STuple ReluForwardTransform(const Call& ref_call,
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
ref_call->op, {sargs[0]->value}, ref_call->attrs, {});
ref_call->op, {sargs[0]->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = sargs[0]->scale;
rnode->axes = sargs[0]->axes;
return STuple(rnode);
......@@ -474,8 +479,6 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
Layout weight_layout(param->weight_layout);
int c_big_axis = data_layout.indexof('C');
int c_small_axis = data_layout.indexof('c');
const auto* tdata = call->args[0]->type_as<TensorTypeNode>();
CHECK(tdata) << "require checked type";
CHECK_GE(c_big_axis, 0);
AxesSet data_axes = NullValue<AxesSet>();
......@@ -486,8 +489,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
......@@ -515,18 +517,24 @@ STuple Conv2DForwardTransform(const Call& ref_call,
CHECK_EQ(weight_layout.indexof('i'), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
int big_oc_axis = weight_layout.indexof('O');
int big_ic_axis = weight_layout.indexof('I');
const auto* tdata = ref_call->args[0]->type_as<TensorTypeNode>();
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d =
is_const_int(tdata->shape[c_big_axis], param->groups);
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr weight = sargs[1]->value;
// match the ic_axis
if (is_depthwise_conv2d) {
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale);
} else {
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, weight_layout.ndim(), {big_ic_axis});
Expr weight = Multiply(sargs[1]->value, scale);
weight = Multiply(weight, scale);
}
// return transformed conv2d
auto rnode = make_node<STupleNode>();
rnode->value = CallNode::make(
......@@ -542,13 +550,416 @@ RELAY_REGISTER_OP("nn.conv2d")
Expr ForwardFoldScaleAxis(Expr data) {
return FScaleAxisForwardTransform().Transform(data);
return ForwardTransformer().Fold(data);
}
// Expose the FoldScaleAxisFoward
TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
//----------------------------------------
// Implement backward transformations.
//----------------------------------------
class BackwardTransformer;
/*!
* \brief Preparation function for for pass scale backward.
* \param call The call node.
* \param in_scale_axes Allowed input scaling.
* \return The result scaling on axes of the input.
*/
using FBackwardPrep = TypedPackedFunc<
AxesSet(const Call& call, const Array<AxesSet>& in_scale_axes)>;
using FBackwardTransform = TypedPackedFunc<
Expr(const Call& call,
const AxesSet& axes,
const Expr& scale,
const BackwardTransformer& transformer)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}
class BackwardPrep : private ExprVisitor {
public:
// The message on each node.
std::unordered_map<const Node*, AxesSet>
Prepare(const Expr& body) {
ref_counter_ = GetExprRefCount(body);
this->VisitExpr(body);
return std::move(message_);
}
private:
// The message on each node.
std::unordered_map<const Node*, AxesSet> message_;
// reference counter of an internal expr
std::unordered_map<const Node*, size_t> ref_counter_;
// Visit the expression.
void VisitExpr_(const CallNode* call) {
ExprVisitor::VisitExpr_(call);
static const auto& fprep =
Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
auto f = GetFunc(fprep, call->op);
if (f == nullptr) return;
auto rit = ref_counter_.find(call);
CHECK(rit != ref_counter_.end());
// We only allow propagation of scale backward
// if the expression is only referred by a single parent.
if (rit->second != 1) return;
Array<AxesSet> in_axes;
for (Expr arg : call->args) {
auto it = message_.find(arg.get());
if (it != message_.end()) {
in_axes.push_back(it->second);
} else {
in_axes.push_back(NullValue<AxesSet>());
}
}
AxesSet out_axes = f(GetRef<Call>(call), in_axes);
if (out_axes.defined()) {
message_[call] = out_axes;
}
}
};
class BackwardTransformerNode :
public Node,
private ExprMutator {
public:
// Run forward transform.
Expr Fold(Expr expr) {
expected_scale_axes_ = BackwardPrep().Prepare(expr);
return this->Mutate(expr);
}
/*!
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param axes The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr Transform(const Expr& expr, AxesSet axes, Expr scale) {
// NOTE: the result of Transform is not memoized.
// However, in the current rule, Transform will
// only be called to expr that is referred once.
if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, axes, scale);
} else {
CHECK(!axes.defined()) << "outstanding scale";
return ExprMutator::VisitExpr(expr);
}
}
/*!
* \brief Normal way of mutating call node.
* \param call_node The call node to be mutated.
* \return the result of the call Mutation.
*/
Expr NormalCallTransform(const CallNode* call_node) {
return ExprMutator::VisitExpr_(call_node);
}
/*!
* \brief Get the expected axes on expr.
* \param expr The expresison.
* \return The expected axes.
*/
AxesSet GetExpectedAxes(const Expr& expr) const {
auto it = expected_scale_axes_.find(expr.get());
if (it != expected_scale_axes_.end()) return it->second;
return NullValue<AxesSet>();
}
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);
private:
// Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
// Override mutation of call.
Expr VisitExpr_(const CallNode* call_node) final {
return Transform(call_node, NullValue<AxesSet>(), NullValue<Expr>());
}
// Transform of CallNode.
Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale);
};
class BackwardTransformer : public NodeRef {
public:
BackwardTransformer() {}
explicit BackwardTransformer(
::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
}
BackwardTransformerNode* operator->() const {
return static_cast<BackwardTransformerNode*>(node_.get());
}
using ContainerType = BackwardTransformerNode;
};
Expr BackwardTransformerNode::Transform(
const CallNode* call_node, AxesSet axes, Expr scale) {
static const auto& ftransform =
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = GetFunc(ftransform, call_node->op);
if (f != nullptr) {
return f(GetRef<Call>(call_node),
axes,
scale,
GetRef<BackwardTransformer>(this));
} else {
CHECK(!axes.defined()) << "outstanding scale";
return NormalCallTransform(call_node);
}
}
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Intermediate operators
AxesSet ReluBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
return in_axes[0];
}
Expr ReluBackwardTransform(const Call& call,
const AxesSet& axes,
const Expr& scale,
const BackwardTransformer& transformer) {
if (!axes.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Expr input = transformer->Transform(
call->args[0], axes, scale);
return CallNode::make(call->op, {input}, call->attrs, call->type_args);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
// AddSub
AxesSet AddSubBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AttrsEqual equal;
if (in_axes[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) {
return in_axes[0];
} else if (in_axes[1].defined() &&
MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) {
return in_axes[1];
} else if (in_axes[0].defined() &&
in_axes[1].defined() &&
equal(in_axes[0], in_axes[1]) &&
equal(tlhs->shape, trhs->shape)) {
// add of two elements.
return in_axes[0];
} else {
return NullValue<AxesSet>();
}
}
Expr AddSubBackwardTransform(const Call& call,
const AxesSet& axes,
const Expr& scale,
const BackwardTransformer& transformer) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (!axes.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
AttrsEqual equal;
if (lhs_axes.defined() && rhs_axes.defined()) {
CHECK(equal(lhs_axes, rhs_axes));
CHECK(equal(axes, lhs_axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale);
Expr rhs = transformer->Transform(call->args[1], axes, scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (lhs_axes.defined()) {
CHECK(equal(axes, lhs_axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale);
Expr rhs = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
Expr rhs_scale = ExpandBiasToMatchAxis(
scale, tlhs->shape.size(), axes);
rhs = Multiply(rhs, rhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_axes.defined()) {
CHECK(equal(axes, rhs_axes));
Expr lhs = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], axes, scale);
Expr lhs_scale = ExpandBiasToMatchAxis(
scale, trhs->shape.size(), axes);
lhs = Multiply(lhs, lhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
LOG(FATAL) << "outstanding scale";
return Expr();
}
}
RELAY_REGISTER_OP("add")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
RELAY_REGISTER_OP("add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
const AxesSet& axes,
const Expr& scale,
const BackwardTransformer& transformer) {
CHECK(!axes.defined()) << "outstanding scale";
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
if (lhs_axes.defined()) {
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr rhs = call->args[1];
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) {
return transformer->Transform(call->args[0], lhs_axes, rhs);
}
} else if (rhs_axes.defined()) {
Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) {
return transformer->Transform(call->args[1], rhs_axes, lhs);
}
}
return transformer->NormalCallTransform(call.operator->());
}
RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
// Consumer operators
// Conv2D send out requirement of axis folding.
AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout out_layout(param->out_layout);
if (!out_layout.defined()) {
out_layout = Layout(param->data_layout);
}
Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C');
int c_small_axis = out_layout.indexof('c');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
if (weight_layout.indexof('o') < 0 &&
weight_layout.indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis};
} else {
return NullValue<AxesSet>();
}
}
// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call,
const AxesSet& axes,
const Expr& scale,
const BackwardTransformer& transformer) {
if (!axes.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout out_layout(param->out_layout);
if (!out_layout.defined()) {
out_layout = Layout(param->data_layout);
}
Layout weight_layout(param->weight_layout);
int c_big_axis = out_layout.indexof('C');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(weight_layout.indexof('o'), -1);
CHECK_EQ(weight_layout.indexof('i'), -1);
CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value);
int big_oc_axis = weight_layout.indexof('O');
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr data = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
Expr weight = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis(
scale, weight_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, wscale);
return CallNode::make(
call->op, {data, weight}, call->attrs, call->type_args);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
Expr BackwardFoldScaleAxis(Expr data) {
return make_node<BackwardTransformerNode>()->Fold(data);
}
TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);
} // namespace fold_scale_axis
} // namespace relay
} // namespace tvm
......@@ -11,6 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include "../op/nn/layout.h"
namespace tvm {
namespace relay {
......@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
return bias;
}
/*!
* \brief Check if the call is depthwise conv2d.
*
* \param call The conv2d call.
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
*/
inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param,
const Layout& weight_layout) {
static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape,
weight_layout, kOIHW);
return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1);
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Divide(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("divide");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
......@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path():
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
weight_layout="HWIO",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
weight_layout="HWIO",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
......@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path():
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
weight_layout="HWIO",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
......@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path():
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
weight_layout="HWOI",
weight_layout="HWIO",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
......@@ -147,7 +147,176 @@ def test_fold_fwd_fail():
check((2, 11, 10, 4), 4)
def test_fold_bwd_simple():
"""Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.add(y, out_bias)
y = relay.nn.relu(y)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
out_bias = relay.multiply(out_bias,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.add(y, out_bias)
y = relay.nn.relu(y)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
def test_fold_bwd_dual_path():
"""Dual path testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_bias, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
def fold_conv_weight():
return relay.multiply(
conv_weight ,
relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y1 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, fold_conv_weight(),
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_bias, out_scale, channels)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
def test_fold_bwd_fail():
"""Dual path testcase."""
def fail1(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
out_layout="CNHW")
# fold will fail because the axis from two path
# differs from each other.
y2 = relay.nn.relu(y2)
y = relay.add(y1, y2)
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def fail2(x, conv_weight, out_bias, out_scale, channels):
args = [x, conv_weight, out_bias, out_scale]
out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y1 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y2 = relay.nn.relu(y1)
# fold will fail because y1 is referred also by y2
y1 = relay.multiply(y1, out_scale)
y = relay.add(y1, y2)
return relay.Function(args, y)
def check(shape, channels, fbefore):
x = relay.var("x", shape=shape)
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.var("out_scale", shape=(channels,))
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
assert relay.ir_pass.alpha_equal(y1_folded, y1)
check((4, 4, 10, 10), 4, fail1)
check((4, 4, 10, 10), 4, fail2)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment