eliminate_common_subexpr.cc 3.47 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 *
 * \file eliminate_common_subexpr.cc
 * \brief Combine common subexpressions.
 *
 * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
 * to replace an expression with a previously appeared expression with the same input and
 * attributes. The fskip callback argument allows us to skip specific expressions.
 */
Zhi committed
29
#include <tvm/relay/analysis.h>
30
#include <tvm/relay/expr_functor.h>
31
#include <tvm/relay/transform.h>
32
#include <unordered_map>
33
#include "pattern_util.h"
34 35 36 37 38 39 40 41 42 43 44 45 46 47

namespace tvm {
namespace relay {

class CommonSubexprEliminator : public ExprMutator {
 public:
  explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}

  Expr VisitExpr_(const CallNode* call) final {
    static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
    Expr new_expr = ExprMutator::VisitExpr_(call);
    const CallNode* new_call = new_expr.as<CallNode>();
    CHECK(new_call);
    const OpNode* op = new_call->op.as<OpNode>();
48
    StructuralEqual attrs_equal;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }

    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const CallNode* candidate : it->second) {
        bool is_equivalent = true;
        if (!attrs_equal(new_call->attrs, candidate->attrs)) {
          continue;
        }
        for (size_t i = 0; i < new_call->args.size(); i++) {
          if (!new_call->args[i].same_as(candidate->args[i]) &&
              !IsEqualScalar(new_call->args[i], candidate->args[i])) {
            is_equivalent = false;
            break;
          }
        }
        if (!is_equivalent) continue;
        return GetRef<Call>(candidate);
      }
    }
    expr_map_[new_call->op].push_back(new_call);
    return new_expr;
  }

79
  std::unordered_map<Expr, std::vector<const CallNode*>, ObjectHash, ObjectEqual> expr_map_;
80 81 82 83 84 85 86
  runtime::TypedPackedFunc<bool(Expr)> fskip_;
};

Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
  return CommonSubexprEliminator(callback)(expr);
}

87 88 89
namespace transform {

Pass EliminateCommonSubexpr(PackedFunc fskip) {
90 91
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
92 93
      return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
  };
94
  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
95 96
}

97
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
98 99 100 101
.set_body_typed(EliminateCommonSubexpr);

}  // namespace transform

102 103
}  // namespace relay
}  // namespace tvm