forward_rewrite.cc 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file forward_rewrite.cc
 * \brief Apply rewriting rules in a forward fashion.
 */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
10
#include "pass_util.h"
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace tvm {
namespace relay {

// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
 public:
  Expr Realize(Expr expr) {
    return VisitExpr(expr);
  }

 private:
  Expr VisitExpr(const Expr& expr) final {
    auto it = memo_.find(expr);
    if (it != memo_.end()) {
      return it->second;
    } else {
      Expr res;
      if (const auto* temp = expr.as_derived<TempExprNode>()) {
        res = temp->Realize();

      } else {
        res = ExprFunctor::VisitExpr(expr);
      }
      memo_[res] = res;
      return res;
    }
  }
};

class ForwardRewriter : private ExprMutator {
 public:
45
  ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
46 47
                  std::function<NodeRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
48
      : rewrite_map_(rewrite_map),
49
        fcontext_(fcontext),
50 51 52 53 54 55 56 57 58
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

  ForwardRewriter(const FForwardRewrite* rewrite_func,
                  std::function<NodeRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
      : rewrite_func_(rewrite_func),
        fcontext_(fcontext),
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

59 60 61

  // Transform expression.
  Expr Rewrite(Expr expr) {
62 63 64
    if (fmulti_ref_trigger_ != nullptr) {
      ref_counter_ = GetExprRefCount(expr);
    }
65 66 67 68 69
    return this->VisitExpr(expr);
  }

 private:
  // The rewrite rule.
70 71 72
  const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
  const FForwardRewrite* rewrite_func_{nullptr};
  // The context.const
73
  std::function<NodeRef(const Call&)> fcontext_{nullptr};
74 75 76 77
  // The multiple reference trigger
  std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
  // Internal ref counter
  std::unordered_map<const Node*, size_t> ref_counter_;
78 79 80 81 82 83 84 85 86 87
  // internal realizer
  TempRealizer realizer_;

  Expr VisitExpr(const Expr& expr) final {
    // by default always realize.
    return realizer_.Realize(ExprMutator::VisitExpr(expr));
  }

  // Visit and allow non-realized version.
  Expr GetTempExpr(const Expr& expr)  {
88 89 90 91 92 93 94 95 96 97 98
    if (fmulti_ref_trigger_ != nullptr) {
      Expr ret = ExprMutator::VisitExpr(expr);
      auto it = ref_counter_.find(expr.get());
      CHECK(it != ref_counter_.end());
      if (it->second > 1) {
        ret = fmulti_ref_trigger_(ret);
      }
      return ret;
    } else {
      return ExprMutator::VisitExpr(expr);
    }
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  }

  // Automatic fold TupleGetItem.
  Expr VisitExpr_(const TupleGetItemNode* op) final {
    Expr tuple = this->GetTempExpr(op->tuple);
    if (const auto* ptuple = tuple.as<TupleNode>()) {
      return ptuple->fields[op->index];
    } else {
      if (tuple.same_as(op->tuple)) {
        return GetRef<Expr>(op);
      } else {
        return TupleGetItemNode::make(tuple, op->index);
      }
    }
  }

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  Expr VisitExpr_(const TupleNode* op) final {
    tvm::Array<Expr> fields;
    bool all_fields_unchanged = true;
    for (auto field : op->fields) {
      auto new_field = this->GetTempExpr(field);
      fields.push_back(new_field);
      all_fields_unchanged &= new_field.same_as(field);
    }

    if (all_fields_unchanged) {
      return GetRef<Expr>(op);
    } else {
      return TupleNode::make(fields);
    }
  }

131 132
  Expr VisitExpr_(const CallNode* call_node) final {
    const Call& ref_call = GetRef<Call>(call_node);
133 134 135 136 137 138 139
    PackedFunc frewrite;
    if (rewrite_func_) {
      frewrite = *rewrite_func_;
    } else {
      CHECK(rewrite_map_);
      frewrite = rewrite_map_->get(call_node->op, nullptr);
    }
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

    auto new_op = this->Mutate(call_node->op);
    bool unchanged = call_node->op.same_as(new_op);

    Array<Expr> call_args;
    for (auto arg : call_node->args) {
      Expr new_arg = this->GetTempExpr(arg);
      if (frewrite == nullptr) {
        new_arg = realizer_.Realize(new_arg);
      }
      unchanged &= new_arg.same_as(arg);
      call_args.push_back(new_arg);
    }
    // try to rewrite.
    if (frewrite != nullptr) {
      Expr res = frewrite(
          ref_call, call_args,
          fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
      if (res.defined()) return res;
      // abort, use old rule
      for (size_t i = 0; i < call_args.size(); ++i) {
        Expr arg = call_args[i];
        Expr new_arg = realizer_.Realize(arg);
        if (!arg.same_as(new_arg)) {
          call_args.Set(i, new_arg);
          unchanged = false;
        }
      }
    }
    if (unchanged) return ref_call;
    return CallNode::make(
        new_op, call_args, call_node->attrs, call_node->type_args);
  }
};

Expr ForwardRewrite(const Expr& expr,
                    const std::string& rewrite_map_name,
177 178
                    std::function<NodeRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
179
  auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
180 181 182 183 184 185 186 187
  return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}

Expr ForwardRewrite(const Expr& expr,
                    const FForwardRewrite& rewrite_func,
                    std::function<NodeRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
188
}
189 190


191 192
}  // namespace relay
}  // namespace tvm