fold_scale_axis.cc 33.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file fold_scale_axis.cc
 *
 * \brief Fold axis scaling into weights of
 *  conv/dense operators.
 */
28
#include <tvm/data_layout.h>
Zhi committed
29
#include <tvm/relay/analysis.h>
30 31
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
32
#include <tvm/relay/transform.h>
33
#include "pattern_util.h"
34
#include "pass_util.h"
35

36 37 38 39 40 41 42 43 44 45 46 47 48

namespace tvm {
namespace relay {
/*!
 * \brief namespace of fold scale axis
 *
 * Use namespace to reduce potential naming conflict.
 */
namespace fold_scale_axis {

using runtime::TypedPackedFunc;


49
// FoldScaleAxis algorithm:
50
//
51
// The general idea is to transform Expr to tuple of
52
// (value, axes, scale), where the final result satisfies:
53 54 55
//
// result = value
// for i, k in enumerate(axes):
56
//    k-th dimension of result *= i-th dimension of scale
57 58 59 60 61 62 63 64 65
//
// Then we can propagate this signal along and fold the scale if necessary.
// However, it is possible that certain scale may never be consumed
// if there is no dense/conv2d that follows multiplication.
//
// In order to make sure all the scale we sent out can be consumed eventually,
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
//
66
// Forward folding process is done in two steps:
67 68
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
69 70 71 72 73
//
// Similarly, backward folding process is done in two steps:
// - Prepare phase: forward propagation of demand.
// - Transform phase: transformation by push down the axes scale signal to inputs.
//
74 75 76 77 78 79 80 81

/*!
 * \brief sorted array axis, can also be nullptr.
 *
 *  nullptr means no scaling request can be done.
 */
using AxesSet = Array<Integer>;

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
class Message;

/*!
 * \brief Message propogated during the prepare phase.
 */
class MessageNode : public RelayNode {
 public:
  /*! \brief Axes for scaling */
  AxesSet axes;
  /*!
   * \brief Whether folding requires the scale to be positive constant. This is necessary if some
   *  operators (e.g. Relu) is present.
   */
  bool require_positive;

  static Message make(const AxesSet& axes, bool require_positive);

  static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
  TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode);
};

RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef);

Message MessageNode::make(const AxesSet& axes, bool require_positive)  {
  auto n = make_node<MessageNode>();
  n->axes = axes;
  n->require_positive = require_positive;
  return Message(n);
}

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
/*!
 * \brief Merge two axis set together by taking
 *  intersection.
 *
 * \note The axes in a AxesSet should be sorted.
 *
 * \param lhs The left axis.
 * \param rhs The right axis.
 * \return The result of the inersection.
 */
AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  // This code relies on axes in a AxesSet to be sorted.
  AxesSet ret;
  size_t i = 0, j = 0;
  while (i < lhs.size() && j < rhs.size()) {
    if (lhs[i]->value < rhs[j]->value) {
      ++i;
    } else if (lhs[i]->value > rhs[j]->value) {
      ++j;
    } else {
      ret.push_back(lhs[i]);
      ++i; ++j;
    }
  }
  return ret;
}

/*!
142 143 144 145 146 147 148 149 150 151 152 153 154 155
 * \brief Merge two messages together by taking intersection.
 *
 * \param lhs The lhs message.
 * \param rhs The rhs message.
 * \return The result of intersection.
 */
Message Intersect(const Message& lhs, const Message& rhs) {
  if (!lhs.defined()) return lhs;
  if (!rhs.defined()) return rhs;
  auto axes = Intersect(lhs->axes, rhs->axes);
  return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
}

/*!
156
 * \brief Preparation function for pass scale forward.
157
 * \param call The call node.
158 159 160
 * \param out_message Message from the output containing possible scaling on axes and whether
 *        positive scale is required.
 * \return The message containing the result scaling on axes of the input.
161 162
 */
using FForwardPrep = runtime::TypedPackedFunc<
163
  Array<Message> (const Call& call, const Message& out_message)>;
164 165

/*! \brief Axis scale tuple.  */
166
class ScaledExprNode : public TempExprNode {
167 168 169 170 171 172 173 174
 public:
  /*! \brief The value */
  Expr value;
  /*! \brief The axes to scale, can be nullptr(means no-scaling) */
  AxesSet axes = NullValue<AxesSet>();
  /*! \brief The scaling factor */
  Expr scale = NullValue<Expr>();

175 176 177 178 179 180
  Expr Realize() const final {
    CHECK(!axes.defined())
        << "outstanding scale";
    return value;
  }

181 182 183 184 185 186
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("value", &value);
    v->Visit("axes", &axes);
    v->Visit("scale", &scale);
  }

187 188
  static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr";
  TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode);
189 190
};

191 192 193
using FForwardRewrite = TypedPackedFunc<
  Expr(const Call& ref_call,
       const Array<Expr>& new_args,
194
       const Message& message)>;
195 196 197 198

//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
199
class ForwardPrep : private ExprVisitor {
200
 public:
201
  std::unordered_map<const Node*, Message>
202
  Prepare(const Expr& body) {
203
    this->Update(body, NullValue<Message>());
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    this->VisitExpr(body);
    // flist is added in the Post-DFS order
    // which is a special case of topological order.
    // We reversely traverse the list to invoke the lazy functions.
    // This act like a backprop of valid scale axis messages
    for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
      (*it)();
    }
    // return the created message;
    return std::move(message_);
  }

 private:
  // The invoke list
  std::vector<std::function<void()> > flist_;
  // The message on each node.
220
  std::unordered_map<const Node*, Message> message_;
221
  // Update the message stored at node.
222
  void Update(const Expr& node, const Message& message) {
223 224 225 226 227 228 229 230 231 232 233 234
    // We run intersection of messages:
    //
    // %y = multiply(%x, %scale)
    // %z1 = conv2d(%y, %w)
    // %z2 = exp(%y)
    //
    // Consider the above code example,
    // because %z2 will propagate null to %y,
    // the AxesSet on %y is also null,
    // and the forward folding won't be triggered.
    const Node* key = node.get();
    if (message_.count(key)) {
235
      message_[key] = Intersect(message_[key], message);
236
    } else {
237
      message_[key] = message;
238 239 240 241 242 243 244 245 246 247
    }
  }
  // Visitor pattern override.
  void VisitExpr_(const LetNode* call) {
    LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
  }

  void VisitExpr_(const FunctionNode* op) {
    ExprVisitor::VisitExpr_(op);
    auto flazy = [this, op] {
248
      this->Update(op->body, NullValue<Message>());
249 250 251 252 253 254 255 256 257 258 259 260
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    // function to be lazily invoked
    auto flazy = [this, call]() {
      static const auto& fprep =
        Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
      // find the message send to this node.
      auto it = message_.find(call);
261
      Message out_message;
262
      if (it != message_.end()) {
263
        out_message = it->second;
264
      } else {
265
        out_message = NullValue<Message>();
266 267
      }
      // pass the message back to all the children it references.
268
      auto f = fprep.get(call->op, nullptr);
269
      if (f != nullptr) {
270 271
        Array<Message> in_messages = f(GetRef<Call>(call), out_message);
        CHECK_EQ(in_messages.size(), call->args.size());
272
        for (size_t i = 0; i < call->args.size(); ++i) {
273
          this->Update(call->args[i], in_messages[i]);
274 275 276
        }
      } else {
        for (size_t i = 0; i < call->args.size(); ++i) {
277
          this->Update(call->args[i], NullValue<Message>());
278 279 280 281 282 283 284 285 286 287 288
        }
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const TupleNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do not support pass scale through tuple for now.
    auto flazy = [this, op]() {
      for (const Expr& field : op->fields) {
289
        this->Update(field, NullValue<Message>());
290 291 292 293 294 295 296 297
      }
    };
    flist_.push_back(flazy);
  }

  void VisitExpr_(const IfNode* op) {
    ExprVisitor::VisitExpr_(op);
    // do pass through condition
298
    // by assigning NullValue<Message>
299 300 301
    // it means fuse signal cannot pass
    // through into these subexpressions.
    auto flazy = [this, op]() {
302 303 304
      this->Update(op->cond, NullValue<Message>());
      this->Update(op->true_branch, NullValue<Message>());
      this->Update(op->false_branch, NullValue<Message>());
305 306 307 308 309 310 311 312 313 314
    };
    flist_.push_back(flazy);
  }
};

//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
315 316 317 318 319
Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
  if (out_message.defined()) {
    return {MessageNode::make(out_message->axes, true)};
  }
  return {out_message};
320 321
}

322 323
Expr ReluForwardRewrite(const Call& ref_call,
                        const Array<Expr>& new_args,
324
                        const Message& message) {
325 326
  const auto* input = new_args[0].as<ScaledExprNode>();
  if (input == nullptr) return Expr(nullptr);
327
  // return transformed conv2d
328
  auto rnode = make_node<ScaledExprNode>();
329
  rnode->value = CallNode::make(
330 331 332 333
      ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
  rnode->scale = input->scale;
  rnode->axes = input->axes;
  return Expr(rnode);
334 335 336 337 338 339
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.relu")
340
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
341 342 343 344 345

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
346
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);
347 348

// AddSub
349
Array<Message> AddSubForwardPrep(const Call& call, const Message& out_message) {
350 351
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
352 353 354 355 356 357 358
  auto none = NullValue<Message>();
  if (out_message.defined()) {
    if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) {
      return {out_message, none};
    } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) {
      return {none, out_message};
    }
359
  }
360
  return {none, none};
361 362
}

363 364
Expr AddSubForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
365
                          const Message& message) {
366 367 368
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  if (!slhs && !srhs) return Expr();
369 370
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
371
  auto rnode = make_node<ScaledExprNode>();
372

373 374 375
  if (slhs != nullptr) {
    CHECK(srhs == nullptr);
    CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
376
    Expr scale = ExpandBiasToMatchAxis(
377 378 379
        slhs->scale, tlhs->shape.size(), slhs->axes);
    Expr rhs = Divide(new_args[1], scale);
    rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs},
380
                                  ref_call->attrs, ref_call->type_args);
381 382
    rnode->scale = slhs->scale;
    rnode->axes = slhs->axes;
383
  } else {
384
    CHECK(srhs != nullptr);
385
    CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
386
    Expr scale = ExpandBiasToMatchAxis(
387 388 389
        srhs->scale, trhs->shape.size(), srhs->axes);
    Expr lhs = Divide(new_args[0], scale);
    rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value},
390
                                  ref_call->attrs, ref_call->type_args);
391 392
    rnode->scale = srhs->scale;
    rnode->axes = srhs->axes;
393
  }
394
  return Expr(rnode);
395 396 397 398 399 400
}

RELAY_REGISTER_OP("add")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("add")
401
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
402 403 404 405 406

RELAY_REGISTER_OP("subtract")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("subtract")
407
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);
408 409 410

// Producer operators
// Multiply produces the scale-axis pair.
411 412
Expr MultiplyForwardRewrite(const Call& ref_call,
                            const Array<Expr>& new_args,
413 414 415 416
                            const Message& message) {
  if (!message.defined()) return Expr();
  const auto& expected_out_axes = message->axes;
  CHECK(expected_out_axes.defined() && expected_out_axes.size());
417 418
  // TODO(tvm-team) allow same axes accumulation
  // not as important because it is less common in nn.
419 420 421 422
  const auto* slhs = new_args[0].as<ScaledExprNode>();
  const auto* srhs = new_args[1].as<ScaledExprNode>();
  CHECK(!slhs && !srhs);

423 424
  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
425 426 427
  Expr lhs = new_args[0];
  Expr rhs = new_args[1];
  auto rnode = make_node<ScaledExprNode>();
428

429
  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
430
      (!message->require_positive || IsAllPositiveConstant(rhs))) {
431 432 433
    rnode->value = lhs;
    rnode->scale = rhs;
    rnode->axes = expected_out_axes;
434 435
    return Expr(rnode);
  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
436
             (!message->require_positive || IsAllPositiveConstant(lhs))) {
437 438 439
    rnode->value = rhs;
    rnode->scale = lhs;
    rnode->axes = expected_out_axes;
440 441 442
    return Expr(rnode);
  } else {
    return Expr();
443 444 445 446
  }
}

RELAY_REGISTER_OP("multiply")
447
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);
448 449 450

// Consumer operators
// Conv2D send out requirement of axis folding.
451
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
452 453 454 455 456
  // TODO(tvm-team) support general data layout
  // by transforming weight
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
457
  Layout kernel_layout(param->kernel_layout);
458 459
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
460 461

  CHECK_GE(c_big_axis, 0);
462
  Message none = NullValue<Message>();
463 464 465 466 467 468 469 470
  AxesSet data_axes = NullValue<AxesSet>();
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
471
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
472
  if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
473 474 475 476
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
    data_axes = {c_big_axis};
  }
477 478 479 480
  if (data_axes.defined()) {
    return {MessageNode::make(data_axes, false), none};
  }
  return {none, none};
481 482 483
}

// Conv2D consumes the scale axis during transformation.
484 485
Expr Conv2DForwardRewrite(const Call& ref_call,
                          const Array<Expr>& new_args,
486
                          const Message& message) {
487
  // if data do not have scale, normal transform path.
488 489 490 491
  const auto* sdata = new_args[0].as<ScaledExprNode>();
  const auto* sweight = new_args[1].as<ScaledExprNode>();
  if (sdata == nullptr) return Expr();
  if (sweight != nullptr) return Expr();
492 493 494
  const auto* param = ref_call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
  Layout data_layout(param->data_layout);
495
  Layout kernel_layout(param->kernel_layout);
496
  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
497 498 499
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
500
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
501 502
  CHECK(sdata->axes.size() == 1 &&
        c_big_axis == sdata->axes[0]->value);
503 504
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
  int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
505 506

  // Check it must be depthwise or full conv2d.
507
  bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
508
  CHECK(param->groups == 1 || is_depthwise_conv2d);
509 510

  Expr weight = new_args[1];
511 512

  // match the ic_axis
513 514
  if (is_depthwise_conv2d) {
    Expr scale = ExpandBiasToMatchAxis(
515
        sdata->scale, kernel_layout.ndim(), {big_oc_axis});
516 517 518
    weight = Multiply(weight, scale);
  } else {
    Expr scale = ExpandBiasToMatchAxis(
519
        sdata->scale, kernel_layout.ndim(), {big_ic_axis});
520 521
    weight = Multiply(weight, scale);
  }
522
  // return transformed conv2d
523
  return CallNode::make(
524 525 526 527 528 529 530
      ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
531
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
532 533


534
Expr ForwardFoldScaleAxis(const Expr& data) {
535
  auto message = ForwardPrep().Prepare(data);
536
  auto fcontext = [&](const Call& call) -> NodeRef{
537 538
    auto it = message.find(call.get());
    if (it != message.end()) {
539 540 541 542 543 544 545
      return it->second;
    } else {
      return NodeRef(nullptr);
    }
  };
  return ForwardRewrite(
      data, "FScaleAxisForwardRewrite", fcontext);
546 547
}

548 549 550 551 552 553 554 555
//----------------------------------------
// Implement backward transformations.
//----------------------------------------
class BackwardTransformer;

/*!
 * \brief Preparation function for for pass scale backward.
 * \param call The call node.
556 557 558
 * \param in_messages Messages from the input containing allowed input scaling and whether
 *        positive scale is required.
 * \return Message containing the result scaling on axes of the input.
559 560
 */
using FBackwardPrep = TypedPackedFunc<
561
  Message(const Call& call, const Array<Message>& in_messages)>;
562 563 564

using FBackwardTransform = TypedPackedFunc<
  Expr(const Call& call,
565
       const Message& message,
566 567 568 569 570 571 572 573 574 575
       const Expr& scale,
       const BackwardTransformer& transformer)>;

//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------

class BackwardPrep : private ExprVisitor {
 public:
  // The message on each node.
576
  std::unordered_map<const Node*, Message>
577 578 579 580 581 582 583 584
  Prepare(const Expr& body) {
    ref_counter_ = GetExprRefCount(body);
    this->VisitExpr(body);
    return std::move(message_);
  }

 private:
  // The message on each node.
585
  std::unordered_map<const Node*, Message> message_;
586 587 588 589 590 591 592
  // reference counter of an internal expr
  std::unordered_map<const Node*, size_t> ref_counter_;
  // Visit the expression.
  void VisitExpr_(const CallNode* call) {
    ExprVisitor::VisitExpr_(call);
    static const auto& fprep =
        Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
593
    auto f = fprep.get(call->op, nullptr);
594 595 596 597 598 599
    if (f == nullptr) return;
    auto rit = ref_counter_.find(call);
    CHECK(rit != ref_counter_.end());
    // We only allow propagation of scale backward
    // if the expression is only referred by a single parent.
    if (rit->second != 1) return;
600
    Array<Message> in_messages;
601 602 603
    for (Expr arg : call->args) {
      auto it = message_.find(arg.get());
      if (it != message_.end()) {
604
        in_messages.push_back(it->second);
605
      } else {
606
        in_messages.push_back(NullValue<Message>());
607 608
      }
    }
609 610 611
    Message out_message = f(GetRef<Call>(call), in_messages);
    if (out_message.defined()) {
      message_[call] = out_message;
612 613 614 615 616 617 618 619 620 621
    }
  }
};

class BackwardTransformerNode :
      public Node,
      private ExprMutator {
 public:
  // Run forward transform.
  Expr Fold(Expr expr) {
622
    message_ = BackwardPrep().Prepare(expr);
623 624 625 626 627 628 629 630 631 632
    return this->Mutate(expr);
  }
  /*!
   * \brief Transform the expr to consider the scaling.
   *
   * \param expr The input expression.
   * \param axes The axes to scale.
   * \param scale The scale applied to the axes.
   * \return The result of transformation.
   */
633
  Expr Transform(const Expr& expr, Message message, Expr scale) {
634
    // NOTE: the result of Transform is memoized.
635
    if (const CallNode* call_node = expr.as<CallNode>()) {
636
      return Transform(call_node, message, scale);
637
    } else {
638
      CHECK(!message.defined()) << "outstanding scale";
639 640 641 642 643 644 645 646 647
      return ExprMutator::VisitExpr(expr);
    }
  }
  /*!
   * \brief Normal way of mutating call node.
   * \param call_node The call node to be mutated.
   * \return the result of the call Mutation.
   */
  Expr NormalCallTransform(const CallNode* call_node) {
648 649 650 651 652 653 654 655
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = ExprMutator::VisitExpr_(call_node);
    memo_[call] = new_expr;
    return new_expr;
656 657
  }
  /*!
658
   * \brief Get the message propogated to the expr.
659
   * \param expr The expresison.
660
   * \return The message containing the expected axes and whether positive scale is required.
661
   */
662 663 664 665
  Message GetMessage(const Expr& expr) const {
    auto it = message_.find(expr.get());
    if (it != message_.end()) return it->second;
    return NullValue<Message>();
666 667 668 669 670 671 672 673 674 675
  }

  // solver is not serializable.
  void VisitAttrs(tvm::AttrVisitor* v) final {}

  static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
  TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);

 private:
  // Valid axes on each node.
676
  std::unordered_map<const Node*, Message> message_;
677 678
  // Override mutation of call.
  Expr VisitExpr_(const CallNode* call_node) final {
679
    return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
680 681
  }
  // Transform of CallNode.
682
  Expr Transform(const CallNode* call_node, Message message, Expr scale);
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
};

class BackwardTransformer : public NodeRef {
 public:
  BackwardTransformer() {}
  explicit BackwardTransformer(
      ::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
  }
  BackwardTransformerNode* operator->() const {
    return static_cast<BackwardTransformerNode*>(node_.get());
  }
  using ContainerType = BackwardTransformerNode;
};

Expr BackwardTransformerNode::Transform(
698
    const CallNode* call_node, Message message, Expr scale) {
699 700
  static const auto& ftransform =
      Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
701
  auto f = ftransform.get(call_node->op, nullptr);
702
  if (f != nullptr) {
703 704 705 706 707 708
    const Call call = GetRef<Call>(call_node);
    const auto it = memo_.find(call);
    if (it != memo_.end()) {
      return it->second;
    }
    Expr new_expr = f(GetRef<Call>(call_node),
709
                      message,
710 711 712 713
                      scale,
                      GetRef<BackwardTransformer>(this));
    memo_[call] = new_expr;
    return new_expr;
714
  } else {
715
    CHECK(!message.defined()) << "outstanding scale";
716 717 718 719 720 721 722 723 724 725
    return NormalCallTransform(call_node);
  }
}


//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------

// Intermediate operators
726 727 728 729 730
Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
  if (in_messages[0].defined()) {
    return MessageNode::make(in_messages[0]->axes, true);
  }
  return in_messages[0];
731 732 733
}

Expr ReluBackwardTransform(const Call& call,
734
                           const Message& message,
735 736
                           const Expr& scale,
                           const BackwardTransformer& transformer) {
737
  if (!message.defined()) {
738 739 740
    return transformer->NormalCallTransform(call.operator->());
  }
  Expr input = transformer->Transform(
741
      call->args[0], message, scale);
742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
  return CallNode::make(call->op, {input}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

// AddSub
758
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
759 760 761
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
  AttrsEqual equal;
762 763 764 765 766 767 768 769 770
  if (in_messages[0].defined() &&
      MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
    return in_messages[0];
  } else if (in_messages[1].defined() &&
             MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) {
    return in_messages[1];
  } else if (in_messages[0].defined() &&
             in_messages[1].defined() &&
             equal(in_messages[0]->axes, in_messages[1]->axes) &&
771 772
             equal(tlhs->shape, trhs->shape)) {
    // add of two elements.
773
    return in_messages[0];
774
  } else {
775
    auto res = NullValue<Message>();
776
    return res;
777 778 779 780
  }
}

Expr AddSubBackwardTransform(const Call& call,
781
                             const Message& message,
782 783 784 785
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
786
  if (!message.defined()) {
787 788
    return transformer->NormalCallTransform(call.operator->());
  }
789 790
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
791 792
  AttrsEqual equal;

793 794 795 796 797
  if (lhs_message.defined() && rhs_message.defined()) {
    CHECK(equal(lhs_message->axes, rhs_message->axes));
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
    Expr rhs = transformer->Transform(call->args[1], message, scale);
798
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
799 800 801
  } else if (lhs_message.defined()) {
    CHECK(equal(message->axes, lhs_message->axes));
    Expr lhs = transformer->Transform(call->args[0], message, scale);
802
    Expr rhs = transformer->Transform(
803
        call->args[1], NullValue<Message>(), NullValue<Expr>());
804
    Expr rhs_scale = ExpandBiasToMatchAxis(
805
        scale, tlhs->shape.size(), message->axes);
806 807
    rhs = Multiply(rhs, rhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
808 809
  } else if (rhs_message.defined()) {
    CHECK(equal(message->axes, rhs_message->axes));
810
    Expr lhs = transformer->Transform(
811 812
        call->args[0], NullValue<Message>(), NullValue<Expr>());
    Expr rhs = transformer->Transform(call->args[1], message, scale);
813
    Expr lhs_scale = ExpandBiasToMatchAxis(
814
        scale, trhs->shape.size(), message->axes);
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
    lhs = Multiply(lhs, lhs_scale);
    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
  } else {
    LOG(FATAL) << "outstanding scale";
    return Expr();
  }
}

RELAY_REGISTER_OP("add")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("subtract")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

// Producer operators
// Multiply produces the scale-axis pair.
Expr MultiplyBackwardTransform(const Call& call,
838
                               const Message& message,
839 840
                               const Expr& scale,
                               const BackwardTransformer& transformer) {
841
  CHECK(!message.defined()) << "outstanding scale";
842 843
  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
844 845 846 847
  Message lhs_message = transformer->GetMessage(call->args[0]);
  Message rhs_message = transformer->GetMessage(call->args[1]);
  if (lhs_message.defined()) {
    CHECK(lhs_message->axes.defined() && lhs_message->axes.size());
848 849 850
    // NOTE we won't recursively call mutating on scale part.
    // since there  won't be scale chance within scale part.
    Expr rhs = call->args[1];
851 852 853
    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) &&
        (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) {
      return transformer->Transform(call->args[0], lhs_message, rhs);
854
    }
855 856
  } else if (rhs_message.defined()) {
    CHECK(rhs_message->axes.defined() && rhs_message->axes.size());
857
    Expr lhs = call->args[0];
858 859 860
    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) &&
        (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) {
      return transformer->Transform(call->args[1], rhs_message, lhs);
861 862 863 864 865 866 867 868 869 870
    }
  }
  return transformer->NormalCallTransform(call.operator->());
}

RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);

// Consumer operators
// Conv2D send out requirement of axis folding.
871
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
872 873
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
874 875
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
876 877
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
  int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
878 879 880 881 882 883 884 885 886

  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // More general layout can be supported under the current framework.
  // By using a unified layout transformation.
  // We only need to change the Prep and Mutate function.
  //
  // only handle depthwise or full conv2d.
  // TODO(tvm-team) handle grouped conv by reshape + bcast
887
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
888 889
  if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
  kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
890 891
      c_small_axis < 0 &&
      (param->groups == 1 || is_depthwise_conv2d)) {
892
    return MessageNode::make({c_big_axis}, false);
893
  } else {
894
    return NullValue<Message>();
895 896 897 898 899
  }
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call,
900
                             const Message& message,
901 902
                             const Expr& scale,
                             const BackwardTransformer& transformer) {
903
  if (!message.defined()) {
904 905 906 907
    return transformer->NormalCallTransform(call.operator->());
  }
  const auto* param = call->attrs.as<Conv2DAttrs>();
  CHECK(param != nullptr);
908 909
  Layout kernel_layout(param->kernel_layout);
  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
910
  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
911 912 913
  CHECK_GE(c_big_axis, 0);
  // For now, we only support simple pattern (no folded weight/data)
  // TODO(tvm-team) support general data layout
914 915
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
916 917
  CHECK(message->axes.size() == 1 &&
        c_big_axis == message->axes[0]->value);
918

919
  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
920
  // Check it must be depthwise or full conv2d.
921
  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
922 923 924
  CHECK(param->groups == 1 || is_depthwise_conv2d);

  Expr data = transformer->Transform(
925
      call->args[0], NullValue<Message>(), NullValue<Expr>());
926
  Expr weight = transformer->Transform(
927
      call->args[1], NullValue<Message>(), NullValue<Expr>());
928 929
  // scale on input for deptwise.
  Expr wscale = ExpandBiasToMatchAxis(
930
      scale, kernel_layout.ndim(), {big_oc_axis});
931 932 933 934 935 936 937 938 939 940 941
  weight = Multiply(weight, wscale);
  return CallNode::make(
      call->op, {data, weight}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);

942
Expr BackwardFoldScaleAxis(const Expr& data) {
943 944 945
  return make_node<BackwardTransformerNode>()->Fold(data);
}

946
}  // namespace fold_scale_axis
947 948 949 950 951 952 953 954 955 956 957 958 959

namespace transform {

Pass ForwardFoldScaleAxis() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(
          relay::fold_scale_axis::ForwardFoldScaleAxis(f));
  };
  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
                            {ir::StringImm::make("InferType")});
}

Zhi committed
960 961 962
TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis")
.set_body_typed(ForwardFoldScaleAxis);

963 964 965 966 967 968 969 970 971 972
Pass BackwardFoldScaleAxis() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(
          relay::fold_scale_axis::BackwardFoldScaleAxis(f));
    };
  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
                            {ir::StringImm::make("InferType")});
}

Zhi committed
973 974 975
TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis")
.set_body_typed(BackwardFoldScaleAxis);

976 977 978 979 980 981 982 983 984 985 986 987 988 989
Pass FoldScaleAxis() {
  // FoldScaleAxis pass contains the following three passes. Therefore, we can
  // register it as a sequential pass.
  Pass pass = Sequential(
      {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()},
      "FoldScaleAxis");
  return pass;
}

TVM_REGISTER_API("relay._transform.FoldScaleAxis")
.set_body_typed(FoldScaleAxis);

}  // namespace transform

990 991
}  // namespace relay
}  // namespace tvm