fold_scale_axis.cc 34 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 * \file fold_scale_axis.cc
 *
 * \brief Fold axis scaling into weights of
 *  conv/dense operators.
 */
26
#include <tvm/tir/data_layout.h>
Zhi committed
27
#include <tvm/relay/analysis.h>
28 29
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
30
#include <tvm/relay/transform.h>
31
#include "pattern_util.h"
32
#include "pass_util.h"
33

34 35 36 37 38 39 40 41 42 43 44 45 46

namespace tvm {
namespace relay {
/*!
 * \brief namespace of fold scale axis
 *
 * Use namespace to reduce potential naming conflict.
 */
namespace fold_scale_axis {

using runtime::TypedPackedFunc;


47
// FoldScaleAxis algorithm:
48
//
49
// The general idea is to transform Expr to tuple of
50
// (value, axes, scale), where the final result satisfies:
51 52 53
//
// result = value
// for i, k in enumerate(axes):
54
//    k-th dimension of result *= i-th dimension of scale
55 56 57 58 59 60 61 62 63
//
// Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed
// if there is no dense/conv2d that follows multiplication.
//
// In order to make sure all the scale we sent out can be consumed eventually,
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
64
// Forward folding process is done in two steps:
65 66
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
67 68 69 70 71
//
// Similarly, backward folding process is done in two steps:
// - Prepare phase: forward propagation of demand.
// - Transform phase: transformation by push down the axes scale signal to inputs.
//
72 73 74 75 76 77 78 79

/*!
 * \brief sorted array axis, can also be nullptr.
 *
 *  nullptr means no scaling request can be done.
 */
using AxesSet = Array<Integer>;

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
class Message;

/*!
 * \brief Message propogated during the prepare phase.
 */
class MessageNode : public RelayNode {
 public:
  /*! \brief Axes for scaling */
  AxesSet axes;
  /*!
   * \brief Whether folding requires the scale to be positive constant. This is necessary if some
   *  operators (e.g. Relu) is present.
   */
  bool require_positive;

  static Message make(const AxesSet& axes, bool require_positive);

  static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
98
  TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode);
99 100
};

101 102 103 104
class Message : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode);
};
105 106

Message MessageNode::make(const AxesSet& axes, bool require_positive)  {
107
  auto n = make_object<MessageNode>();
108 109 110 111 112
  n->axes = axes;
  n->require_positive = require_positive;
  return Message(n);
}

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
/*!
 * \brief Merge two axis set together by taking
 *  intersection.
 *
 * \note The axes in a AxesSet should be sorted.
 *
 * \param lhs The left axis.
 * \param rhs The right axis.
 * \return The result of the inersection.
 */
AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  // This code relies on axes in a AxesSet to be sorted.
  AxesSet ret;
  size_t i = 0, j = 0;
  while (i < lhs.size() && j < rhs.size()) {
    if (lhs[i]->value < rhs[j]->value) {
      ++i;
    } else if (lhs[i]->value > rhs[j]->value) {
      ++j;
    } else {
      ret.push_back(lhs[i]);
      ++i; ++j;
    }
  }
  return ret;
}

/*!
143 144 145 146 147 148 149 150 151 152 153 154 155 156
 * \brief Merge two messages together by taking intersection.
 *
 * \param lhs The lhs message.
 * \param rhs The rhs message.
 * \return The result of intersection.
 */
Message Intersect(const Message& lhs, const Message& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  auto axes = Intersect(lhs->axes, rhs->axes);
  return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
}

/*!
157
 * \brief Preparation function for pass scale forward.
158
 * \param call The call node.
159 160 161
 * \param out_message Message from the output containing possible scaling on axes and whether
 *        positive scale is required.
 * \return The message containing the result scaling on axes of the input.
162 163
 */
using FForwardPrep = runtime::TypedPackedFunc<
164
  Array<Message> (const Call& call, const Message& out_message)>;
165 166

/*! \brief Axis scale tuple.  */
167
class ScaledExprNode : public TempExprNode {
168 169 170 171 172 173 174 175
 public:
  /*! \brief The value */
  Expr value;
  /*! \brief The axes to scale, can be nullptr(means no-scaling) */
  AxesSet axes = NullValue<AxesSet>();
  /*! \brief The scaling factor */
  Expr scale = NullValue<Expr>();

176 177 178 179 180 181
  Expr Realize() const final {
    CHECK(!axes.defined())
        << "outstanding scale";
    return value;
  }

182
  void VisitAttrs(AttrVisitor* v) {
183 184 185 186 187
    v->Visit("value", &value);
    v->Visit("axes", &axes);
    v->Visit("scale", &scale);
  }

188
  static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr";
189
  TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode);
190 191
};

192 193 194
using FForwardRewrite = TypedPackedFunc<
  Expr(const Call& ref_call,
       const Array<Expr>& new_args,
195
       const Message& message)>;
196 197 198 199

//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
200
class ForwardPrep : private ExprVisitor {
201
 public:
202
  std::unordered_map<const Object*, Message>
203
  Prepare(const Expr& body) {
204
    this->Update(body, NullValue<Message>());
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    this->VisitExpr(body);
    // flist is added in the Post-DFS order
    // which is a special case of topological order.
    // We reversely traverse the list to invoke the lazy functions.
    // This act like a backprop of valid scale axis messages
    for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
      (*it)();
    }
    // return the created message;
    return std::move(message_);
  }

 private:
  // The invoke list
  std::vector<std::function<void()> > flist_;
  // The message on each node.
221
  std::unordered_map<const Object*, Message> message_;
222
  // Update the message stored at node.
223
  void Update(const Expr& node, const Message& message) {
224 225 226 227 228 229 230 231 232 233
    // We run intersection of messages:
    //
    // %y = multiply(%x, %scale)
    // %z1 = conv2d(%y, %w)
    // %z2 = exp(%y)
    //
    // Consider the above code example,
    // because %z2 will propagate null to %y,
    // the AxesSet on %y is also null,
    // and the forward folding won't be triggered.
234
    const Object* key = node.get();
235
    if (message_.count(key)) {
236
      message_[key] = Intersect(message_[key], message);
237
    } else {
238
      message_[key] = message;
239 240 241 242 243 244 245 246 247 248
    }
  }
  // Visitor pattern override.
  void VisitExpr_(const LetNode* call) {
    LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
  }

  void VisitExpr_(const FunctionNode* op) {
    ExprVisitor::VisitExpr_(op);
    auto flazy = [this, op] {
249
      this->Update(op->body, NullValue<Message>());
250 251 252 253 254 255 256 257 258 259 260 261
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    // function to be lazily invoked
    auto flazy = [this, call]() {
      static const auto& fprep =
        Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
      // find the message send to this node.
      auto it = message_.find(call);
262
      Message out_message;
263
      if (it != message_.end()) {
264
        out_message = it->second;
265
      } else {
266
        out_message = NullValue<Message>();
267 268
      }
      // pass the message back to all the children it references.
269
      auto f = fprep.get(call->op, nullptr);
270
      if (f != nullptr) {
271 272
        Array<Message> in_messages = f(GetRef<Call>(call), out_message);
        CHECK_EQ(in_messages.size(), call->args.size());
273
        for (size_t i = 0; i < call->args.size(); ++i) {
274
          this->Update(call->args[i], in_messages[i]);
275 276 277
        }
      } else {
        for (size_t i = 0; i < call->args.size(); ++i) {
278
          this->Update(call->args[i], NullValue<Message>());
279 280 281 282 283 284 285 286 287 288 289
        }
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const TupleNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do not support pass scale through tuple for now.
    auto flazy = [this, op]() {
      for (const Expr& field : op->fields) {
290
        this->Update(field, NullValue<Message>());
291 292 293 294 295 296 297 298
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const IfNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do pass through condition
299
    // by assigning NullValue<Message>
300 301 302
    // it means fuse signal cannot pass
    // through into these subexpressions.
    auto flazy = [this, op]() {
303 304 305
      this->Update(op->cond, NullValue<Message>());
      this->Update(op->true_branch, NullValue<Message>());
      this->Update(op->false_branch, NullValue<Message>());
306 307 308 309 310 311 312 313 314 315
    };
    flist_.push_back(flazy);
  }
};

//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
316 317 318 319 320
Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
  if (out_message.defined()) {
    return {MessageNode::make(out_message->axes, true)};
  }
  return {out_message};
321 322
}

323 324
Expr ReluForwardRewrite(const Call& ref_call,
                        const Array<Expr>& new_args,
325
                        const Message& message) {
326 327
  const auto* input = new_args[0].as<ScaledExprNode>();
  if (input == nullptr) return Expr(nullptr);
328
  // return transformed conv2d
329
  auto rnode = make_object<ScaledExprNode>();
330
  rnode->value = CallNode::make(
331 332 333 334
      ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
  rnode->scale = input->scale;
  rnode->axes = input->axes;
  return Expr(rnode);
335 336 337 338 339 340
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.relu")
341
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
342 343 344 345 346

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
347
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
348 349

// AddSub
350
Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
351 352
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
353 354 355 356 357 358 359
  auto none = NullValue<Message>();
  if (out_message.defined()) {
    if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
      return {out_message, none};
    } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
      return {none, out_message};
    }
360
  }
361
  return {none, none};
362 363
}

364 365
Expr AddSubForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
366
                          const Message& message) {
367 368 369
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  if (!slhs && !srhs) return Expr();
370 371
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
372
  auto rnode = make_object<ScaledExprNode>();
373

374 375 376
  if (slhs != nullptr) {
    CHECK(srhs == nullptr);
    CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
377
    Expr scale = ExpandBiasToMatchAxis(
378 379 380
        slhs->scale, tlhs->shape.size(), slhs->axes);
    Expr rhs = Divide(new_args[1], scale);
    rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs},
381
                                  ref_call->attrs, ref_call->type_args);
382 383
    rnode->scale = slhs->scale;
    rnode->axes = slhs->axes;
384
  } else {
385
    CHECK(srhs != nullptr);
386
    CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
387
    Expr scale = ExpandBiasToMatchAxis(
388 389 390
        srhs->scale, trhs->shape.size(), srhs->axes);
    Expr lhs = Divide(new_args[0], scale);
    rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value},
391
                                  ref_call->attrs, ref_call->type_args);
392 393
    rnode->scale = srhs->scale;
    rnode->axes = srhs->axes;
394
  }
395
  return Expr(rnode);
396 397 398 399 400 401
}

RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("add")
402
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
403 404 405 406 407

RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("subtract")
408
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
409 410 411

// Producer operators
// Multiply produces the scale-axis pair.
412 413
Expr MultiplyForwardRewrite(const Call& ref_call,
                            const Array<Expr>& new_args,
414 415 416 417
                            const Message& message) {
  if (!message.defined()) return Expr();
  const auto& expected_out_axes = message->axes;
  CHECK(expected_out_axes.defined() && expected_out_axes.size());
418 419
  // TODO(tvm-team) allow same axes accumulation
  // not as important because it is less common in nn.
420 421 422 423
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  CHECK(!slhs && !srhs);

424 425
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
426 427
  Expr lhs = new_args[0];
  Expr rhs = new_args[1];
428
  auto rnode = make_object<ScaledExprNode>();
429

430
  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
431
      (!message->require_positive || IsAllPositiveConstant(rhs))) {
432 433 434
    rnode->value = lhs;
    rnode->scale = rhs;
    rnode->axes = expected_out_axes;
435 436
    return Expr(rnode);
  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
437
             (!message->require_positive || IsAllPositiveConstant(lhs))) {
438 439 440
    rnode->value = rhs;
    rnode->scale = lhs;
    rnode->axes = expected_out_axes;
441 442 443
    return Expr(rnode);
  } else {
    return Expr();
444 445 446 447
  }
}

RELAY_REGISTER_OP("multiply")
448
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
449 450 451

// Consumer operators
// Conv2D send out requirement of axis folding.
452
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
453 454 455 456 457
  // TODO(tvm-team) support general data layout
  // by transforming weight
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
458
  Layout kernel_layout(param->kernel_layout);
459 460
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
461 462

  CHECK_GE(c_big_axis, 0);
463
  Message none = NullValue<Message>();
464 465 466 467 468 469 470 471
  AxesSet data_axes = NullValue<AxesSet>();
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
472
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
473
  if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
474 475 476 477
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
    data_axes = {c_big_axis};
  }
478 479 480 481
  if (data_axes.defined()) {
    return {MessageNode::make(data_axes, false), none};
  }
  return {none, none};
482 483 484
}

// Conv2D consumes the scale axis during transformation.
485 486
Expr Conv2DForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
487
                          const Message& message) {
488
  // if data do not have scale, normal transform path.
489 490 491 492
  const auto* sdata = new_args[0].as<ScaledExprNode>();
  const auto* sweight = new_args[1].as<ScaledExprNode>();
  if (sdata == nullptr) return Expr();
  if (sweight != nullptr) return Expr();
493 494 495
  const auto* param = ref_call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
496
  Layout kernel_layout(param->kernel_layout);
497
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
498 499 500
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
501
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
502 503
  CHECK(sdata->axes.size() == 1 &&
        c_big_axis == sdata->axes[0]->value);
504 505
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
  int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
506 507

  // Check it must be depthwise or full conv2d.
508
  bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
509
  CHECK(param->groups == 1 || is_depthwise_conv2d);
510 511

  Expr weight = new_args[1];
512 513

  // match the ic_axis
514 515
  if (is_depthwise_conv2d) {
    Expr scale = ExpandBiasToMatchAxis(
516
        sdata->scale, kernel_layout.ndim(), {big_oc_axis});
517 518 519
    weight = Multiply(weight, scale);
  } else {
    Expr scale = ExpandBiasToMatchAxis(
520
        sdata->scale, kernel_layout.ndim(), {big_ic_axis});
521 522
    weight = Multiply(weight, scale);
  }
523
  // return transformed conv2d
524
  return CallNode::make(
525 526 527 528 529 530 531
      ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
532
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
533 534


535
Expr ForwardFoldScaleAxis(const Expr& data) {
536
  auto message = ForwardPrep().Prepare(data);
537
  auto fcontext = [&](const Call& call) -> ObjectRef{
538 539
    auto it = message.find(call.get());
    if (it != message.end()) {
540 541
      return it->second;
    } else {
542
      return ObjectRef(nullptr);
543 544 545 546
    }
  };
  return ForwardRewrite(
      data, "FScaleAxisForwardRewrite", fcontext);
547 548
}

549 550 551 552 553 554 555 556
//----------------------------------------
// Implement backward transformations.
//----------------------------------------
class BackwardTransformer;

/*!
 * \brief Preparation function for for pass scale backward.
 * \param call The call node.
557 558 559
 * \param in_messages Messages from the input containing allowed input scaling and whether
 *        positive scale is required.
 * \return Message containing the result scaling on axes of the input.
560 561
 */
using FBackwardPrep = TypedPackedFunc<
562
  Message(const Call& call, const Array<Message>& in_messages)>;
563 564 565

using FBackwardTransform = TypedPackedFunc<
  Expr(const Call& call,
566
       const Message& message,
567 568 569 570 571 572 573 574 575 576
       const Expr& scale,
       const BackwardTransformer& transformer)>;

//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------

class BackwardPrep : private ExprVisitor {
 public:
  // The message on each node.
577
  std::unordered_map<const Object*, Message>
578 579 580 581 582 583 584 585
  Prepare(const Expr& body) {
    ref_counter_ = GetExprRefCount(body);
    this->VisitExpr(body);
    return std::move(message_);
  }

 private:
  // The message on each node.
586
  std::unordered_map<const Object*, Message> message_;
587
  // reference counter of an internal expr
588
  std::unordered_map<const Object*, size_t> ref_counter_;
589 590 591 592 593
  // Visit the expression.
  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    static const auto& fprep =
        Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
594
    auto f = fprep.get(call->op, nullptr);
595 596 597 598 599 600
    if (f == nullptr) return;
    auto rit = ref_counter_.find(call);
    CHECK(rit != ref_counter_.end());
    // We only allow propagation of scale backward
    // if the expression is only referred by a single parent.
    if (rit->second != 1) return;
601
    Array<Message> in_messages;
602 603 604
    for (Expr arg : call->args) {
      auto it = message_.find(arg.get());
      if (it != message_.end()) {
605
        in_messages.push_back(it->second);
606
      } else {
607
        in_messages.push_back(NullValue<Message>());
608 609
      }
    }
610 611 612
    Message out_message = f(GetRef<Call>(call), in_messages);
    if (out_message.defined()) {
      message_[call] = out_message;
613 614 615 616 617
    }
  }
};

class BackwardTransformerNode :
618
      public Object,
619 620 621 622
      private ExprMutator {
 public:
  // Run forward transform.
  Expr Fold(Expr expr) {
623
    message_ = BackwardPrep().Prepare(expr);
624 625 626 627 628 629 630 631 632 633
    return this->Mutate(expr);
  }
  /*!
   * \brief Transform the expr to consider the scaling.
   *
   * \param expr The input expression.
   * \param axes The axes to scale.
   * \param scale The scale applied to the axes.
   * \return The result of transformation.
   */
634
  Expr Transform(const Expr& expr, Message message, Expr scale) {
635
    // NOTE: the result of Transform is memoized.
636
    if (const CallNode* call_node = expr.as<CallNode>()) {
637
      return Transform(call_node, message, scale);
638
    } else {
639
      CHECK(!message.defined()) << "outstanding scale";
640 641 642 643 644 645 646 647 648
      return ExprMutator::VisitExpr(expr);
    }
  }
  /*!
   * \brief Normal way of mutating call node.
   * \param call_node The call node to be mutated.
   * \return the result of the call Mutation.
   */
  Expr NormalCallTransform(const CallNode* call_node) {
649 650 651 652 653 654 655 656
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = ExprMutator::VisitExpr_(call_node);
    memo_[call] = new_expr;
    return new_expr;
657 658
  }
  /*!
659
   * \brief Get the message propogated to the expr.
660
   * \param expr The expresison.
661
   * \return The message containing the expected axes and whether positive scale is required.
662
   */
663 664 665 666
  Message GetMessage(const Expr& expr) const {
    auto it = message_.find(expr.get());
    if (it != message_.end()) return it->second;
    return NullValue<Message>();
667 668 669
  }

  // solver is not serializable.
670
  void VisitAttrs(tvm::AttrVisitor* v) {}
671 672

  static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
673
  TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object);
674 675 676

 private:
  // Valid axes on each node.
677
  std::unordered_map<const Object*, Message> message_;
678 679
  // Override mutation of call.
  Expr VisitExpr_(const CallNode* call_node) final {
680
    return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
681 682
  }
  // Transform of CallNode.
683
  Expr Transform(const CallNode* call_node, Message message, Expr scale);
684 685
};

686
class BackwardTransformer : public ObjectRef {
687 688 689
 public:
  BackwardTransformer() {}
  explicit BackwardTransformer(
690
      ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
691 692
  }
  BackwardTransformerNode* operator->() const {
693
    return static_cast<BackwardTransformerNode*>(get_mutable());
694 695 696 697 698
  }
  using ContainerType = BackwardTransformerNode;
};

Expr BackwardTransformerNode::Transform(
699
    const CallNode* call_node, Message message, Expr scale) {
700 701
  static const auto& ftransform =
      Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
702
  auto f = ftransform.get(call_node->op, nullptr);
703
  if (f != nullptr) {
704 705 706 707 708 709
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = f(GetRef<Call>(call_node),
710
                      message,
711 712 713 714
                      scale,
                      GetRef<BackwardTransformer>(this));
    memo_[call] = new_expr;
    return new_expr;
715
  } else {
716
    CHECK(!message.defined()) << "outstanding scale";
717 718 719 720 721 722 723 724 725 726
    return NormalCallTransform(call_node);
  }
}


//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
727 728 729 730 731
Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
  if (in_messages[0].defined()) {
    return MessageNode::make(in_messages[0]->axes, true);
  }
  return in_messages[0];
732 733 734
}

Expr ReluBackwardTransform(const Call& call,
735
                           const Message& message,
736 737
                           const Expr& scale,
                           const BackwardTransformer& transformer) {
738
  if (!message.defined()) {
739 740 741
    return transformer->NormalCallTransform(call.operator->());
  }
  Expr input = transformer->Transform(
742
      call->args[0], message, scale);
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
  return CallNode::make(call->op, {input}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

// AddSub
759
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
760 761 762
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
  AttrsEqual equal;
763 764 765 766 767 768 769 770 771
  if (in_messages[0].defined() &&
      MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
    return in_messages[0];
  } else if (in_messages[1].defined() &&
             MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
    return in_messages[1];
  } else if (in_messages[0].defined() &&
             in_messages[1].defined() &&
             equal(in_messages[0]->axes, in_messages[1]->axes) &&
772 773
             equal(tlhs->shape, trhs->shape)) {
    // add of two elements.
774
    return in_messages[0];
775
  } else {
776
    auto res = NullValue<Message>();
777
    return res;
778 779 780 781
  }
}

Expr AddSubBackwardTransform(const Call& call,
782
                             const Message& message,
783 784 785 786
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
787
  if (!message.defined()) {
788 789
    return transformer->NormalCallTransform(call.operator->());
  }
790 791
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
792 793
  AttrsEqual equal;

794 795 796 797 798
  if (lhs_message.defined() && rhs_message.defined()) {
    CHECK(equal(lhs_message->axes, rhs_message->axes));
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
    Expr rhs = transformer->Transform(call->args[1], message, scale);
799
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
800 801 802
  } else if (lhs_message.defined()) {
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
803
    Expr rhs = transformer->Transform(
804
        call->args[1], NullValue<Message>(), NullValue<Expr>());
805
    Expr rhs_scale = ExpandBiasToMatchAxis(
806
        scale, tlhs->shape.size(), message->axes);
807 808
    rhs = Multiply(rhs, rhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
809 810
  } else if (rhs_message.defined()) {
    CHECK(equal(message->axes, rhs_message->axes));
811
    Expr lhs = transformer->Transform(
812 813
        call->args[0], NullValue<Message>(), NullValue<Expr>());
    Expr rhs = transformer->Transform(call->args[1], message, scale);
814
    Expr lhs_scale = ExpandBiasToMatchAxis(
815
        scale, trhs->shape.size(), message->axes);
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
    lhs = Multiply(lhs, lhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
  } else {
    LOG(FATAL) << "outstanding scale";
    return Expr();
  }
}

RELAY_REGISTER_OP("add")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
839
                               const Message& message,
840 841
                               const Expr& scale,
                               const BackwardTransformer& transformer) {
842
  CHECK(!message.defined()) << "outstanding scale";
843 844
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
845 846 847 848
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
  if (lhs_message.defined()) {
    CHECK(lhs_message->axes.defined() && lhs_message->axes.size());
849 850 851
    // NOTE we won't recursively call mutating on scale part.
    // since there  won't be scale chance within scale part.
    Expr rhs = call->args[1];
852 853 854
    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
        (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
      return transformer->Transform(call->args[0], lhs_message, rhs);
855
    }
856 857
  } else if (rhs_message.defined()) {
    CHECK(rhs_message->axes.defined() && rhs_message->axes.size());
858
    Expr lhs = call->args[0];
859 860 861
    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
        (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
      return transformer->Transform(call->args[1], rhs_message, lhs);
862 863 864 865 866 867 868 869 870 871
    }
  }
  return transformer->NormalCallTransform(call.operator->());
}

RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);

// Consumer operators
// Conv2D send out requirement of axis folding.
872
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
873 874
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
875 876
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
877 878
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
879 880 881 882 883 884 885 886 887

  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
888
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
889 890
  if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
  kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
891 892
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
893
    return MessageNode::make({c_big_axis}, false);
894
  } else {
895
    return NullValue<Message>();
896 897 898 899 900
  }
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call,
901
                             const Message& message,
902 903
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
904
  if (!message.defined()) {
905 906 907 908
    return transformer->NormalCallTransform(call.operator->());
  }
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
909 910
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
911
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
912 913 914
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
915 916
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
917 918
  CHECK(message->axes.size() == 1 &&
        c_big_axis == message->axes[0]->value);
919

920
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
921
  // Check it must be depthwise or full conv2d.
922
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
923 924 925
  CHECK(param->groups == 1 || is_depthwise_conv2d);

  Expr data = transformer->Transform(
926
      call->args[0], NullValue<Message>(), NullValue<Expr>());
927
  Expr weight = transformer->Transform(
928
      call->args[1], NullValue<Message>(), NullValue<Expr>());
929 930
  // scale on input for deptwise.
  Expr wscale = ExpandBiasToMatchAxis(
931
      scale, kernel_layout.ndim(), {big_oc_axis});
932 933 934 935 936 937 938 939 940 941 942
  weight = Multiply(weight, wscale);
  return CallNode::make(
      call->op, {data, weight}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);

943
Expr BackwardFoldScaleAxis(const Expr& data) {
944
  return make_object<BackwardTransformerNode>()->Fold(data);
945 946
}

947
}  // namespace fold_scale_axis
948 949 950 951

namespace transform {

Pass ForwardFoldScaleAxis() {
952 953
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
954 955 956 957
      return Downcast<Function>(
          relay::fold_scale_axis::ForwardFoldScaleAxis(f));
  };
  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
958
                            {tir::StringImmNode::make("InferType")});
959 960
}

961
TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
Zhi committed
962 963
.set_body_typed(ForwardFoldScaleAxis);

964
Pass BackwardFoldScaleAxis() {
965 966
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
967 968 969 970
      return Downcast<Function>(
          relay::fold_scale_axis::BackwardFoldScaleAxis(f));
    };
  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
971
                            {tir::StringImmNode::make("InferType")});
972 973
}

974
TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
Zhi committed
975 976
.set_body_typed(BackwardFoldScaleAxis);

977 978 979 980 981 982 983 984 985
Pass FoldScaleAxis() {
  // FoldScaleAxis pass contains the following three passes. Therefore, we can
  // register it as a sequential pass.
  Pass pass = Sequential(
      {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
      "FoldScaleAxis");
  return pass;
}

986
TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis")
987 988 989 990
.set_body_typed(FoldScaleAxis);

}  // namespace transform

991 992
}  // namespace relay
}  // namespace tvm