Commit 2a871f35 by Wuwei Lin Committed by ziheng

[RELAY][PASS] Support Negative Scale in FoldScaleAxis (#2426)

* [RELAY][PASS] Support Negative Scale in FoldScaleAxis

* Fix comment
parent 52e3bd32
...@@ -59,6 +59,36 @@ using runtime::TypedPackedFunc; ...@@ -59,6 +59,36 @@ using runtime::TypedPackedFunc;
*/ */
using AxesSet = Array<Integer>; using AxesSet = Array<Integer>;
class Message;
/*!
* \brief Message propogated during the prepare phase.
*/
class MessageNode : public RelayNode {
public:
/*! \brief Axes for scaling */
AxesSet axes;
/*!
* \brief Whether folding requires the scale to be positive constant. This is necessary if some
* operators (e.g. Relu) is present.
*/
bool require_positive;
static Message make(const AxesSet& axes, bool require_positive);
static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode);
};
RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef);
Message MessageNode::make(const AxesSet& axes, bool require_positive) {
auto n = make_node<MessageNode>();
n->axes = axes;
n->require_positive = require_positive;
return Message(n);
}
/*! /*!
* \brief Merge two axis set together by taking * \brief Merge two axis set together by taking
* intersection. * intersection.
...@@ -89,13 +119,28 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { ...@@ -89,13 +119,28 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
} }
/*! /*!
* \brief Merge two messages together by taking intersection.
*
* \param lhs The lhs message.
* \param rhs The rhs message.
* \return The result of intersection.
*/
Message Intersect(const Message& lhs, const Message& rhs) {
if (!lhs.defined()) return lhs;
if (!rhs.defined()) return rhs;
auto axes = Intersect(lhs->axes, rhs->axes);
return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
}
/*!
* \brief Preparation function for pass scale forward. * \brief Preparation function for pass scale forward.
* \param call The call node. * \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output. * \param out_message Message from the output containing possible scaling on axes and whether
* \return The result scaling on axes of the input. * positive scale is required.
* \return The message containing the result scaling on axes of the input.
*/ */
using FForwardPrep = runtime::TypedPackedFunc< using FForwardPrep = runtime::TypedPackedFunc<
Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>; Array<Message> (const Call& call, const Message& out_message)>;
/*! \brief Axis scale tuple. */ /*! \brief Axis scale tuple. */
class ScaledExprNode : public TempExprNode { class ScaledExprNode : public TempExprNode {
...@@ -126,16 +171,16 @@ class ScaledExprNode : public TempExprNode { ...@@ -126,16 +171,16 @@ class ScaledExprNode : public TempExprNode {
using FForwardRewrite = TypedPackedFunc< using FForwardRewrite = TypedPackedFunc<
Expr(const Call& ref_call, Expr(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const AxesSet& expeced_out_axes)>; const Message& message)>;
//---------------------------------------------- //----------------------------------------------
// Generic Visitors for FScaleAxisForward // Generic Visitors for FScaleAxisForward
//---------------------------------------------- //----------------------------------------------
class ForwardPrep : private ExprVisitor { class ForwardPrep : private ExprVisitor {
public: public:
std::unordered_map<const Node*, AxesSet> std::unordered_map<const Node*, Message>
Prepare(const Expr& body) { Prepare(const Expr& body) {
this->Update(body, NullValue<AxesSet>()); this->Update(body, NullValue<Message>());
this->VisitExpr(body); this->VisitExpr(body);
// flist is added in the Post-DFS order // flist is added in the Post-DFS order
// which is a special case of topological order. // which is a special case of topological order.
...@@ -152,9 +197,9 @@ class ForwardPrep : private ExprVisitor { ...@@ -152,9 +197,9 @@ class ForwardPrep : private ExprVisitor {
// The invoke list // The invoke list
std::vector<std::function<void()> > flist_; std::vector<std::function<void()> > flist_;
// The message on each node. // The message on each node.
std::unordered_map<const Node*, AxesSet> message_; std::unordered_map<const Node*, Message> message_;
// Update the message stored at node. // Update the message stored at node.
void Update(const Expr& node, const AxesSet& axes) { void Update(const Expr& node, const Message& message) {
// We run intersection of messages: // We run intersection of messages:
// //
// %y = multiply(%x, %scale) // %y = multiply(%x, %scale)
...@@ -167,9 +212,9 @@ class ForwardPrep : private ExprVisitor { ...@@ -167,9 +212,9 @@ class ForwardPrep : private ExprVisitor {
// and the forward folding won't be triggered. // and the forward folding won't be triggered.
const Node* key = node.get(); const Node* key = node.get();
if (message_.count(key)) { if (message_.count(key)) {
message_[key] = Intersect(message_[key], axes); message_[key] = Intersect(message_[key], message);
} else { } else {
message_[key] = axes; message_[key] = message;
} }
} }
// Visitor pattern override. // Visitor pattern override.
...@@ -180,7 +225,7 @@ class ForwardPrep : private ExprVisitor { ...@@ -180,7 +225,7 @@ class ForwardPrep : private ExprVisitor {
void VisitExpr_(const FunctionNode* op) { void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
auto flazy = [this, op] { auto flazy = [this, op] {
this->Update(op->body, NullValue<AxesSet>()); this->Update(op->body, NullValue<Message>());
}; };
flist_.push_back(flazy); flist_.push_back(flazy);
} }
...@@ -193,23 +238,23 @@ class ForwardPrep : private ExprVisitor { ...@@ -193,23 +238,23 @@ class ForwardPrep : private ExprVisitor {
Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep"); Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node. // find the message send to this node.
auto it = message_.find(call); auto it = message_.find(call);
AxesSet out_axes; Message out_message;
if (it != message_.end()) { if (it != message_.end()) {
out_axes = it->second; out_message = it->second;
} else { } else {
out_axes = NullValue<AxesSet>(); out_message = NullValue<Message>();
} }
// pass the message back to all the children it references. // pass the message back to all the children it references.
auto f = fprep.get(call->op, nullptr); auto f = fprep.get(call->op, nullptr);
if (f != nullptr) { if (f != nullptr) {
Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes); Array<Message> in_messages = f(GetRef<Call>(call), out_message);
CHECK_EQ(in_axes.size(), call->args.size()); CHECK_EQ(in_messages.size(), call->args.size());
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], in_axes[i]); this->Update(call->args[i], in_messages[i]);
} }
} else { } else {
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
this->Update(call->args[i], NullValue<AxesSet>()); this->Update(call->args[i], NullValue<Message>());
} }
} }
}; };
...@@ -221,7 +266,7 @@ class ForwardPrep : private ExprVisitor { ...@@ -221,7 +266,7 @@ class ForwardPrep : private ExprVisitor {
// do not support pass scale through tuple for now. // do not support pass scale through tuple for now.
auto flazy = [this, op]() { auto flazy = [this, op]() {
for (const Expr& field : op->fields) { for (const Expr& field : op->fields) {
this->Update(field, NullValue<AxesSet>()); this->Update(field, NullValue<Message>());
} }
}; };
flist_.push_back(flazy); flist_.push_back(flazy);
...@@ -230,13 +275,13 @@ class ForwardPrep : private ExprVisitor { ...@@ -230,13 +275,13 @@ class ForwardPrep : private ExprVisitor {
void VisitExpr_(const IfNode* op) { void VisitExpr_(const IfNode* op) {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
// do pass through condition // do pass through condition
// by assigning NullValue<AxesSet> // by assigning NullValue<Message>
// it means fuse signal cannot pass // it means fuse signal cannot pass
// through into these subexpressions. // through into these subexpressions.
auto flazy = [this, op]() { auto flazy = [this, op]() {
this->Update(op->cond, NullValue<AxesSet>()); this->Update(op->cond, NullValue<Message>());
this->Update(op->true_branch, NullValue<AxesSet>()); this->Update(op->true_branch, NullValue<Message>());
this->Update(op->false_branch, NullValue<AxesSet>()); this->Update(op->false_branch, NullValue<Message>());
}; };
flist_.push_back(flazy); flist_.push_back(flazy);
} }
...@@ -247,13 +292,16 @@ class ForwardPrep : private ExprVisitor { ...@@ -247,13 +292,16 @@ class ForwardPrep : private ExprVisitor {
//---------------------------------------------- //----------------------------------------------
// Intermediate operators // Intermediate operators
Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
return {out}; if (out_message.defined()) {
return {MessageNode::make(out_message->axes, true)};
}
return {out_message};
} }
Expr ReluForwardRewrite(const Call& ref_call, Expr ReluForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const AxesSet& expected_axes) { const Message& message) {
const auto* input = new_args[0].as<ScaledExprNode>(); const auto* input = new_args[0].as<ScaledExprNode>();
if (input == nullptr) return Expr(nullptr); if (input == nullptr) return Expr(nullptr);
// return transformed conv2d // return transformed conv2d
...@@ -278,23 +326,23 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -278,23 +326,23 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite); .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
// AddSub // AddSub
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) { Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
auto none = NullValue<Message>();
auto none = NullValue<AxesSet>(); if (out_message.defined()) {
if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) { if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
return {out_axes, none}; return {out_message, none};
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) { } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
return {none, out_axes}; return {none, out_message};
} else { }
return {none, none};
} }
return {none, none};
} }
Expr AddSubForwardRewrite(const Call& ref_call, Expr AddSubForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const AxesSet& expected_out_axes) { const Message& message) {
const auto* slhs = new_args[0].as<ScaledExprNode>(); const auto* slhs = new_args[0].as<ScaledExprNode>();
const auto* srhs = new_args[1].as<ScaledExprNode>(); const auto* srhs = new_args[1].as<ScaledExprNode>();
if (!slhs && !srhs) return Expr(); if (!slhs && !srhs) return Expr();
...@@ -342,9 +390,10 @@ RELAY_REGISTER_OP("subtract") ...@@ -342,9 +390,10 @@ RELAY_REGISTER_OP("subtract")
// Multiply produces the scale-axis pair. // Multiply produces the scale-axis pair.
Expr MultiplyForwardRewrite(const Call& ref_call, Expr MultiplyForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const AxesSet& expected_out_axes) { const Message& message) {
if (!expected_out_axes.defined()) return Expr(); if (!message.defined()) return Expr();
if (expected_out_axes.size() == 0) return Expr(); const auto& expected_out_axes = message->axes;
CHECK(expected_out_axes.defined() && expected_out_axes.size());
// TODO(tvm-team) allow same axes accumulation // TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn. // not as important because it is less common in nn.
const auto* slhs = new_args[0].as<ScaledExprNode>(); const auto* slhs = new_args[0].as<ScaledExprNode>();
...@@ -356,14 +405,15 @@ Expr MultiplyForwardRewrite(const Call& ref_call, ...@@ -356,14 +405,15 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
Expr lhs = new_args[0]; Expr lhs = new_args[0];
Expr rhs = new_args[1]; Expr rhs = new_args[1];
auto rnode = make_node<ScaledExprNode>(); auto rnode = make_node<ScaledExprNode>();
if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
IsAllPositiveConstant(rhs)) { (!message->require_positive || IsAllPositiveConstant(rhs))) {
rnode->value = lhs; rnode->value = lhs;
rnode->scale = rhs; rnode->scale = rhs;
rnode->axes = expected_out_axes; rnode->axes = expected_out_axes;
return Expr(rnode); return Expr(rnode);
} else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
IsAllPositiveConstant(lhs)) { (!message->require_positive || IsAllPositiveConstant(lhs))) {
rnode->value = rhs; rnode->value = rhs;
rnode->scale = lhs; rnode->scale = lhs;
rnode->axes = expected_out_axes; rnode->axes = expected_out_axes;
...@@ -378,7 +428,7 @@ RELAY_REGISTER_OP("multiply") ...@@ -378,7 +428,7 @@ RELAY_REGISTER_OP("multiply")
// Consumer operators // Consumer operators
// Conv2D send out requirement of axis folding. // Conv2D send out requirement of axis folding.
Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
// by transforming weight // by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
...@@ -389,6 +439,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -389,6 +439,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
int c_small_axis = data_layout.Indexof('c'); int c_small_axis = data_layout.Indexof('c');
CHECK_GE(c_big_axis, 0); CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
AxesSet data_axes = NullValue<AxesSet>(); AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data) // For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework. // More general layout can be supported under the current framework.
...@@ -403,13 +454,16 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) { ...@@ -403,13 +454,16 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis}; data_axes = {c_big_axis};
} }
return {data_axes, NullValue<AxesSet>()}; if (data_axes.defined()) {
return {MessageNode::make(data_axes, false), none};
}
return {none, none};
} }
// Conv2D consumes the scale axis during transformation. // Conv2D consumes the scale axis during transformation.
Expr Conv2DForwardRewrite(const Call& ref_call, Expr Conv2DForwardRewrite(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const AxesSet& expected_axes) { const Message& message) {
// if data do not have scale, normal transform path. // if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>(); const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>(); const auto* sweight = new_args[1].as<ScaledExprNode>();
...@@ -458,11 +512,10 @@ RELAY_REGISTER_OP("nn.conv2d") ...@@ -458,11 +512,10 @@ RELAY_REGISTER_OP("nn.conv2d")
Expr ForwardFoldScaleAxis(Expr data) { Expr ForwardFoldScaleAxis(Expr data) {
auto expected_scale_axes = auto message = ForwardPrep().Prepare(data);
ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> NodeRef{ auto fcontext = [&](const Call& call) -> NodeRef{
auto it = expected_scale_axes.find(call.get()); auto it = message.find(call.get());
if (it != expected_scale_axes.end()) { if (it != message.end()) {
return it->second; return it->second;
} else { } else {
return NodeRef(nullptr); return NodeRef(nullptr);
...@@ -484,15 +537,16 @@ class BackwardTransformer; ...@@ -484,15 +537,16 @@ class BackwardTransformer;
/*! /*!
* \brief Preparation function for for pass scale backward. * \brief Preparation function for for pass scale backward.
* \param call The call node. * \param call The call node.
* \param in_scale_axes Allowed input scaling. * \param in_messages Messages from the input containing allowed input scaling and whether
* \return The result scaling on axes of the input. * positive scale is required.
* \return Message containing the result scaling on axes of the input.
*/ */
using FBackwardPrep = TypedPackedFunc< using FBackwardPrep = TypedPackedFunc<
AxesSet(const Call& call, const Array<AxesSet>& in_scale_axes)>; Message(const Call& call, const Array<Message>& in_messages)>;
using FBackwardTransform = TypedPackedFunc< using FBackwardTransform = TypedPackedFunc<
Expr(const Call& call, Expr(const Call& call,
const AxesSet& axes, const Message& message,
const Expr& scale, const Expr& scale,
const BackwardTransformer& transformer)>; const BackwardTransformer& transformer)>;
...@@ -503,7 +557,7 @@ using FBackwardTransform = TypedPackedFunc< ...@@ -503,7 +557,7 @@ using FBackwardTransform = TypedPackedFunc<
class BackwardPrep : private ExprVisitor { class BackwardPrep : private ExprVisitor {
public: public:
// The message on each node. // The message on each node.
std::unordered_map<const Node*, AxesSet> std::unordered_map<const Node*, Message>
Prepare(const Expr& body) { Prepare(const Expr& body) {
ref_counter_ = GetExprRefCount(body); ref_counter_ = GetExprRefCount(body);
this->VisitExpr(body); this->VisitExpr(body);
...@@ -512,7 +566,7 @@ class BackwardPrep : private ExprVisitor { ...@@ -512,7 +566,7 @@ class BackwardPrep : private ExprVisitor {
private: private:
// The message on each node. // The message on each node.
std::unordered_map<const Node*, AxesSet> message_; std::unordered_map<const Node*, Message> message_;
// reference counter of an internal expr // reference counter of an internal expr
std::unordered_map<const Node*, size_t> ref_counter_; std::unordered_map<const Node*, size_t> ref_counter_;
// Visit the expression. // Visit the expression.
...@@ -527,18 +581,18 @@ class BackwardPrep : private ExprVisitor { ...@@ -527,18 +581,18 @@ class BackwardPrep : private ExprVisitor {
// We only allow propagation of scale backward // We only allow propagation of scale backward
// if the expression is only referred by a single parent. // if the expression is only referred by a single parent.
if (rit->second != 1) return; if (rit->second != 1) return;
Array<AxesSet> in_axes; Array<Message> in_messages;
for (Expr arg : call->args) { for (Expr arg : call->args) {
auto it = message_.find(arg.get()); auto it = message_.find(arg.get());
if (it != message_.end()) { if (it != message_.end()) {
in_axes.push_back(it->second); in_messages.push_back(it->second);
} else { } else {
in_axes.push_back(NullValue<AxesSet>()); in_messages.push_back(NullValue<Message>());
} }
} }
AxesSet out_axes = f(GetRef<Call>(call), in_axes); Message out_message = f(GetRef<Call>(call), in_messages);
if (out_axes.defined()) { if (out_message.defined()) {
message_[call] = out_axes; message_[call] = out_message;
} }
} }
}; };
...@@ -549,7 +603,7 @@ class BackwardTransformerNode : ...@@ -549,7 +603,7 @@ class BackwardTransformerNode :
public: public:
// Run forward transform. // Run forward transform.
Expr Fold(Expr expr) { Expr Fold(Expr expr) {
expected_scale_axes_ = BackwardPrep().Prepare(expr); message_ = BackwardPrep().Prepare(expr);
return this->Mutate(expr); return this->Mutate(expr);
} }
/*! /*!
...@@ -560,12 +614,12 @@ class BackwardTransformerNode : ...@@ -560,12 +614,12 @@ class BackwardTransformerNode :
* \param scale The scale applied to the axes. * \param scale The scale applied to the axes.
* \return The result of transformation. * \return The result of transformation.
*/ */
Expr Transform(const Expr& expr, AxesSet axes, Expr scale) { Expr Transform(const Expr& expr, Message message, Expr scale) {
// NOTE: the result of Transform is memoized. // NOTE: the result of Transform is memoized.
if (const CallNode* call_node = expr.as<CallNode>()) { if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, axes, scale); return Transform(call_node, message, scale);
} else { } else {
CHECK(!axes.defined()) << "outstanding scale"; CHECK(!message.defined()) << "outstanding scale";
return ExprMutator::VisitExpr(expr); return ExprMutator::VisitExpr(expr);
} }
} }
...@@ -585,14 +639,14 @@ class BackwardTransformerNode : ...@@ -585,14 +639,14 @@ class BackwardTransformerNode :
return new_expr; return new_expr;
} }
/*! /*!
* \brief Get the expected axes on expr. * \brief Get the message propogated to the expr.
* \param expr The expresison. * \param expr The expresison.
* \return The expected axes. * \return The message containing the expected axes and whether positive scale is required.
*/ */
AxesSet GetExpectedAxes(const Expr& expr) const { Message GetMessage(const Expr& expr) const {
auto it = expected_scale_axes_.find(expr.get()); auto it = message_.find(expr.get());
if (it != expected_scale_axes_.end()) return it->second; if (it != message_.end()) return it->second;
return NullValue<AxesSet>(); return NullValue<Message>();
} }
// solver is not serializable. // solver is not serializable.
...@@ -603,13 +657,13 @@ class BackwardTransformerNode : ...@@ -603,13 +657,13 @@ class BackwardTransformerNode :
private: private:
// Valid axes on each node. // Valid axes on each node.
std::unordered_map<const Node*, AxesSet> expected_scale_axes_; std::unordered_map<const Node*, Message> message_;
// Override mutation of call. // Override mutation of call.
Expr VisitExpr_(const CallNode* call_node) final { Expr VisitExpr_(const CallNode* call_node) final {
return Transform(call_node, NullValue<AxesSet>(), NullValue<Expr>()); return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
} }
// Transform of CallNode. // Transform of CallNode.
Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale); Expr Transform(const CallNode* call_node, Message message, Expr scale);
}; };
class BackwardTransformer : public NodeRef { class BackwardTransformer : public NodeRef {
...@@ -625,7 +679,7 @@ class BackwardTransformer : public NodeRef { ...@@ -625,7 +679,7 @@ class BackwardTransformer : public NodeRef {
}; };
Expr BackwardTransformerNode::Transform( Expr BackwardTransformerNode::Transform(
const CallNode* call_node, AxesSet axes, Expr scale) { const CallNode* call_node, Message message, Expr scale) {
static const auto& ftransform = static const auto& ftransform =
Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform"); Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr); auto f = ftransform.get(call_node->op, nullptr);
...@@ -636,13 +690,13 @@ Expr BackwardTransformerNode::Transform( ...@@ -636,13 +690,13 @@ Expr BackwardTransformerNode::Transform(
return it->second; return it->second;
} }
Expr new_expr = f(GetRef<Call>(call_node), Expr new_expr = f(GetRef<Call>(call_node),
axes, message,
scale, scale,
GetRef<BackwardTransformer>(this)); GetRef<BackwardTransformer>(this));
memo_[call] = new_expr; memo_[call] = new_expr;
return new_expr; return new_expr;
} else { } else {
CHECK(!axes.defined()) << "outstanding scale"; CHECK(!message.defined()) << "outstanding scale";
return NormalCallTransform(call_node); return NormalCallTransform(call_node);
} }
} }
...@@ -653,19 +707,22 @@ Expr BackwardTransformerNode::Transform( ...@@ -653,19 +707,22 @@ Expr BackwardTransformerNode::Transform(
//---------------------------------------------- //----------------------------------------------
// Intermediate operators // Intermediate operators
AxesSet ReluBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
return in_axes[0]; if (in_messages[0].defined()) {
return MessageNode::make(in_messages[0]->axes, true);
}
return in_messages[0];
} }
Expr ReluBackwardTransform(const Call& call, Expr ReluBackwardTransform(const Call& call,
const AxesSet& axes, const Message& message,
const Expr& scale, const Expr& scale,
const BackwardTransformer& transformer) { const BackwardTransformer& transformer) {
if (!axes.defined()) { if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->()); return transformer->NormalCallTransform(call.operator->());
} }
Expr input = transformer->Transform( Expr input = transformer->Transform(
call->args[0], axes, scale); call->args[0], message, scale);
return CallNode::make(call->op, {input}, call->attrs, call->type_args); return CallNode::make(call->op, {input}, call->attrs, call->type_args);
} }
...@@ -682,64 +739,63 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -682,64 +739,63 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform); .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
// AddSub // AddSub
AxesSet AddSubBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AttrsEqual equal; AttrsEqual equal;
if (in_axes[0].defined() && if (in_messages[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) { MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_axes[0]; return in_messages[0];
} else if (in_axes[1].defined() && } else if (in_messages[1].defined() &&
MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) { MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
return in_axes[1]; return in_messages[1];
} else if (in_axes[0].defined() && } else if (in_messages[0].defined() &&
in_axes[1].defined() && in_messages[1].defined() &&
equal(in_axes[0], in_axes[1]) && equal(in_messages[0]->axes, in_messages[1]->axes) &&
equal(tlhs->shape, trhs->shape)) { equal(tlhs->shape, trhs->shape)) {
// add of two elements. // add of two elements.
return in_axes[0]; return in_messages[0];
} else { } else {
auto res = NullValue<AxesSet>(); auto res = NullValue<Message>();
CHECK(!res.defined());
return res; return res;
} }
} }
Expr AddSubBackwardTransform(const Call& call, Expr AddSubBackwardTransform(const Call& call,
const AxesSet& axes, const Message& message,
const Expr& scale, const Expr& scale,
const BackwardTransformer& transformer) { const BackwardTransformer& transformer) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
if (!axes.defined()) { if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->()); return transformer->NormalCallTransform(call.operator->());
} }
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); Message lhs_message = transformer->GetMessage(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); Message rhs_message = transformer->GetMessage(call->args[1]);
AttrsEqual equal; AttrsEqual equal;
if (lhs_axes.defined() && rhs_axes.defined()) { if (lhs_message.defined() && rhs_message.defined()) {
CHECK(equal(lhs_axes, rhs_axes)); CHECK(equal(lhs_message->axes, rhs_message->axes));
CHECK(equal(axes, lhs_axes)); CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale); Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], axes, scale); Expr rhs = transformer->Transform(call->args[1], message, scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (lhs_axes.defined()) { } else if (lhs_message.defined()) {
CHECK(equal(axes, lhs_axes)); CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], axes, scale); Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform( Expr rhs = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>()); call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ExpandBiasToMatchAxis( Expr rhs_scale = ExpandBiasToMatchAxis(
scale, tlhs->shape.size(), axes); scale, tlhs->shape.size(), message->axes);
rhs = Multiply(rhs, rhs_scale); rhs = Multiply(rhs, rhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_axes.defined()) { } else if (rhs_message.defined()) {
CHECK(equal(axes, rhs_axes)); CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform( Expr lhs = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>()); call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], axes, scale); Expr rhs = transformer->Transform(call->args[1], message, scale);
Expr lhs_scale = ExpandBiasToMatchAxis( Expr lhs_scale = ExpandBiasToMatchAxis(
scale, trhs->shape.size(), axes); scale, trhs->shape.size(), message->axes);
lhs = Multiply(lhs, lhs_scale); lhs = Multiply(lhs, lhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else { } else {
...@@ -763,29 +819,29 @@ RELAY_REGISTER_OP("subtract") ...@@ -763,29 +819,29 @@ RELAY_REGISTER_OP("subtract")
// Producer operators // Producer operators
// Multiply produces the scale-axis pair. // Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call, Expr MultiplyBackwardTransform(const Call& call,
const AxesSet& axes, const Message& message,
const Expr& scale, const Expr& scale,
const BackwardTransformer& transformer) { const BackwardTransformer& transformer) {
CHECK(!axes.defined()) << "outstanding scale"; CHECK(!message.defined()) << "outstanding scale";
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); Message lhs_message = transformer->GetMessage(call->args[0]);
AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); Message rhs_message = transformer->GetMessage(call->args[1]);
if (lhs_axes.defined() && lhs_axes.size() != 0) { if (lhs_message.defined()) {
CHECK(lhs_message->axes.defined() && lhs_message->axes.size());
// NOTE we won't recursively call mutating on scale part. // NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part. // since there won't be scale chance within scale part.
Expr rhs = call->args[1]; Expr rhs = call->args[1];
// Only propagate positive scaling. if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
IsAllPositiveConstant(rhs)) { return transformer->Transform(call->args[0], lhs_message, rhs);
return transformer->Transform(call->args[0], lhs_axes, rhs);
} }
} else if (rhs_axes.defined() && rhs_axes.size() != 0) { } else if (rhs_message.defined()) {
// Only propagate positive scaling. CHECK(rhs_message->axes.defined() && rhs_message->axes.size());
Expr lhs = call->args[0]; Expr lhs = call->args[0];
if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
IsAllPositiveConstant(lhs)) { (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
return transformer->Transform(call->args[1], rhs_axes, lhs); return transformer->Transform(call->args[1], rhs_message, lhs);
} }
} }
return transformer->NormalCallTransform(call.operator->()); return transformer->NormalCallTransform(call.operator->());
...@@ -796,7 +852,7 @@ RELAY_REGISTER_OP("multiply") ...@@ -796,7 +852,7 @@ RELAY_REGISTER_OP("multiply")
// Consumer operators // Consumer operators
// Conv2D send out requirement of axis folding. // Conv2D send out requirement of axis folding.
AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout); Layout kernel_layout(param->kernel_layout);
...@@ -817,18 +873,18 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) { ...@@ -817,18 +873,18 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
kernel_layout.Indexof('i') < 0 && kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) { (param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis}; return MessageNode::make({c_big_axis}, false);
} else { } else {
return NullValue<AxesSet>(); return NullValue<Message>();
} }
} }
// Conv2D consumes the scale axis during transformation. // Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call, Expr Conv2DBackwardTransform(const Call& call,
const AxesSet& axes, const Message& message,
const Expr& scale, const Expr& scale,
const BackwardTransformer& transformer) { const BackwardTransformer& transformer) {
if (!axes.defined()) { if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->()); return transformer->NormalCallTransform(call.operator->());
} }
const auto* param = call->attrs.as<Conv2DAttrs>(); const auto* param = call->attrs.as<Conv2DAttrs>();
...@@ -841,8 +897,8 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -841,8 +897,8 @@ Expr Conv2DBackwardTransform(const Call& call,
// TODO(tvm-team) support general data layout // TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.Indexof('o'), -1);
CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 && CHECK(message->axes.size() == 1 &&
c_big_axis == axes[0]->value); c_big_axis == message->axes[0]->value);
int big_oc_axis = kernel_layout.Indexof('O'); int big_oc_axis = kernel_layout.Indexof('O');
// Check it must be depthwise or full conv2d. // Check it must be depthwise or full conv2d.
...@@ -850,9 +906,9 @@ Expr Conv2DBackwardTransform(const Call& call, ...@@ -850,9 +906,9 @@ Expr Conv2DBackwardTransform(const Call& call,
CHECK(param->groups == 1 || is_depthwise_conv2d); CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr data = transformer->Transform( Expr data = transformer->Transform(
call->args[0], NullValue<AxesSet>(), NullValue<Expr>()); call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform( Expr weight = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>()); call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise. // scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis( Expr wscale = ExpandBiasToMatchAxis(
scale, kernel_layout.ndim(), {big_oc_axis}); scale, kernel_layout.ndim(), {big_oc_axis});
......
...@@ -174,7 +174,6 @@ def test_fold_fwd_relu_fail(): ...@@ -174,7 +174,6 @@ def test_fold_fwd_relu_fail():
assert in_channels == channels assert in_channels == channels
weight = relay.var("weight") weight = relay.var("weight")
in_bias = relay.var("in_bias", shape=(in_channels,)) in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.var("in_scale", shape=(in_channels,))
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
...@@ -182,10 +181,52 @@ def test_fold_fwd_relu_fail(): ...@@ -182,10 +181,52 @@ def test_fold_fwd_relu_fail():
in_scale = relay.var("in_scale", shape=(4,)) in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale) check((2, 11, 10, 4), 4, in_scale)
in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32") in_scale = relay.const(-_get_positive_scale((4,)))
check((2, 11, 10, 4), 4, in_scale) check((2, 11, 10, 4), 4, in_scale)
def test_fold_fwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, in_scale, channels):
args = [x, conv_weight]
x = relay.multiply(x, in_scale)
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def expected(x, conv_weight, in_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
y = relay.nn.conv2d(x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[1]
in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
weight = relay.var("weight")
y1 = before(x, weight, in_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
y1_expected = expected(x, weight, in_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
def test_fold_bwd_simple(): def test_fold_bwd_simple():
"""Simple testcase.""" """Simple testcase."""
def before(x, conv_weight, out_bias, out_scale, channels): def before(x, conv_weight, out_bias, out_scale, channels):
...@@ -223,7 +264,7 @@ def test_fold_bwd_simple(): ...@@ -223,7 +264,7 @@ def test_fold_bwd_simple():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -283,7 +324,7 @@ def test_fold_bwd_dual_path(): ...@@ -283,7 +324,7 @@ def test_fold_bwd_dual_path():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -356,7 +397,7 @@ def test_fold_bwd_dual_consumer(): ...@@ -356,7 +397,7 @@ def test_fold_bwd_dual_consumer():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32")) out_scale = relay.const(_get_positive_scale((channels,1, 1)))
y1 = before(x, weight, out_bias, out_scale, channels) y1 = before(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
...@@ -411,7 +452,7 @@ def test_fold_bwd_fail(): ...@@ -411,7 +452,7 @@ def test_fold_bwd_fail():
in_channels = shape[1] in_channels = shape[1]
weight = relay.var("weight") weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,)) out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1) y1 = relay.ir_pass.infer_type(y1)
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
...@@ -448,13 +489,55 @@ def test_fold_bwd_relu_fail(): ...@@ -448,13 +489,55 @@ def test_fold_bwd_relu_fail():
check((4, 4, 10, 10), 4, out_scale) check((4, 4, 10, 10), 4, out_scale)
def test_fold_bwd_negative_scale():
"""Testcase of folding negative scale"""
def before(x, conv_weight, out_scale, channels):
args = [x, conv_weight]
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.multiply(y, out_scale)
return relay.Function(args, y)
def expected(x, conv_weight, out_scale, channels):
# use a fixed order of args so alpha equal check can pass
args = [x, conv_weight]
squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
conv_weight = relay.multiply(
conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
y = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1))
return relay.Function(args, y)
def check(shape, channels):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_scale, channels)
y1 = relay.ir_pass.infer_type(y1)
type_dict = {x.name_hint:x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
y1_expected = expected(x, weight, out_scale, channels)
y1_folded = relay.ir_pass.infer_type(y1_folded)
y1_expected = relay.ir_pass.infer_type(y1_expected)
assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
if __name__ == "__main__": if __name__ == "__main__":
test_fold_fwd_simple() test_fold_fwd_simple()
test_fold_fwd_dual_path() test_fold_fwd_dual_path()
test_fold_fwd_fail() test_fold_fwd_fail()
test_fold_fwd_relu_fail() test_fold_fwd_relu_fail()
test_fold_fwd_negative_scale()
test_fold_bwd_simple() test_fold_bwd_simple()
test_fold_bwd_dual_path() test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer() test_fold_bwd_dual_consumer()
test_fold_bwd_fail() test_fold_bwd_fail()
test_fold_bwd_relu_fail() test_fold_bwd_relu_fail()
test_fold_bwd_negative_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment