Commit 2a871f35 by Wuwei Lin Committed by ziheng

[RELAY][PASS] Support Negative Scale in FoldScaleAxis (#2426)

* [RELAY][PASS] Support Negative Scale in FoldScaleAxis

* Fix comment
parent 52e3bd32
......@@ -59,6 +59,36 @@ using runtime::TypedPackedFunc;
*/
using AxesSet = Array<Integer>;
class Message;
/*!
* \brief Message propogated during the prepare phase.
*/
class MessageNode : public RelayNode {
public:
/*! \brief Axes for scaling */
AxesSet axes;
/*!
* \brief Whether folding requires the scale to be positive constant. This is necessary if some
* operators (e.g. Relu) is present.
*/
bool require_positive;
static Message make(const AxesSet& axes, bool require_positive);
static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode);
};
RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef);
Message MessageNode::make(const AxesSet& axes, bool require_positive) {
auto n = make_node<MessageNode>();
n->axes = axes;
n->require_positive = require_positive;
return Message(n);
}
/*!
* \brief Merge two axis set together by taking
* intersection.
......@@ -89,13 +119,28 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
}
/*!
* \brief Merge two messages together by taking intersection.
*
* \param lhs The lhs message.
* \param rhs The rhs message.
* \return The result of intersection.
*/
Message Intersect(const Message& lhs, const Message& rhs) {
if (!lhs.defined()) return lhs;
if (!rhs.defined()) return rhs;
auto axes = Intersect(lhs->axes, rhs->axes);
return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
}
/*!
* \brief Preparation function for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
* \return The result scaling on axes of the input.
* \param out_message Message from the output containing possible scaling on axes and whether
* positive scale is required.
* \return The message containing the result scaling on axes of the input.
*/
using FForwardPrep = runtime::TypedPackedFunc<
Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>;
Array<Message> (const Call& call, const Message& out_message)>;
/*! \brief Axis scale tuple. */
class ScaledExprNode : public TempExprNode {
......@@ -126,16 +171,16 @@ class ScaledExprNode : public TempExprNode {
using FForwardRewrite = TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expeced_out_axes)>;
const Message& message)>;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class ForwardPrep : private ExprVisitor {
public:
std::unordered_map<const Node*, AxesSet>
std::unordered_map<const Node*, Message>
Prepare(const Expr& body) {
this->Update(body, NullValue<AxesSet>());
this->Update(body, NullValue<Message>());
this->VisitExpr(body);
// flist is added in the Post-DFS order
// which is a special case of topological order.
......@@ -152,9 +197,9 @@ class ForwardPrep : private ExprVisitor {
// The invoke list
std::vector<std::function<void()> > flist_;
// The message on each node.
std::unordered_map<const Node*, AxesSet> message_;
std::unordered_map<const Node*, Message> message_;
// Update the message stored at node.
void Update(const Expr& node, const AxesSet& axes) {
void Update(const Expr& node, const Message& message) {
// We run intersection of messages:
//
// %y = multiply(%x, %scale)
......@@ -167,9 +212,9 @@ class ForwardPrep : private ExprVisitor {
// and the forward folding won't be triggered.
const Node* key = node.get();
if (message_.count(key)) {
message_[key] = Intersect(message_[key], axes);
message_[key] = Intersect(message_[key], message);
} else {
message_[key] = axes;
message_[key] = message;
}
}
// Visitor pattern override.
......@@ -180,7 +225,7 @@ class ForwardPrep : private ExprVisitor {
void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
auto flazy = [this, op] {
this->Update(op->body, NullValue<AxesSet>());
this->Update(op->body, NullValue<Message>());
};
flist_.push_back(flazy);
}
......@@ -193,23 +238,23 @@ class ForwardPrep : private ExprVisitor {
Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node.
auto it = message_.find(call);
AxesSet out_axes;
Message out_message;
if (it != message_.end()) {
out_axes = it->second;
out_message = it->second;
} else {
out_axes = NullValue<AxesSet>();
out_message = NullValue<Message>();
}
// pass the message back to all the children it references.
auto f = fprep.get(call->op, nullptr);
if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes);
CHECK_EQ(in_axes.size(), call->args.size());
Array<Message> in_messages = f(GetRef<Call>(call), out_message);
CHECK_EQ(in_messages.size(), call->args.size());
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], in_axes[i]);
this->Update(call->args[i], in_messages[i]);
}
} else {
for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], NullValue<AxesSet>());
this->Update(call->args[i], NullValue<Message>());
}
}
};
......@@ -221,7 +266,7 @@ class ForwardPrep : private ExprVisitor {
// do not support pass scale through tuple for now.
auto flazy = [this, op]() {
for (const Expr& field : op->fields) {
this->Update(field, NullValue<AxesSet>());
this->Update(field, NullValue<Message>());
}
};
flist_.push_back(flazy);
......@@ -230,13 +275,13 @@ class ForwardPrep : private ExprVisitor {
void VisitExpr_(const IfNode* op) {
ExprVisitor::VisitExpr_(op);
// do pass through condition
// by assigning NullValue<AxesSet>
// by assigning NullValue<Message>
// it means fuse signal cannot pass
// through into these subexpressions.
auto flazy = [this, op]() {
this->Update(op->cond, NullValue<AxesSet>());
this->Update(op->true_branch, NullValue<AxesSet>());
this->Update(op->false_branch, NullValue<AxesSet>());
this->Update(op->cond, NullValue<Message>());
this->Update(op->true_branch, NullValue<Message>());
this->Update(op->false_branch, NullValue<Message>());
};
flist_.push_back(flazy);
}
......@@ -247,13 +292,16 @@ class ForwardPrep : private ExprVisitor {
//----------------------------------------------
// Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return {out};
Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
if (out_message.defined()) {
return {MessageNode::make(out_message->axes, true)};
}
return {out_message};
}
Expr ReluForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_axes) {
const Message& message) {
const auto* input = new_args[0].as<ScaledExprNode>();
if (input == nullptr) return Expr(nullptr);
// return transformed conv2d
......@@ -278,23 +326,23 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
// AddSub
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
auto none = NullValue<AxesSet>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) {
return {out_axes, none};
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) {
return {none, out_axes};
} else {
return {none, none};
auto none = NullValue<Message>();
if (out_message.defined()) {
if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
return {out_message, none};
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
return {none, out_message};
}
}
return {none, none};
}
Expr AddSubForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_out_axes) {
const Message& message) {
const auto* slhs = new_args[0].as<ScaledExprNode>();
const auto* srhs = new_args[1].as<ScaledExprNode>();
if (!slhs && !srhs) return Expr();
......@@ -342,9 +390,10 @@ RELAY_REGISTER_OP("subtract")
// Multiply produces the scale-axis pair.
Expr MultiplyForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_out_axes) {
if (!expected_out_axes.defined()) return Expr();
if (expected_out_axes.size() == 0) return Expr();
const Message& message) {
if (!message.defined()) return Expr();
const auto& expected_out_axes = message->axes;
CHECK(expected_out_axes.defined() && expected_out_axes.size());
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
const auto* slhs = new_args[0].as<ScaledExprNode>();
......@@ -356,14 +405,15 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
Expr lhs = new_args[0];
Expr rhs = new_args[1];
auto rnode = make_node<ScaledExprNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
IsAllPositiveConstant(rhs)) {
(!message->require_positive || IsAllPositiveConstant(rhs))) {
rnode->value = lhs;
rnode->scale = rhs;
rnode->axes = expected_out_axes;
return Expr(rnode);
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
IsAllPositiveConstant(lhs)) {
(!message->require_positive || IsAllPositiveConstant(lhs))) {
rnode->value = rhs;
rnode->scale = lhs;
rnode->axes = expected_out_axes;
......@@ -378,7 +428,7 @@ RELAY_REGISTER_OP("multiply")
// Consumer operators
// Conv2D send out requirement of axis folding.
Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>();
......@@ -389,6 +439,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
int c_small_axis = data_layout.Indexof('c');
CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
......@@ -403,13 +454,16 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
return {data_axes, NullValue<AxesSet>()};
if (data_axes.defined()) {
return {MessageNode::make(data_axes, false), none};
}
return {none, none};
}
// Conv2D consumes the scale axis during transformation.
Expr Conv2DForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const AxesSet& expected_axes) {
const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
......@@ -458,11 +512,10 @@ RELAY_REGISTER_OP("nn.conv2d")
Expr ForwardFoldScaleAxis(Expr data) {
auto expected_scale_axes =
ForwardPrep().Prepare(data);
auto message = ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> NodeRef{
auto it = expected_scale_axes.find(call.get());
if (it != expected_scale_axes.end()) {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
} else {
return NodeRef(nullptr);
......@@ -484,15 +537,16 @@ class BackwardTransformer;
/*!
* \brief Preparation function for for pass scale backward.
* \param call The call node.
* \param in_scale_axes Allowed input scaling.
* \return The result scaling on axes of the input.
* \param in_messages Messages from the input containing allowed input scaling and whether
* positive scale is required.
* \return Message containing the result scaling on axes of the input.
*/
using FBackwardPrep = TypedPackedFunc<
AxesSet(const Call& call, const Array<AxesSet>& in_scale_axes)>;
Message(const Call& call, const Array<Message>& in_messages)>;
using FBackwardTransform = TypedPackedFunc<
Expr(const Call& call,
const AxesSet& axes,
const Message& message,
const Expr& scale,
const BackwardTransformer& transformer)>;
......@@ -503,7 +557,7 @@ using FBackwardTransform = TypedPackedFunc<
class BackwardPrep : private ExprVisitor {
public:
// The message on each node.
std::unordered_map<const Node*, AxesSet>
std::unordered_map<const Node*, Message>
Prepare(const Expr& body) {
ref_counter_ = GetExprRefCount(body);
this->VisitExpr(body);
......@@ -512,7 +566,7 @@ class BackwardPrep : private ExprVisitor {
private:
// The message on each node.
std::unordered_map<const Node*, AxesSet> message_;
std::unordered_map<const Node*, Message> message_;
// reference counter of an internal expr
std::unordered_map<const Node*, size_t> ref_counter_;
// Visit the expression.
......@@ -527,18 +581,18 @@ class BackwardPrep : private ExprVisitor {
// We only allow propagation of scale backward
// if the expression is only referred by a single parent.
if (rit->second != 1) return;
Array<AxesSet> in_axes;
Array<Message> in_messages;
for (Expr arg : call->args) {
auto it = message_.find(arg.get());
if (it != message_.end()) {
in_axes.push_back(it->second);
in_messages.push_back(it->second);
} else {
in_axes.push_back(NullValue<AxesSet>());
in_messages.push_back(NullValue<Message>());
}
}
AxesSet out_axes = f(GetRef<Call>(call), in_axes);
if (out_axes.defined()) {
message_[call] = out_axes;
Message out_message = f(GetRef<Call>(call), in_messages);
if (out_message.defined()) {
message_[call] = out_message;
}
}
};
......@@ -549,7 +603,7 @@ class BackwardTransformerNode :
public:
// Run forward transform.
Expr Fold(Expr expr) {
expected_scale_axes_ = BackwardPrep().Prepare(expr);
message_ = BackwardPrep().Prepare(expr);
return this->Mutate(expr);
}
/*!
......@@ -560,12 +614,12 @@ class BackwardTransformerNode :
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr Transform(const Expr& expr, AxesSet axes, Expr scale) {
Expr Transform(const Expr& expr, Message message, Expr scale) {
// NOTE: the result of Transform is memoized.
if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, axes, scale);
return Transform(call_node, message, scale);
} else {
CHECK(!axes.defined()) << "outstanding scale";
CHECK(!message.defined()) << "outstanding scale";
return ExprMutator::VisitExpr(expr);
}
}
......@@ -585,14 +639,14 @@ class BackwardTransformerNode :
return new_expr;
}
/*!
* \brief Get the expected axes on expr.
* \brief Get the message propogated to the expr.
* \param expr The expresison.
* \return The expected axes.
* \return The message containing the expected axes and whether positive scale is required.
*/
AxesSet GetExpectedAxes(const Expr& expr) const {
auto it = expected_scale_axes_.find(expr.get());
if (it != expected_scale_axes_.end()) return it->second;
return NullValue<AxesSet>();
Message GetMessage(const Expr& expr) const {
auto it = message_.find(expr.get());
if (it != message_.end()) return it->second;
return NullValue<Message>();
}
// solver is not serializable.
......@@ -603,13 +657,13 @@ class BackwardTransformerNode :
private:
// Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
std::unordered_map<const Node*, Message> message_;
// Override mutation of call.
Expr VisitExpr_(const CallNode* call_node) final {
return Transform(call_node, NullValue<AxesSet>(), NullValue<Expr>());
return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
}
// Transform of CallNode.
Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale);
Expr Transform(const CallNode* call_node, Message message, Expr scale);
};
class BackwardTransformer : public NodeRef {
......@@ -625,7 +679,7 @@ class BackwardTransformer : public NodeRef {
};
Expr BackwardTransformerNode::Transform(
const CallNode* call_node, AxesSet axes, Expr scale) {
const CallNode* call_node, Message message, Expr scale) {
static const auto& ftransform =
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
......@@ -636,13 +690,13 @@ Expr BackwardTransformerNode::Transform(
return it->second;
}
Expr new_expr = f(GetRef<Call>(call_node),
axes,
message,
scale,
GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else {
CHECK(!axes.defined()) << "outstanding scale";
CHECK(!message.defined()) << "outstanding scale";
return NormalCallTransform(call_node);
}
}
......@@ -653,19 +707,22 @@ Expr BackwardTransformerNode::Transform(
//----------------------------------------------
// Intermediate operators
AxesSet ReluBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
return in_axes[0];
Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
if (in_messages[0].defined()) {
return MessageNode::make(in_messages[0]->axes, true);
}
return in_messages[0];
}
Expr ReluBackwardTransform(const Call& call,
const AxesSet& axes,
const Message& message,
const Expr& scale,
const BackwardTransformer& transformer) {
if (!axes.defined()) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Expr input = transformer->Transform(
call->args[0], axes, scale);
call->args[0], message, scale);
return CallNode::make(call->op, {input}, call->attrs, call->type_args);
}
......@@ -682,64 +739,63 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
// AddSub
AxesSet AddSubBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AttrsEqual equal;
if (in_axes[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) {
return in_axes[0];
} else if (in_axes[1].defined() &&
MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) {
return in_axes[1];
} else if (in_axes[0].defined() &&
in_axes[1].defined() &&
equal(in_axes[0], in_axes[1]) &&
if (in_messages[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_messages[0];
} else if (in_messages[1].defined() &&
MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
return in_messages[1];
} else if (in_messages[0].defined() &&
in_messages[1].defined() &&
equal(in_messages[0]->axes, in_messages[1]->axes) &&
equal(tlhs->shape, trhs->shape)) {
// add of two elements.
return in_axes[0];
return in_messages[0];
} else {
auto res = NullValue<AxesSet>();
CHECK(!res.defined());
auto res = NullValue<Message>();
return res;
}
}
Expr AddSubBackwardTransform(const Call& call,
const AxesSet& axes,
const Message& message,
const Expr& scale,
const BackwardTransformer& transformer) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (!axes.defined()) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
AttrsEqual equal;
if (lhs_axes.defined() && rhs_axes.defined()) {
CHECK(equal(lhs_axes, rhs_axes));
CHECK(equal(axes, lhs_axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale);
Expr rhs = transformer->Transform(call->args[1], axes, scale);
if (lhs_message.defined() && rhs_message.defined()) {
CHECK(equal(lhs_message->axes, rhs_message->axes));
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], message, scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (lhs_axes.defined()) {
CHECK(equal(axes, lhs_axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale);
} else if (lhs_message.defined()) {
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ExpandBiasToMatchAxis(
scale, tlhs->shape.size(), axes);
scale, tlhs->shape.size(), message->axes);
rhs = Multiply(rhs, rhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_axes.defined()) {
CHECK(equal(axes, rhs_axes));
} else if (rhs_message.defined()) {
CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], axes, scale);
call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], message, scale);
Expr lhs_scale = ExpandBiasToMatchAxis(
scale, trhs->shape.size(), axes);
scale, trhs->shape.size(), message->axes);
lhs = Multiply(lhs, lhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
......@@ -763,29 +819,29 @@ RELAY_REGISTER_OP("subtract")
// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
const AxesSet& axes,
const Message& message,
const Expr& scale,
const BackwardTransformer& transformer) {
CHECK(!axes.defined()) << "outstanding scale";
CHECK(!message.defined()) << "outstanding scale";
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
if (lhs_axes.defined() && lhs_axes.size() != 0) {
Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
if (lhs_message.defined()) {
CHECK(lhs_message->axes.defined() && lhs_message->axes.size());
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr rhs = call->args[1];
// Only propagate positive scaling.
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
IsAllPositiveConstant(rhs)) {
return transformer->Transform(call->args[0], lhs_axes, rhs);
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
(!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
return transformer->Transform(call->args[0], lhs_message, rhs);
}
} else if (rhs_axes.defined() && rhs_axes.size() != 0) {
// Only propagate positive scaling.
} else if (rhs_message.defined()) {
CHECK(rhs_message->axes.defined() && rhs_message->axes.size());
Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
IsAllPositiveConstant(lhs)) {
return transformer->Transform(call->args[1], rhs_axes, lhs);
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
(!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
return transformer->Transform(call->args[1], rhs_message, lhs);
}
}
return transformer->NormalCallTransform(call.operator->());
......@@ -796,7 +852,7 @@ RELAY_REGISTER_OP("multiply")
// Consumer operators
// Conv2D send out requirement of axis folding.
AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
......@@ -817,18 +873,18 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis};
return MessageNode::make({c_big_axis}, false);
} else {
return NullValue<AxesSet>();
return NullValue<Message>();
}
}
// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call,
const AxesSet& axes,
const Message& message,
const Expr& scale,
const BackwardTransformer& transformer) {
if (!axes.defined()) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
const auto* param = call->attrs.as<Conv2DAttrs>();
......@@ -841,8 +897,8 @@ Expr Conv2DBackwardTransform(const Call& call,
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value);
CHECK(message->axes.size() == 1 &&
c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O');
// Check it must be depthwise or full conv2d.
......@@ -850,9 +906,9 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr data = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis(
scale, kernel_layout.ndim(), {big_oc_axis});
......
......@@ -174,7 +174,6 @@ def test_fold_fwd_relu_fail():
assert in_channels == channels
weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
......@@ -182,10 +181,52 @@ def test_fold_fwd_relu_fail():
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32")
in_scale = relay.const(-_get_positive_scale((4,)))
check((2, 11, 10, 4), 4, in_scale)
def test_fold_fwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, in_scale, channels):
args = [x, conv_weight]
x = relay.multiply(x, in_scale)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
weight = relay.var("weight")
y1 = before(x, weight, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
def test_fold_bwd_simple():
"""Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels):
......@@ -223,7 +264,7 @@ def test_fold_bwd_simple():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -283,7 +324,7 @@ def test_fold_bwd_dual_path():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -356,7 +397,7 @@ def test_fold_bwd_dual_consumer():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels,1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
......@@ -411,7 +452,7 @@ def test_fold_bwd_fail():
in_channels = shape[1]
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
......@@ -448,13 +489,55 @@ def test_fold_bwd_relu_fail():
check((4, 4, 10, 10), 4, out_scale)
def test_fold_bwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, out_scale, channels):
args = [x, conv_weight]
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
if __name__ == "__main__":
test_fold_fwd_simple()
test_fold_fwd_dual_path()
test_fold_fwd_fail()
test_fold_fwd_relu_fail()
test_fold_fwd_negative_scale()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
test_fold_bwd_fail()
test_fold_bwd_relu_fail()
test_fold_bwd_negative_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment