simple_passes.cc 3.21 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2016 by Contributors
 * \file simple_passes.cc
 * \brief Implementation of simple passes
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
8
#include <tvm/ir_mutator.h>
9 10 11 12 13 14 15 16 17
#include <tvm/ir_pass.h>

namespace tvm {
namespace ir {

class IRSideEffect : public IRVisitor {
 public:
  void Visit(const NodeRef& e) final {
    if (has_side_effect_) return;
18
    IRVisitor::Visit(e);
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
  }

  void Visit_(const Call* op) final {
    if (!op->is_pure()) {
      has_side_effect_ = true; return;
    } else {
      IRVisitor::Visit_(op);
    }
  }

  bool has_side_effect_{false};
};

bool HasSideEffect(const Expr& e) {
  IRSideEffect v;
  v.Visit(e);
  return v.has_side_effect_;
}
37 38 39

class IRSubstitue : public IRMutator {
 public:
40 41 42 43 44
  explicit IRSubstitue(
      const std::unordered_map<const Variable*, Expr>& smap)
      : smap_(smap) {
  }

45
  Expr Mutate_(const Variable* op, const Expr& e) final {
46 47
    auto it = smap_.find(op);
    if (it != smap_.end()) {
48 49 50 51 52
      return it->second;
    } else {
      return e;
    }
  }
53 54 55

 private:
  const std::unordered_map<const Variable*, Expr>& smap_;
56 57
};

58 59
Stmt Substitute(Stmt stmt,
                const std::unordered_map<const Variable*, Expr>& value_map) {
Tianqi Chen committed
60
  if (value_map.size() == 0) return stmt;
61 62 63 64 65 66 67 68 69 70 71 72 73
  return IRSubstitue(value_map).Mutate(stmt);
}

Expr Substitute(Expr expr,
                const std::unordered_map<const Variable*, Expr>& value_map) {
  if (value_map.size() == 0) return expr;
  return IRSubstitue(value_map).Mutate(expr);
}

Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
  std::unordered_map<const Variable*, Expr> vmap;
  for (const auto& kv : value_map) {
    vmap[kv.first.get()] = kv.second;
74
  }
75
  return Substitute(stmt, vmap);
76
}
77

ziheng committed
78
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
79 80 81
  std::unordered_map<const Variable*, Expr> vmap;
  for (const auto& kv : value_map) {
    vmap[kv.first.get()] = kv.second;
ziheng committed
82
  }
83
  return Substitute(expr, vmap);
ziheng committed
84 85
}

86
class VarTouchVisitor : public IRVisitor {
87 88 89 90 91 92 93
 public:
  void Visit(const NodeRef& e) final {
    if (use_var_) return;
    IRVisitor::Visit(e);
  }

  void Visit_(const Variable* op) final {
94
    Handle(op);
95 96 97
  }

  void Visit_(const Load* op) final {
98
    Handle(op->buffer_var.get());
99 100 101
    IRVisitor::Visit_(op);
  }

102 103
  virtual void Handle(const Variable* var) = 0;

104 105 106
  bool use_var_{false};
};

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
class ExprUseVarVisitor : public VarTouchVisitor {
 public:
  explicit ExprUseVarVisitor(const Variable* var)
      : var_(var) {}

  void Handle(const Variable* var) final {
    if (var == var_) use_var_ = true;
  }
 private:
  const Variable* var_;
};

class ExprUseVSetVisitor : public VarTouchVisitor {
 public:
  explicit ExprUseVSetVisitor(
      const std::unordered_set<const Variable*>& vset)
      : vset_(vset) {}

  void Handle(const Variable* var) final {
    if (vset_.count(var)) use_var_ = true;
  }
 private:
  const std::unordered_set<const Variable*>& vset_;
};

132 133 134 135 136 137
bool ExprUseVar(const Expr& e, const Var& v) {
  ExprUseVarVisitor visitor(v.get());
  visitor.Visit(e);
  return visitor.use_var_;
}

138 139 140 141 142 143 144
bool ExprUseVar(const Expr& e,
                const std::unordered_set<const Variable*>& vset) {
  ExprUseVSetVisitor visitor(vset);
  visitor.Visit(e);
  return visitor.use_var_;
}

145 146
}  // namespace ir
}  // namespace tvm