forward_rewrite.cc 6.39 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file forward_rewrite.cc
 * \brief Apply rewriting rules in a forward fashion.
 */
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
Zhi committed
28
#include <tvm/relay/transform.h>
29
#include "pass_util.h"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

namespace tvm {
namespace relay {

// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator {
 public:
  Expr Realize(Expr expr) {
    return VisitExpr(expr);
  }

 private:
  Expr VisitExpr(const Expr& expr) final {
    auto it = memo_.find(expr);
    if (it != memo_.end()) {
      return it->second;
    } else {
      Expr res;
50
      if (const auto* temp = expr.as<TempExprNode>()) {
51 52 53 54 55 56 57 58 59 60 61 62 63
        res = temp->Realize();

      } else {
        res = ExprFunctor::VisitExpr(expr);
      }
      memo_[res] = res;
      return res;
    }
  }
};

class ForwardRewriter : private ExprMutator {
 public:
64
  ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
65 66
                  std::function<NodeRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
67
      : rewrite_map_(rewrite_map),
68
        fcontext_(fcontext),
69 70 71 72 73 74 75 76 77
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

  ForwardRewriter(const FForwardRewrite* rewrite_func,
                  std::function<NodeRef(const Call&)> fcontext,
                  std::function<Expr(const Expr&)> fmulti_ref_trigger)
      : rewrite_func_(rewrite_func),
        fcontext_(fcontext),
        fmulti_ref_trigger_(fmulti_ref_trigger) {}

78 79 80

  // Transform expression.
  Expr Rewrite(Expr expr) {
81 82 83
    if (fmulti_ref_trigger_ != nullptr) {
      ref_counter_ = GetExprRefCount(expr);
    }
84 85 86 87 88
    return this->VisitExpr(expr);
  }

 private:
  // The rewrite rule.
89 90 91
  const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
  const FForwardRewrite* rewrite_func_{nullptr};
  // The context.const
92
  std::function<NodeRef(const Call&)> fcontext_{nullptr};
93 94 95 96
  // The multiple reference trigger
  std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
  // Internal ref counter
  std::unordered_map<const Node*, size_t> ref_counter_;
97 98 99 100 101 102 103 104 105 106
  // internal realizer
  TempRealizer realizer_;

  Expr VisitExpr(const Expr& expr) final {
    // by default always realize.
    return realizer_.Realize(ExprMutator::VisitExpr(expr));
  }

  // Visit and allow non-realized version.
  Expr GetTempExpr(const Expr& expr)  {
107 108 109 110 111 112 113 114 115 116 117
    if (fmulti_ref_trigger_ != nullptr) {
      Expr ret = ExprMutator::VisitExpr(expr);
      auto it = ref_counter_.find(expr.get());
      CHECK(it != ref_counter_.end());
      if (it->second > 1) {
        ret = fmulti_ref_trigger_(ret);
      }
      return ret;
    } else {
      return ExprMutator::VisitExpr(expr);
    }
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  }

  // Automatic fold TupleGetItem.
  Expr VisitExpr_(const TupleGetItemNode* op) final {
    Expr tuple = this->GetTempExpr(op->tuple);
    if (const auto* ptuple = tuple.as<TupleNode>()) {
      return ptuple->fields[op->index];
    } else {
      if (tuple.same_as(op->tuple)) {
        return GetRef<Expr>(op);
      } else {
        return TupleGetItemNode::make(tuple, op->index);
      }
    }
  }

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
  Expr VisitExpr_(const TupleNode* op) final {
    tvm::Array<Expr> fields;
    bool all_fields_unchanged = true;
    for (auto field : op->fields) {
      auto new_field = this->GetTempExpr(field);
      fields.push_back(new_field);
      all_fields_unchanged &= new_field.same_as(field);
    }

    if (all_fields_unchanged) {
      return GetRef<Expr>(op);
    } else {
      return TupleNode::make(fields);
    }
  }

150 151
  Expr VisitExpr_(const CallNode* call_node) final {
    const Call& ref_call = GetRef<Call>(call_node);
152 153 154 155 156 157 158
    PackedFunc frewrite;
    if (rewrite_func_) {
      frewrite = *rewrite_func_;
    } else {
      CHECK(rewrite_map_);
      frewrite = rewrite_map_->get(call_node->op, nullptr);
    }
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

    auto new_op = this->Mutate(call_node->op);
    bool unchanged = call_node->op.same_as(new_op);

    Array<Expr> call_args;
    for (auto arg : call_node->args) {
      Expr new_arg = this->GetTempExpr(arg);
      if (frewrite == nullptr) {
        new_arg = realizer_.Realize(new_arg);
      }
      unchanged &= new_arg.same_as(arg);
      call_args.push_back(new_arg);
    }
    // try to rewrite.
    if (frewrite != nullptr) {
      Expr res = frewrite(
          ref_call, call_args,
          fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
      if (res.defined()) return res;
      // abort, use old rule
      for (size_t i = 0; i < call_args.size(); ++i) {
        Expr arg = call_args[i];
        Expr new_arg = realizer_.Realize(arg);
        if (!arg.same_as(new_arg)) {
          call_args.Set(i, new_arg);
          unchanged = false;
        }
      }
    }
    if (unchanged) return ref_call;
    return CallNode::make(
        new_op, call_args, call_node->attrs, call_node->type_args);
  }
};

Expr ForwardRewrite(const Expr& expr,
                    const std::string& rewrite_map_name,
196 197
                    std::function<NodeRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
198
  auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
199 200 201 202 203 204 205 206
  return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}

Expr ForwardRewrite(const Expr& expr,
                    const FForwardRewrite& rewrite_func,
                    std::function<NodeRef(const Call&)> fcontext,
                    std::function<Expr(const Expr&)> fmulti_ref_trigger) {
  return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
207
}
208

209 210
}  // namespace relay
}  // namespace tvm