/*! * Copyright (c) 2018 by Contributors * * \file fold_scale_axis.cc * * \brief Fold axis scaling into weights of * conv/dense operators. */ #include <tvm/relay/pass.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/expr_functor.h> #include "pattern_util.h" #include "pass_util.h" #include "../op/layout.h" namespace tvm { namespace relay { /*! * \brief namespace of fold scale axis * * Use namespace to reduce potential naming conflict. */ namespace fold_scale_axis { using runtime::TypedPackedFunc; // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of // (value, axes, scale), where the final result satisfies: // // result = value // for i, k in enumerate(axes): // k-th dimension of result *= i-th dimension of scale // // Then we can propagate this signal along and fold the scale if necessary. // However, it is possible that certain scale may never be consumed // if there is no dense/conv2d that follows multiplication. // // In order to make sure all the scale we sent out can be consumed eventually, // we run a backward "preparation phase", which propagates the demand // of the potential axes scaling back to its input. // // Forward folding process is done in two steps: // - Prepare phase: backward propagation of demand. // - Transform phase: forward transformation, // // Similarly, backward folding process is done in two steps: // - Prepare phase: forward propagation of demand. // - Transform phase: transformation by push down the axes scale signal to inputs. // /*! * \brief sorted array axis, can also be nullptr. * * nullptr means no scaling request can be done. */ using AxesSet = Array<Integer>; class Message; /*! * \brief Message propogated during the prepare phase. */ class MessageNode : public RelayNode { public: /*! \brief Axes for scaling */ AxesSet axes; /*! * \brief Whether folding requires the scale to be positive constant. This is necessary if some * operators (e.g. Relu) is present. */ bool require_positive; static Message make(const AxesSet& axes, bool require_positive); static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode); }; RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef); Message MessageNode::make(const AxesSet& axes, bool require_positive) { auto n = make_node<MessageNode>(); n->axes = axes; n->require_positive = require_positive; return Message(n); } /*! * \brief Merge two axis set together by taking * intersection. * * \note The axes in a AxesSet should be sorted. * * \param lhs The left axis. * \param rhs The right axis. * \return The result of the inersection. */ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { if (!lhs.defined()) return lhs; if (!rhs.defined()) return rhs; // This code relies on axes in a AxesSet to be sorted. AxesSet ret; size_t i = 0, j = 0; while (i < lhs.size() && j < rhs.size()) { if (lhs[i]->value < rhs[j]->value) { ++i; } else if (lhs[i]->value > rhs[j]->value) { ++j; } else { ret.push_back(lhs[i]); ++i; ++j; } } return ret; } /*! * \brief Merge two messages together by taking intersection. * * \param lhs The lhs message. * \param rhs The rhs message. * \return The result of intersection. */ Message Intersect(const Message& lhs, const Message& rhs) { if (!lhs.defined()) return lhs; if (!rhs.defined()) return rhs; auto axes = Intersect(lhs->axes, rhs->axes); return MessageNode::make(axes, lhs->require_positive || rhs->require_positive); } /*! * \brief Preparation function for pass scale forward. * \param call The call node. * \param out_message Message from the output containing possible scaling on axes and whether * positive scale is required. * \return The message containing the result scaling on axes of the input. */ using FForwardPrep = runtime::TypedPackedFunc< Array<Message> (const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { public: /*! \brief The value */ Expr value; /*! \brief The axes to scale, can be nullptr(means no-scaling) */ AxesSet axes = NullValue<AxesSet>(); /*! \brief The scaling factor */ Expr scale = NullValue<Expr>(); Expr Realize() const final { CHECK(!axes.defined()) << "outstanding scale"; return value; } void VisitAttrs(AttrVisitor* v) final { v->Visit("value", &value); v->Visit("axes", &axes); v->Visit("scale", &scale); } static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr"; TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode); }; using FForwardRewrite = TypedPackedFunc< Expr(const Call& ref_call, const Array<Expr>& new_args, const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: std::unordered_map<const Node*, Message> Prepare(const Expr& body) { this->Update(body, NullValue<Message>()); this->VisitExpr(body); // flist is added in the Post-DFS order // which is a special case of topological order. // We reversely traverse the list to invoke the lazy functions. // This act like a backprop of valid scale axis messages for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) { (*it)(); } // return the created message; return std::move(message_); } private: // The invoke list std::vector<std::function<void()> > flist_; // The message on each node. std::unordered_map<const Node*, Message> message_; // Update the message stored at node. void Update(const Expr& node, const Message& message) { // We run intersection of messages: // // %y = multiply(%x, %scale) // %z1 = conv2d(%y, %w) // %z2 = exp(%y) // // Consider the above code example, // because %z2 will propagate null to %y, // the AxesSet on %y is also null, // and the forward folding won't be triggered. const Node* key = node.get(); if (message_.count(key)) { message_[key] = Intersect(message_[key], message); } else { message_[key] = message; } } // Visitor pattern override. void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; } void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { this->Update(op->body, NullValue<Message>()); }; flist_.push_back(flazy); } void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; if (it != message_.end()) { out_message = it->second; } else { out_message = NullValue<Message>(); } // pass the message back to all the children it references. auto f = fprep.get(call->op, nullptr); if (f != nullptr) { Array<Message> in_messages = f(GetRef<Call>(call), out_message); CHECK_EQ(in_messages.size(), call->args.size()); for (size_t i = 0; i < call->args.size(); ++i) { this->Update(call->args[i], in_messages[i]); } } else { for (size_t i = 0; i < call->args.size(); ++i) { this->Update(call->args[i], NullValue<Message>()); } } }; flist_.push_back(flazy); } void VisitExpr_(const TupleNode* op) { ExprVisitor::VisitExpr_(op); // do not support pass scale through tuple for now. auto flazy = [this, op]() { for (const Expr& field : op->fields) { this->Update(field, NullValue<Message>()); } }; flist_.push_back(flazy); } void VisitExpr_(const IfNode* op) { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue<Message> // it means fuse signal cannot pass // through into these subexpressions. auto flazy = [this, op]() { this->Update(op->cond, NullValue<Message>()); this->Update(op->true_branch, NullValue<Message>()); this->Update(op->false_branch, NullValue<Message>()); }; flist_.push_back(flazy); } }; //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- // Intermediate operators Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) { if (out_message.defined()) { return {MessageNode::make(out_message->axes, true)}; } return {out_message}; } Expr ReluForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) { const auto* input = new_args[0].as<ScaledExprNode>(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_node<ScaledExprNode>(); rnode->value = CallNode::make( ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; return Expr(rnode); } RELAY_REGISTER_OP("nn.relu") .set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.relu") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) { const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); auto none = NullValue<Message>(); if (out_message.defined()) { if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) { return {out_message, none}; } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) { return {none, out_message}; } } return {none, none}; } Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) { const auto* slhs = new_args[0].as<ScaledExprNode>(); const auto* srhs = new_args[1].as<ScaledExprNode>(); if (!slhs && !srhs) return Expr(); const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>(); auto rnode = make_node<ScaledExprNode>(); if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ExpandBiasToMatchAxis( slhs->scale, tlhs->shape.size(), slhs->axes); Expr rhs = Divide(new_args[1], scale); rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ExpandBiasToMatchAxis( srhs->scale, trhs->shape.size(), srhs->axes); Expr lhs = Divide(new_args[0], scale); rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; } return Expr(rnode); } RELAY_REGISTER_OP("add") .set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("add") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite); RELAY_REGISTER_OP("subtract") .set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) { if (!message.defined()) return Expr(); const auto& expected_out_axes = message->axes; CHECK(expected_out_axes.defined() && expected_out_axes.size()); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. const auto* slhs = new_args[0].as<ScaledExprNode>(); const auto* srhs = new_args[1].as<ScaledExprNode>(); CHECK(!slhs && !srhs); const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>(); Expr lhs = new_args[0]; Expr rhs = new_args[1]; auto rnode = make_node<ScaledExprNode>(); if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && (!message->require_positive || IsAllPositiveConstant(rhs))) { rnode->value = lhs; rnode->scale = rhs; rnode->axes = expected_out_axes; return Expr(rnode); } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && (!message->require_positive || IsAllPositiveConstant(lhs))) { rnode->value = rhs; rnode->scale = lhs; rnode->axes = expected_out_axes; return Expr(rnode); } else { return Expr(); } } RELAY_REGISTER_OP("multiply") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) { // TODO(tvm-team) support general data layout // by transforming weight const auto* param = call->attrs.as<Conv2DAttrs>(); CHECK(param != nullptr); Layout data_layout(param->data_layout); Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.Indexof('C'); int c_small_axis = data_layout.Indexof('c'); CHECK_GE(c_big_axis, 0); Message none = NullValue<Message>(); AxesSet data_axes = NullValue<AxesSet>(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. // We only need to change the Prep and Mutate function. // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; } if (data_axes.defined()) { return {MessageNode::make(data_axes, false), none}; } return {none, none}; } // Conv2D consumes the scale axis during transformation. Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as<ScaledExprNode>(); const auto* sweight = new_args[1].as<ScaledExprNode>(); if (sdata == nullptr) return Expr(); if (sweight != nullptr) return Expr(); const auto* param = ref_call->attrs.as<Conv2DAttrs>(); CHECK(param != nullptr); Layout data_layout(param->data_layout); Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.Indexof('C'); CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); int big_oc_axis = kernel_layout.Indexof('O'); int big_ic_axis = kernel_layout.Indexof('I'); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); Expr weight = new_args[1]; // match the ic_axis if (is_depthwise_conv2d) { Expr scale = ExpandBiasToMatchAxis( sdata->scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, scale); } else { Expr scale = ExpandBiasToMatchAxis( sdata->scale, kernel_layout.ndim(), {big_ic_axis}); weight = Multiply(weight, scale); } // return transformed conv2d return CallNode::make( ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } RELAY_REGISTER_OP("nn.conv2d") .set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") .set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(Expr data) { auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ auto it = message.find(call.get()); if (it != message.end()) { return it->second; } else { return NodeRef(nullptr); } }; return ForwardRewrite( data, "FScaleAxisForwardRewrite", fcontext); } // Expose the FoldScaleAxisFoward TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis") .set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis); //---------------------------------------- // Implement backward transformations. //---------------------------------------- class BackwardTransformer; /*! * \brief Preparation function for for pass scale backward. * \param call The call node. * \param in_messages Messages from the input containing allowed input scaling and whether * positive scale is required. * \return Message containing the result scaling on axes of the input. */ using FBackwardPrep = TypedPackedFunc< Message(const Call& call, const Array<Message>& in_messages)>; using FBackwardTransform = TypedPackedFunc< Expr(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer)>; //---------------------------------------------- // Generic Visitors for FScaleAxisBackward //---------------------------------------------- class BackwardPrep : private ExprVisitor { public: // The message on each node. std::unordered_map<const Node*, Message> Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); return std::move(message_); } private: // The message on each node. std::unordered_map<const Node*, Message> message_; // reference counter of an internal expr std::unordered_map<const Node*, size_t> ref_counter_; // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); static const auto& fprep = Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); CHECK(rit != ref_counter_.end()); // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; Array<Message> in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); if (it != message_.end()) { in_messages.push_back(it->second); } else { in_messages.push_back(NullValue<Message>()); } } Message out_message = f(GetRef<Call>(call), in_messages); if (out_message.defined()) { message_[call] = out_message; } } }; class BackwardTransformerNode : public Node, private ExprMutator { public: // Run forward transform. Expr Fold(Expr expr) { message_ = BackwardPrep().Prepare(expr); return this->Mutate(expr); } /*! * \brief Transform the expr to consider the scaling. * * \param expr The input expression. * \param axes The axes to scale. * \param scale The scale applied to the axes. * \return The result of transformation. */ Expr Transform(const Expr& expr, Message message, Expr scale) { // NOTE: the result of Transform is memoized. if (const CallNode* call_node = expr.as<CallNode>()) { return Transform(call_node, message, scale); } else { CHECK(!message.defined()) << "outstanding scale"; return ExprMutator::VisitExpr(expr); } } /*! * \brief Normal way of mutating call node. * \param call_node The call node to be mutated. * \return the result of the call Mutation. */ Expr NormalCallTransform(const CallNode* call_node) { const Call call = GetRef<Call>(call_node); const auto it = memo_.find(call); if (it != memo_.end()) { return it->second; } Expr new_expr = ExprMutator::VisitExpr_(call_node); memo_[call] = new_expr; return new_expr; } /*! * \brief Get the message propogated to the expr. * \param expr The expresison. * \return The message containing the expected axes and whether positive scale is required. */ Message GetMessage(const Expr& expr) const { auto it = message_.find(expr.get()); if (it != message_.end()) return it->second; return NullValue<Message>(); } // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); private: // Valid axes on each node. std::unordered_map<const Node*, Message> message_; // Override mutation of call. Expr VisitExpr_(const CallNode* call_node) final { return Transform(call_node, NullValue<Message>(), NullValue<Expr>()); } // Transform of CallNode. Expr Transform(const CallNode* call_node, Message message, Expr scale); }; class BackwardTransformer : public NodeRef { public: BackwardTransformer() {} explicit BackwardTransformer( ::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { } BackwardTransformerNode* operator->() const { return static_cast<BackwardTransformerNode*>(node_.get()); } using ContainerType = BackwardTransformerNode; }; Expr BackwardTransformerNode::Transform( const CallNode* call_node, Message message, Expr scale) { static const auto& ftransform = Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { const Call call = GetRef<Call>(call_node); const auto it = memo_.find(call); if (it != memo_.end()) { return it->second; } Expr new_expr = f(GetRef<Call>(call_node), message, scale, GetRef<BackwardTransformer>(this)); memo_[call] = new_expr; return new_expr; } else { CHECK(!message.defined()) << "outstanding scale"; return NormalCallTransform(call_node); } } //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- // Intermediate operators Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) { if (in_messages[0].defined()) { return MessageNode::make(in_messages[0]->axes, true); } return in_messages[0]; } Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Expr input = transformer->Transform( call->args[0], message, scale); return CallNode::make(call->op, {input}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.relu") .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.relu") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) { const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); AttrsEqual equal; if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; } else if (in_messages[1].defined() && MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { return in_messages[1]; } else if (in_messages[0].defined() && in_messages[1].defined() && equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. return in_messages[0]; } else { auto res = NullValue<Message>(); return res; } } Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); AttrsEqual equal; if (lhs_message.defined() && rhs_message.defined()) { CHECK(equal(lhs_message->axes, rhs_message->axes)); CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform(call->args[1], message, scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform( call->args[1], NullValue<Message>(), NullValue<Expr>()); Expr rhs_scale = ExpandBiasToMatchAxis( scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform( call->args[0], NullValue<Message>(), NullValue<Expr>()); Expr rhs = transformer->Transform(call->args[1], message, scale); Expr lhs_scale = ExpandBiasToMatchAxis( scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { LOG(FATAL) << "outstanding scale"; return Expr(); } } RELAY_REGISTER_OP("add") .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("add") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); RELAY_REGISTER_OP("subtract") .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("subtract") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); if (lhs_message.defined()) { CHECK(lhs_message->axes.defined() && lhs_message->axes.size()); // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { return transformer->Transform(call->args[0], lhs_message, rhs); } } else if (rhs_message.defined()) { CHECK(rhs_message->axes.defined() && rhs_message->axes.size()); Expr lhs = call->args[0]; if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { return transformer->Transform(call->args[1], rhs_message, lhs); } } return transformer->NormalCallTransform(call.operator->()); } RELAY_REGISTER_OP("multiply") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform); // Consumer operators // Conv2D send out requirement of axis folding. Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) { const auto* param = call->attrs.as<Conv2DAttrs>(); CHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.Indexof('C'); int c_small_axis = out_layout.Indexof('c'); CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. // We only need to change the Prep and Mutate function. // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (kernel_layout.Indexof('o') < 0 && kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { return MessageNode::make({c_big_axis}, false); } else { return NullValue<Message>(); } } // Conv2D consumes the scale axis during transformation. Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } const auto* param = call->attrs.as<Conv2DAttrs>(); CHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); int c_big_axis = out_layout.Indexof('C'); CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1); CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value); int big_oc_axis = kernel_layout.Indexof('O'); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); Expr data = transformer->Transform( call->args[0], NullValue<Message>(), NullValue<Expr>()); Expr weight = transformer->Transform( call->args[1], NullValue<Message>(), NullValue<Expr>()); // scale on input for deptwise. Expr wscale = ExpandBiasToMatchAxis( scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, wscale); return CallNode::make( call->op, {data, weight}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.conv2d") .set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep); RELAY_REGISTER_OP("nn.conv2d") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(Expr data) { return make_node<BackwardTransformerNode>()->Fold(data); } TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") .set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis); } // namespace fold_scale_axis } // namespace relay } // namespace tvm