forward_rewrite.cc 6.36 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26
/*!
 *
 * \file forward_rewrite.cc
 * \brief Apply rewriting rules in a forward fashion.
 */
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
Zhi committed
27
#include <tvm/relay/transform.h>
28
#include "pass_util.h"
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

namespace tvm {
namespace relay {

// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
 public:
  Expr Realize(Expr expr) {
    return VisitExpr(expr);
  }

 private:
  Expr VisitExpr(const Expr& expr) final {
    auto it = memo_.find(expr);
    if (it != memo_.end()) {
      return it->second;
    } else {
      Expr res;
49
      if (const auto* temp = expr.as<TempExprNode>()) {
50 51 52 53 54 55 56 57 58 59 60 61 62
        res = temp->Realize();

      } else {
        res = ExprFunctor::VisitExpr(expr);
      }
      memo_[res] = res;
      return res;
    }
  }
};

class ForwardRewriter : private ExprMutator {
 public:
63
  ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
64
                  std::function<ObjectRef(const Call&)> fcontext,
65
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
66
      : rewrite_map_(rewrite_map),
67
        fcontext_(fcontext),
68 69 70
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

  ForwardRewriter(const FForwardRewrite* rewrite_func,
71
                  std::function<ObjectRef(const Call&)> fcontext,
72 73 74 75 76
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
      : rewrite_func_(rewrite_func),
        fcontext_(fcontext),
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

77 78 79

  // Transform expression.
  Expr Rewrite(Expr expr) {
80 81 82
    if (fmulti_ref_trigger_ != nullptr) {
      ref_counter_ = GetExprRefCount(expr);
    }
83 84 85 86 87
    return this->VisitExpr(expr);
  }

 private:
  // The rewrite rule.
88 89 90
  const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
  const FForwardRewrite* rewrite_func_{nullptr};
  // The context.const
91
  std::function<ObjectRef(const Call&)> fcontext_{nullptr};
92 93 94
  // The multiple reference trigger
  std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
  // Internal ref counter
95
  std::unordered_map<const Object*, size_t> ref_counter_;
96 97 98 99 100 101 102 103 104 105
  // internal realizer
  TempRealizer realizer_;

  Expr VisitExpr(const Expr& expr) final {
    // by default always realize.
    return realizer_.Realize(ExprMutator::VisitExpr(expr));
  }

  // Visit and allow non-realized version.
  Expr GetTempExpr(const Expr& expr)  {
106 107 108 109 110 111 112 113 114 115 116
    if (fmulti_ref_trigger_ != nullptr) {
      Expr ret = ExprMutator::VisitExpr(expr);
      auto it = ref_counter_.find(expr.get());
      CHECK(it != ref_counter_.end());
      if (it->second > 1) {
        ret = fmulti_ref_trigger_(ret);
      }
      return ret;
    } else {
      return ExprMutator::VisitExpr(expr);
    }
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  }

  // Automatic fold TupleGetItem.
  Expr VisitExpr_(const TupleGetItemNode* op) final {
    Expr tuple = this->GetTempExpr(op->tuple);
    if (const auto* ptuple = tuple.as<TupleNode>()) {
      return ptuple->fields[op->index];
    } else {
      if (tuple.same_as(op->tuple)) {
        return GetRef<Expr>(op);
      } else {
        return TupleGetItemNode::make(tuple, op->index);
      }
    }
  }

133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  Expr VisitExpr_(const TupleNode* op) final {
    tvm::Array<Expr> fields;
    bool all_fields_unchanged = true;
    for (auto field : op->fields) {
      auto new_field = this->GetTempExpr(field);
      fields.push_back(new_field);
      all_fields_unchanged &= new_field.same_as(field);
    }

    if (all_fields_unchanged) {
      return GetRef<Expr>(op);
    } else {
      return TupleNode::make(fields);
    }
  }

149 150
  Expr VisitExpr_(const CallNode* call_node) final {
    const Call& ref_call = GetRef<Call>(call_node);
151 152 153 154 155 156 157
    PackedFunc frewrite;
    if (rewrite_func_) {
      frewrite = *rewrite_func_;
    } else {
      CHECK(rewrite_map_);
      frewrite = rewrite_map_->get(call_node->op, nullptr);
    }
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

    auto new_op = this->Mutate(call_node->op);
    bool unchanged = call_node->op.same_as(new_op);

    Array<Expr> call_args;
    for (auto arg : call_node->args) {
      Expr new_arg = this->GetTempExpr(arg);
      if (frewrite == nullptr) {
        new_arg = realizer_.Realize(new_arg);
      }
      unchanged &= new_arg.same_as(arg);
      call_args.push_back(new_arg);
    }
    // try to rewrite.
    if (frewrite != nullptr) {
      Expr res = frewrite(
          ref_call, call_args,
175
          fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
      if (res.defined()) return res;
      // abort, use old rule
      for (size_t i = 0; i < call_args.size(); ++i) {
        Expr arg = call_args[i];
        Expr new_arg = realizer_.Realize(arg);
        if (!arg.same_as(new_arg)) {
          call_args.Set(i, new_arg);
          unchanged = false;
        }
      }
    }
    if (unchanged) return ref_call;
    return CallNode::make(
        new_op, call_args, call_node->attrs, call_node->type_args);
  }
};

Expr ForwardRewrite(const Expr& expr,
                    const std::string& rewrite_map_name,
195
                    std::function<ObjectRef(const Call&)> fcontext,
196
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
197
  auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
198 199 200 201 202
  return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}

Expr ForwardRewrite(const Expr& expr,
                    const FForwardRewrite& rewrite_func,
203
                    std::function<ObjectRef(const Call&)> fcontext,
204 205
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
206
}
207

208 209
}  // namespace relay
}  // namespace tvm