dead_code.cc 3.99 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30
/*!
 * Copyright (c) 2018 by Contributors
 *
 * \file dead_code.cc
 *
 * \brief Remove code that does not effect the program result.
 *
 * The algorithm is implemented by two visitor:
 * CalcDep turn an expr into a dependency graph of expr,
 * GenLet turn the dependency graph into a let list, taking only the used value.
 */
Zhi committed
31
#include <tvm/relay/analysis.h>
32
#include <tvm/relay/expr_functor.h>
Zhi committed
33
#include <tvm/relay/transform.h>
34 35 36 37 38
#include "let_list.h"

namespace tvm {
namespace relay {

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;

class CalcDep;
class FindDef : private ExprVisitor {
 private:
  VarMap<Expr> expr_map_;

  void VisitExpr_(const LetNode* l) final {
    CHECK_EQ(expr_map_.count(l->var), 0);
    expr_map_[l->var] = l->value;
    VisitExpr(l->value);
    VisitExpr(l->body);
  }

  friend CalcDep;
};

class Eliminator : private ExprMutator {
 private:
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
  bool inline_once_;
  explicit Eliminator(const VarMap<Expr>& expr_map,
                      const VarMap<size_t>& use_map,
                      bool inline_once) :
    expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
  friend CalcDep;

  bool HasLet(const Var& v) {
    switch (use_map_[v]) {
    case 0:
      return false;
    case 1:
      return !inline_once_;
    default:
      return true;
    }
  }

  Expr VisitExpr_(const VarNode* op) final {
    Var v = GetRef<Var>(op);
    return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
  }

  Expr VisitExpr_(const LetNode* op) final {
    Var v = op->var;
    if (HasLet(v)) {
      return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
    } else {
      return VisitExpr(op->body);
    }
  }
};

95
// calculate the dependency graph from expression
雾雨魔理沙 committed
96
class CalcDep : private ExprVisitor {
97
 public:
雾雨魔理沙 committed
98
  static Expr Eliminate(const Expr& e, bool inline_once) {
99 100 101 102 103
    FindDef fd;
    fd(e);
    CalcDep cd(fd.expr_map_);
    cd(e);
    Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
雾雨魔理沙 committed
104
    return el(e);
105 106 107
  }

 private:
108
  explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
雾雨魔理沙 committed
109 110
  VarMap<Expr> expr_map_;
  VarMap<size_t> use_map_;
111 112 113

  void VisitExpr(const Expr& e) final {
    return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
雾雨魔理沙 committed
114 115 116 117
  }

  void VisitExpr_(const LetNode* l) final {
    VisitExpr(l->body);
118 119
  }

雾雨魔理沙 committed
120 121
  void VisitExpr_(const VarNode* v) final {
    Var var = GetRef<Var>(v);
122 123 124
    ++use_map_[var];
    if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
      VisitExpr(expr_map_[var]);
雾雨魔理沙 committed
125 126
    }
  }
127 128
};

雾雨魔理沙 committed
129 130
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
  return CalcDep::Eliminate(e, inline_once);
131 132
}

133 134
namespace transform {

雾雨魔理沙 committed
135
Pass DeadCodeElimination(bool inline_once) {
136 137
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
雾雨魔理沙 committed
138
    return Downcast<Function>(DeadCodeElimination(f, inline_once));
139
  };
140
  return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
141 142
}

143 144 145
TVM_REGISTER_API("relay._transform.DeadCodeElimination")
.set_body_typed(DeadCodeElimination);

146 147
}  // namespace transform

148 149
}  // namespace relay
}  // namespace tvm