dead_code.cc 4.21 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *
 * \file dead_code.cc
 *
 * \brief Remove code that does not effect the program result.
 *
 * The algorithm is implemented by two visitor:
 * CalcDep turn an expr into a dependency graph of expr,
 * GenLet turn the dependency graph into a let list, taking only the used value.
 */
Zhi committed
30
#include <tvm/relay/analysis.h>
31
#include <tvm/relay/expr_functor.h>
Zhi committed
32
#include <tvm/relay/transform.h>
33 34 35 36 37
#include "let_list.h"

namespace tvm {
namespace relay {

38
template<typename X>
39 40
using VarMap = std::unordered_map<Var, X, ObjectHash, ObjectEqual>;
using VarSet = std::unordered_set<Var, ObjectHash, ObjectEqual>;
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

class CalcDep;
class FindDef : private ExprVisitor {
 private:
  VarMap<Expr> expr_map_;

  void VisitExpr_(const LetNode* l) final {
    CHECK_EQ(expr_map_.count(l->var), 0);
    expr_map_[l->var] = l->value;
    VisitExpr(l->value);
    VisitExpr(l->body);
  }

  friend CalcDep;
};

class Eliminator : private ExprMutator {
 private:
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
  bool inline_once_;
  explicit Eliminator(const VarMap<Expr>& expr_map,
                      const VarMap<size_t>& use_map,
                      bool inline_once) :
    expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
  friend CalcDep;

  bool HasLet(const Var& v) {
    switch (use_map_[v]) {
    case 0:
      return false;
    case 1:
      return !inline_once_;
    default:
      return true;
    }
  }

  Expr VisitExpr_(const VarNode* op) final {
    Var v = GetRef<Var>(op);
    return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
  }

  Expr VisitExpr_(const LetNode* op) final {
    Var v = op->var;
    if (HasLet(v)) {
      return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
    } else {
      return VisitExpr(op->body);
    }
  }
};

94
// calculate the dependency graph from expression
雾雨魔理沙 committed
95
class CalcDep : private ExprVisitor {
96
 public:
雾雨魔理沙 committed
97
  static Expr Eliminate(const Expr& e, bool inline_once) {
98 99 100 101 102
    FindDef fd;
    fd(e);
    CalcDep cd(fd.expr_map_);
    cd(e);
    Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
雾雨魔理沙 committed
103
    return el(e);
104 105 106
  }

 private:
107
  explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
雾雨魔理沙 committed
108 109
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
110 111

  void VisitExpr(const Expr& e) final {
112 113 114 115 116 117 118 119 120
    visit_counter_[e.get()]++;
    // The dce code seprate variable into three parts:
    // used 0 times (remove)
    // used 1 times (inline)
    // used 2 times (dont do anything).
    if (visit_counter_[e.get()] <= 2) {
      using TParent = ExprFunctor<void(const Expr&)>;
      TParent::VisitExpr(e);
    }
雾雨魔理沙 committed
121 122 123 124
  }

  void VisitExpr_(const LetNode* l) final {
    VisitExpr(l->body);
125 126
  }

雾雨魔理沙 committed
127 128
  void VisitExpr_(const VarNode* v) final {
    Var var = GetRef<Var>(v);
129 130 131
    ++use_map_[var];
    if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
      VisitExpr(expr_map_[var]);
雾雨魔理沙 committed
132 133
    }
  }
134 135
};

雾雨魔理沙 committed
136 137
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
  return CalcDep::Eliminate(e, inline_once);
138 139
}

140 141
namespace transform {

雾雨魔理沙 committed
142
Pass DeadCodeElimination(bool inline_once) {
143 144
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
雾雨魔理沙 committed
145
    return Downcast<Function>(DeadCodeElimination(f, inline_once));
146
  };
147
  return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
148 149
}

150
TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination")
151 152
.set_body_typed(DeadCodeElimination);

153 154
}  // namespace transform

155 156
}  // namespace relay
}  // namespace tvm