fold_scale_axis.cc 34 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file fold_scale_axis.cc
 *
 * \brief Fold axis scaling into weights of
 *  conv/dense operators.
 */
28
#include <tvm/data_layout.h>
29 30 31
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
32
#include <tvm/relay/transform.h>
33
#include "pattern_util.h"
34
#include "pass_util.h"
35

36 37 38 39 40 41 42 43 44 45 46 47 48

namespace tvm {
namespace relay {
/*!
 * \brief namespace of fold scale axis
 *
 * Use namespace to reduce potential naming conflict.
 */
namespace fold_scale_axis {

using runtime::TypedPackedFunc;


49
// FoldScaleAxis algorithm:
50
//
51
// The general idea is to transform Expr to tuple of
52
// (value, axes, scale), where the final result satisfies:
53 54 55
//
// result = value
// for i, k in enumerate(axes):
56
//    k-th dimension of result *= i-th dimension of scale
57 58 59 60 61 62 63 64 65
//
// Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed
// if there is no dense/conv2d that follows multiplication.
//
// In order to make sure all the scale we sent out can be consumed eventually,
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
66
// Forward folding process is done in two steps:
67 68
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
69 70 71 72 73
//
// Similarly, backward folding process is done in two steps:
// - Prepare phase: forward propagation of demand.
// - Transform phase: transformation by push down the axes scale signal to inputs.
//
74 75 76 77 78 79 80 81

/*!
 * \brief sorted array axis, can also be nullptr.
 *
 *  nullptr means no scaling request can be done.
 */
using AxesSet = Array<Integer>;

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
class Message;

/*!
 * \brief Message propogated during the prepare phase.
 */
class MessageNode : public RelayNode {
 public:
  /*! \brief Axes for scaling */
  AxesSet axes;
  /*!
   * \brief Whether folding requires the scale to be positive constant. This is necessary if some
   *  operators (e.g. Relu) is present.
   */
  bool require_positive;

  static Message make(const AxesSet& axes, bool require_positive);

  static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
  TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode);
};

RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef);

Message MessageNode::make(const AxesSet& axes, bool require_positive)  {
  auto n = make_node<MessageNode>();
  n->axes = axes;
  n->require_positive = require_positive;
  return Message(n);
}

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
/*!
 * \brief Merge two axis set together by taking
 *  intersection.
 *
 * \note The axes in a AxesSet should be sorted.
 *
 * \param lhs The left axis.
 * \param rhs The right axis.
 * \return The result of the inersection.
 */
AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  // This code relies on axes in a AxesSet to be sorted.
  AxesSet ret;
  size_t i = 0, j = 0;
  while (i < lhs.size() && j < rhs.size()) {
    if (lhs[i]->value < rhs[j]->value) {
      ++i;
    } else if (lhs[i]->value > rhs[j]->value) {
      ++j;
    } else {
      ret.push_back(lhs[i]);
      ++i; ++j;
    }
  }
  return ret;
}

/*!
142 143 144 145 146 147 148 149 150 151 152 153 154 155
 * \brief Merge two messages together by taking intersection.
 *
 * \param lhs The lhs message.
 * \param rhs The rhs message.
 * \return The result of intersection.
 */
Message Intersect(const Message& lhs, const Message& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  auto axes = Intersect(lhs->axes, rhs->axes);
  return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
}

/*!
156
 * \brief Preparation function for pass scale forward.
157
 * \param call The call node.
158 159 160
 * \param out_message Message from the output containing possible scaling on axes and whether
 *        positive scale is required.
 * \return The message containing the result scaling on axes of the input.
161 162
 */
using FForwardPrep = runtime::TypedPackedFunc<
163
  Array<Message> (const Call& call, const Message& out_message)>;
164 165

/*! \brief Axis scale tuple.  */
166
class ScaledExprNode : public TempExprNode {
167 168 169 170 171 172 173 174
 public:
  /*! \brief The value */
  Expr value;
  /*! \brief The axes to scale, can be nullptr(means no-scaling) */
  AxesSet axes = NullValue<AxesSet>();
  /*! \brief The scaling factor */
  Expr scale = NullValue<Expr>();

175 176 177 178 179 180
  Expr Realize() const final {
    CHECK(!axes.defined())
        << "outstanding scale";
    return value;
  }

181 182 183 184 185 186
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("value", &value);
    v->Visit("axes", &axes);
    v->Visit("scale", &scale);
  }

187 188
  static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr";
  TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode);
189 190
};

191 192 193
using FForwardRewrite = TypedPackedFunc<
  Expr(const Call& ref_call,
       const Array<Expr>& new_args,
194
       const Message& message)>;
195 196 197 198

//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
199
class ForwardPrep : private ExprVisitor {
200
 public:
201
  std::unordered_map<const Node*, Message>
202
  Prepare(const Expr& body) {
203
    this->Update(body, NullValue<Message>());
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    this->VisitExpr(body);
    // flist is added in the Post-DFS order
    // which is a special case of topological order.
    // We reversely traverse the list to invoke the lazy functions.
    // This act like a backprop of valid scale axis messages
    for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
      (*it)();
    }
    // return the created message;
    return std::move(message_);
  }

 private:
  // The invoke list
  std::vector<std::function<void()> > flist_;
  // The message on each node.
220
  std::unordered_map<const Node*, Message> message_;
221
  // Update the message stored at node.
222
  void Update(const Expr& node, const Message& message) {
223 224 225 226 227 228 229 230 231 232 233 234
    // We run intersection of messages:
    //
    // %y = multiply(%x, %scale)
    // %z1 = conv2d(%y, %w)
    // %z2 = exp(%y)
    //
    // Consider the above code example,
    // because %z2 will propagate null to %y,
    // the AxesSet on %y is also null,
    // and the forward folding won't be triggered.
    const Node* key = node.get();
    if (message_.count(key)) {
235
      message_[key] = Intersect(message_[key], message);
236
    } else {
237
      message_[key] = message;
238 239 240 241 242 243 244 245 246 247
    }
  }
  // Visitor pattern override.
  void VisitExpr_(const LetNode* call) {
    LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
  }

  void VisitExpr_(const FunctionNode* op) {
    ExprVisitor::VisitExpr_(op);
    auto flazy = [this, op] {
248
      this->Update(op->body, NullValue<Message>());
249 250 251 252 253 254 255 256 257 258 259 260
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    // function to be lazily invoked
    auto flazy = [this, call]() {
      static const auto& fprep =
        Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
      // find the message send to this node.
      auto it = message_.find(call);
261
      Message out_message;
262
      if (it != message_.end()) {
263
        out_message = it->second;
264
      } else {
265
        out_message = NullValue<Message>();
266 267
      }
      // pass the message back to all the children it references.
268
      auto f = fprep.get(call->op, nullptr);
269
      if (f != nullptr) {
270 271
        Array<Message> in_messages = f(GetRef<Call>(call), out_message);
        CHECK_EQ(in_messages.size(), call->args.size());
272
        for (size_t i = 0; i < call->args.size(); ++i) {
273
          this->Update(call->args[i], in_messages[i]);
274 275 276
        }
      } else {
        for (size_t i = 0; i < call->args.size(); ++i) {
277
          this->Update(call->args[i], NullValue<Message>());
278 279 280 281 282 283 284 285 286 287 288
        }
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const TupleNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do not support pass scale through tuple for now.
    auto flazy = [this, op]() {
      for (const Expr& field : op->fields) {
289
        this->Update(field, NullValue<Message>());
290 291 292 293 294 295 296 297
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const IfNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do pass through condition
298
    // by assigning NullValue<Message>
299 300 301
    // it means fuse signal cannot pass
    // through into these subexpressions.
    auto flazy = [this, op]() {
302 303 304
      this->Update(op->cond, NullValue<Message>());
      this->Update(op->true_branch, NullValue<Message>());
      this->Update(op->false_branch, NullValue<Message>());
305 306 307 308 309 310 311 312 313 314
    };
    flist_.push_back(flazy);
  }
};

//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
315 316 317 318 319
Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
  if (out_message.defined()) {
    return {MessageNode::make(out_message->axes, true)};
  }
  return {out_message};
320 321
}

322 323
Expr ReluForwardRewrite(const Call& ref_call,
                        const Array<Expr>& new_args,
324
                        const Message& message) {
325 326
  const auto* input = new_args[0].as<ScaledExprNode>();
  if (input == nullptr) return Expr(nullptr);
327
  // return transformed conv2d
328
  auto rnode = make_node<ScaledExprNode>();
329
  rnode->value = CallNode::make(
330 331 332 333
      ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
  rnode->scale = input->scale;
  rnode->axes = input->axes;
  return Expr(rnode);
334 335 336 337 338 339
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.relu")
340
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
341 342 343 344 345

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
346
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
347 348

// AddSub
349
Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
350 351
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
352 353 354 355 356 357 358
  auto none = NullValue<Message>();
  if (out_message.defined()) {
    if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
      return {out_message, none};
    } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
      return {none, out_message};
    }
359
  }
360
  return {none, none};
361 362
}

363 364
Expr AddSubForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
365
                          const Message& message) {
366 367 368
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  if (!slhs && !srhs) return Expr();
369 370
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
371
  auto rnode = make_node<ScaledExprNode>();
372

373 374 375
  if (slhs != nullptr) {
    CHECK(srhs == nullptr);
    CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
376
    Expr scale = ExpandBiasToMatchAxis(
377 378 379
        slhs->scale, tlhs->shape.size(), slhs->axes);
    Expr rhs = Divide(new_args[1], scale);
    rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs},
380
                                  ref_call->attrs, ref_call->type_args);
381 382
    rnode->scale = slhs->scale;
    rnode->axes = slhs->axes;
383
  } else {
384
    CHECK(srhs != nullptr);
385
    CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
386
    Expr scale = ExpandBiasToMatchAxis(
387 388 389
        srhs->scale, trhs->shape.size(), srhs->axes);
    Expr lhs = Divide(new_args[0], scale);
    rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value},
390
                                  ref_call->attrs, ref_call->type_args);
391 392
    rnode->scale = srhs->scale;
    rnode->axes = srhs->axes;
393
  }
394
  return Expr(rnode);
395 396 397 398 399 400
}

RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("add")
401
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
402 403 404 405 406

RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("subtract")
407
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
408 409 410

// Producer operators
// Multiply produces the scale-axis pair.
411 412
Expr MultiplyForwardRewrite(const Call& ref_call,
                            const Array<Expr>& new_args,
413 414 415 416
                            const Message& message) {
  if (!message.defined()) return Expr();
  const auto& expected_out_axes = message->axes;
  CHECK(expected_out_axes.defined() && expected_out_axes.size());
417 418
  // TODO(tvm-team) allow same axes accumulation
  // not as important because it is less common in nn.
419 420 421 422
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  CHECK(!slhs && !srhs);

423 424
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
425 426 427
  Expr lhs = new_args[0];
  Expr rhs = new_args[1];
  auto rnode = make_node<ScaledExprNode>();
428

429
  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
430
      (!message->require_positive || IsAllPositiveConstant(rhs))) {
431 432 433
    rnode->value = lhs;
    rnode->scale = rhs;
    rnode->axes = expected_out_axes;
434 435
    return Expr(rnode);
  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
436
             (!message->require_positive || IsAllPositiveConstant(lhs))) {
437 438 439
    rnode->value = rhs;
    rnode->scale = lhs;
    rnode->axes = expected_out_axes;
440 441 442
    return Expr(rnode);
  } else {
    return Expr();
443 444 445 446
  }
}

RELAY_REGISTER_OP("multiply")
447
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
448 449 450

// Consumer operators
// Conv2D send out requirement of axis folding.
451
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
452 453 454 455 456
  // TODO(tvm-team) support general data layout
  // by transforming weight
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
457
  Layout kernel_layout(param->kernel_layout);
458 459
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
460 461

  CHECK_GE(c_big_axis, 0);
462
  Message none = NullValue<Message>();
463 464 465 466 467 468 469 470
  AxesSet data_axes = NullValue<AxesSet>();
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
471
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
472
  if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
473 474 475 476
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
    data_axes = {c_big_axis};
  }
477 478 479 480
  if (data_axes.defined()) {
    return {MessageNode::make(data_axes, false), none};
  }
  return {none, none};
481 482 483
}

// Conv2D consumes the scale axis during transformation.
484 485
Expr Conv2DForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
486
                          const Message& message) {
487
  // if data do not have scale, normal transform path.
488 489 490 491
  const auto* sdata = new_args[0].as<ScaledExprNode>();
  const auto* sweight = new_args[1].as<ScaledExprNode>();
  if (sdata == nullptr) return Expr();
  if (sweight != nullptr) return Expr();
492 493 494
  const auto* param = ref_call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
495
  Layout kernel_layout(param->kernel_layout);
496
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
497 498 499
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
500
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
501 502
  CHECK(sdata->axes.size() == 1 &&
        c_big_axis == sdata->axes[0]->value);
503 504
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
  int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
505 506

  // Check it must be depthwise or full conv2d.
507
  bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
508
  CHECK(param->groups == 1 || is_depthwise_conv2d);
509 510

  Expr weight = new_args[1];
511 512

  // match the ic_axis
513 514
  if (is_depthwise_conv2d) {
    Expr scale = ExpandBiasToMatchAxis(
515
        sdata->scale, kernel_layout.ndim(), {big_oc_axis});
516 517 518
    weight = Multiply(weight, scale);
  } else {
    Expr scale = ExpandBiasToMatchAxis(
519
        sdata->scale, kernel_layout.ndim(), {big_ic_axis});
520 521
    weight = Multiply(weight, scale);
  }
522
  // return transformed conv2d
523
  return CallNode::make(
524 525 526 527 528 529 530
      ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
531
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
532 533


534
Expr ForwardFoldScaleAxis(const Expr& data) {
535
  auto message = ForwardPrep().Prepare(data);
536
  auto fcontext = [&](const Call& call) -> NodeRef{
537 538
    auto it = message.find(call.get());
    if (it != message.end()) {
539 540 541 542 543 544 545
      return it->second;
    } else {
      return NodeRef(nullptr);
    }
  };
  return ForwardRewrite(
      data, "FScaleAxisForwardRewrite", fcontext);
546 547 548 549 550 551
}

// Expose the FoldScaleAxisFoward
TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);

552 553 554 555 556 557 558 559
//----------------------------------------
// Implement backward transformations.
//----------------------------------------
class BackwardTransformer;

/*!
 * \brief Preparation function for for pass scale backward.
 * \param call The call node.
560 561 562
 * \param in_messages Messages from the input containing allowed input scaling and whether
 *        positive scale is required.
 * \return Message containing the result scaling on axes of the input.
563 564
 */
using FBackwardPrep = TypedPackedFunc<
565
  Message(const Call& call, const Array<Message>& in_messages)>;
566 567 568

using FBackwardTransform = TypedPackedFunc<
  Expr(const Call& call,
569
       const Message& message,
570 571 572 573 574 575 576 577 578 579
       const Expr& scale,
       const BackwardTransformer& transformer)>;

//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------

class BackwardPrep : private ExprVisitor {
 public:
  // The message on each node.
580
  std::unordered_map<const Node*, Message>
581 582 583 584 585 586 587 588
  Prepare(const Expr& body) {
    ref_counter_ = GetExprRefCount(body);
    this->VisitExpr(body);
    return std::move(message_);
  }

 private:
  // The message on each node.
589
  std::unordered_map<const Node*, Message> message_;
590 591 592 593 594 595 596
  // reference counter of an internal expr
  std::unordered_map<const Node*, size_t> ref_counter_;
  // Visit the expression.
  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    static const auto& fprep =
        Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
597
    auto f = fprep.get(call->op, nullptr);
598 599 600 601 602 603
    if (f == nullptr) return;
    auto rit = ref_counter_.find(call);
    CHECK(rit != ref_counter_.end());
    // We only allow propagation of scale backward
    // if the expression is only referred by a single parent.
    if (rit->second != 1) return;
604
    Array<Message> in_messages;
605 606 607
    for (Expr arg : call->args) {
      auto it = message_.find(arg.get());
      if (it != message_.end()) {
608
        in_messages.push_back(it->second);
609
      } else {
610
        in_messages.push_back(NullValue<Message>());
611 612
      }
    }
613 614 615
    Message out_message = f(GetRef<Call>(call), in_messages);
    if (out_message.defined()) {
      message_[call] = out_message;
616 617 618 619 620 621 622 623 624 625
    }
  }
};

class BackwardTransformerNode :
      public Node,
      private ExprMutator {
 public:
  // Run forward transform.
  Expr Fold(Expr expr) {
626
    message_ = BackwardPrep().Prepare(expr);
627 628 629 630 631 632 633 634 635 636
    return this->Mutate(expr);
  }
  /*!
   * \brief Transform the expr to consider the scaling.
   *
   * \param expr The input expression.
   * \param axes The axes to scale.
   * \param scale The scale applied to the axes.
   * \return The result of transformation.
   */
637
  Expr Transform(const Expr& expr, Message message, Expr scale) {
638
    // NOTE: the result of Transform is memoized.
639
    if (const CallNode* call_node = expr.as<CallNode>()) {
640
      return Transform(call_node, message, scale);
641
    } else {
642
      CHECK(!message.defined()) << "outstanding scale";
643 644 645 646 647 648 649 650 651
      return ExprMutator::VisitExpr(expr);
    }
  }
  /*!
   * \brief Normal way of mutating call node.
   * \param call_node The call node to be mutated.
   * \return the result of the call Mutation.
   */
  Expr NormalCallTransform(const CallNode* call_node) {
652 653 654 655 656 657 658 659
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = ExprMutator::VisitExpr_(call_node);
    memo_[call] = new_expr;
    return new_expr;
660 661
  }
  /*!
662
   * \brief Get the message propogated to the expr.
663
   * \param expr The expresison.
664
   * \return The message containing the expected axes and whether positive scale is required.
665
   */
666 667 668 669
  Message GetMessage(const Expr& expr) const {
    auto it = message_.find(expr.get());
    if (it != message_.end()) return it->second;
    return NullValue<Message>();
670 671 672 673 674 675 676 677 678 679
  }

  // solver is not serializable.
  void VisitAttrs(tvm::AttrVisitor* v) final {}

  static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
  TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);

 private:
  // Valid axes on each node.
680
  std::unordered_map<const Node*, Message> message_;
681 682
  // Override mutation of call.
  Expr VisitExpr_(const CallNode* call_node) final {
683
    return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
684 685
  }
  // Transform of CallNode.
686
  Expr Transform(const CallNode* call_node, Message message, Expr scale);
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
};

class BackwardTransformer : public NodeRef {
 public:
  BackwardTransformer() {}
  explicit BackwardTransformer(
      ::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
  }
  BackwardTransformerNode* operator->() const {
    return static_cast<BackwardTransformerNode*>(node_.get());
  }
  using ContainerType = BackwardTransformerNode;
};

Expr BackwardTransformerNode::Transform(
702
    const CallNode* call_node, Message message, Expr scale) {
703 704
  static const auto& ftransform =
      Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
705
  auto f = ftransform.get(call_node->op, nullptr);
706
  if (f != nullptr) {
707 708 709 710 711 712
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = f(GetRef<Call>(call_node),
713
                      message,
714 715 716 717
                      scale,
                      GetRef<BackwardTransformer>(this));
    memo_[call] = new_expr;
    return new_expr;
718
  } else {
719
    CHECK(!message.defined()) << "outstanding scale";
720 721 722 723 724 725 726 727 728 729
    return NormalCallTransform(call_node);
  }
}


//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
730 731 732 733 734
Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
  if (in_messages[0].defined()) {
    return MessageNode::make(in_messages[0]->axes, true);
  }
  return in_messages[0];
735 736 737
}

Expr ReluBackwardTransform(const Call& call,
738
                           const Message& message,
739 740
                           const Expr& scale,
                           const BackwardTransformer& transformer) {
741
  if (!message.defined()) {
742 743 744
    return transformer->NormalCallTransform(call.operator->());
  }
  Expr input = transformer->Transform(
745
      call->args[0], message, scale);
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761
  return CallNode::make(call->op, {input}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

// AddSub
762
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
763 764 765
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
  AttrsEqual equal;
766 767 768 769 770 771 772 773 774
  if (in_messages[0].defined() &&
      MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
    return in_messages[0];
  } else if (in_messages[1].defined() &&
             MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
    return in_messages[1];
  } else if (in_messages[0].defined() &&
             in_messages[1].defined() &&
             equal(in_messages[0]->axes, in_messages[1]->axes) &&
775 776
             equal(tlhs->shape, trhs->shape)) {
    // add of two elements.
777
    return in_messages[0];
778
  } else {
779
    auto res = NullValue<Message>();
780
    return res;
781 782 783 784
  }
}

Expr AddSubBackwardTransform(const Call& call,
785
                             const Message& message,
786 787 788 789
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
790
  if (!message.defined()) {
791 792
    return transformer->NormalCallTransform(call.operator->());
  }
793 794
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
795 796
  AttrsEqual equal;

797 798 799 800 801
  if (lhs_message.defined() && rhs_message.defined()) {
    CHECK(equal(lhs_message->axes, rhs_message->axes));
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
    Expr rhs = transformer->Transform(call->args[1], message, scale);
802
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
803 804 805
  } else if (lhs_message.defined()) {
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
806
    Expr rhs = transformer->Transform(
807
        call->args[1], NullValue<Message>(), NullValue<Expr>());
808
    Expr rhs_scale = ExpandBiasToMatchAxis(
809
        scale, tlhs->shape.size(), message->axes);
810 811
    rhs = Multiply(rhs, rhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
812 813
  } else if (rhs_message.defined()) {
    CHECK(equal(message->axes, rhs_message->axes));
814
    Expr lhs = transformer->Transform(
815 816
        call->args[0], NullValue<Message>(), NullValue<Expr>());
    Expr rhs = transformer->Transform(call->args[1], message, scale);
817
    Expr lhs_scale = ExpandBiasToMatchAxis(
818
        scale, trhs->shape.size(), message->axes);
819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841
    lhs = Multiply(lhs, lhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
  } else {
    LOG(FATAL) << "outstanding scale";
    return Expr();
  }
}

RELAY_REGISTER_OP("add")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
842
                               const Message& message,
843 844
                               const Expr& scale,
                               const BackwardTransformer& transformer) {
845
  CHECK(!message.defined()) << "outstanding scale";
846 847
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
848 849 850 851
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
  if (lhs_message.defined()) {
    CHECK(lhs_message->axes.defined() && lhs_message->axes.size());
852 853 854
    // NOTE we won't recursively call mutating on scale part.
    // since there  won't be scale chance within scale part.
    Expr rhs = call->args[1];
855 856 857
    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
        (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
      return transformer->Transform(call->args[0], lhs_message, rhs);
858
    }
859 860
  } else if (rhs_message.defined()) {
    CHECK(rhs_message->axes.defined() && rhs_message->axes.size());
861
    Expr lhs = call->args[0];
862 863 864
    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
        (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
      return transformer->Transform(call->args[1], rhs_message, lhs);
865 866 867 868 869 870 871 872 873 874
    }
  }
  return transformer->NormalCallTransform(call.operator->());
}

RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);

// Consumer operators
// Conv2D send out requirement of axis folding.
875
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
876 877
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
878 879
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
880 881
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
882 883 884 885 886 887 888 889 890

  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
891
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
892 893
  if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
  kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
894 895
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
896
    return MessageNode::make({c_big_axis}, false);
897
  } else {
898
    return NullValue<Message>();
899 900 901 902 903
  }
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call,
904
                             const Message& message,
905 906
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
907
  if (!message.defined()) {
908 909 910 911
    return transformer->NormalCallTransform(call.operator->());
  }
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
912 913
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
914
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
915 916 917
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
918 919
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
920 921
  CHECK(message->axes.size() == 1 &&
        c_big_axis == message->axes[0]->value);
922

923
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
924
  // Check it must be depthwise or full conv2d.
925
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
926 927 928
  CHECK(param->groups == 1 || is_depthwise_conv2d);

  Expr data = transformer->Transform(
929
      call->args[0], NullValue<Message>(), NullValue<Expr>());
930
  Expr weight = transformer->Transform(
931
      call->args[1], NullValue<Message>(), NullValue<Expr>());
932 933
  // scale on input for deptwise.
  Expr wscale = ExpandBiasToMatchAxis(
934
      scale, kernel_layout.ndim(), {big_oc_axis});
935 936 937 938 939 940 941 942 943 944 945
  weight = Multiply(weight, wscale);
  return CallNode::make(
      call->op, {data, weight}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);

946
Expr BackwardFoldScaleAxis(const Expr& data) {
947 948 949 950 951 952
  return make_node<BackwardTransformerNode>()->Fold(data);
}

TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
.set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);

953
}  // namespace fold_scale_axis
954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990

namespace transform {

Pass ForwardFoldScaleAxis() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(
          relay::fold_scale_axis::ForwardFoldScaleAxis(f));
  };
  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
                            {ir::StringImm::make("InferType")});
}

Pass BackwardFoldScaleAxis() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(
          relay::fold_scale_axis::BackwardFoldScaleAxis(f));
    };
  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
                            {ir::StringImm::make("InferType")});
}

Pass FoldScaleAxis() {
  // FoldScaleAxis pass contains the following three passes. Therefore, we can
  // register it as a sequential pass.
  Pass pass = Sequential(
      {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
      "FoldScaleAxis");
  return pass;
}

TVM_REGISTER_API("relay._transform.FoldScaleAxis")
.set_body_typed(FoldScaleAxis);

}  // namespace transform

991 992
}  // namespace relay
}  // namespace tvm