ssa.cc 6.07 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20 21 22
/*!
 *  Copyright (c) 2016 by Contributors
 *  SSA related checks and pass.
23 24
 *
 *  SSA requires each varaible to be only defined once.
tqchen committed
25 26 27 28 29 30 31 32 33 34 35 36 37
 * \file ssa.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace ir {
namespace {
38
class IRVerifySSA final : public IRVisitor {
tqchen committed
39 40 41
 public:
  bool is_ssa{true};

42
  void Visit(const NodeRef& n) final {
tqchen committed
43 44 45
    if (!is_ssa) return;
    IRVisitor::Visit(n);
  }
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  void Visit_(const Let* op) final {
    MarkDef(op->var.get());
    IRVisitor::Visit_(op);
  }
  void Visit_(const LetStmt* op) final {
    MarkDef(op->var.get());
    IRVisitor::Visit_(op);
  }
  void Visit_(const For* op) final {
    MarkDef(op->loop_var.get());
    IRVisitor::Visit_(op);
  }
  void Visit_(const Allocate* op) final {
    MarkDef(op->buffer_var.get());
    IRVisitor::Visit_(op);
  }
tqchen committed
62 63

 private:
64 65 66 67 68 69 70
  void MarkDef(const Variable* v) {
    if (defined_.count(v) != 0) {
      is_ssa = false; return;
    } else {
      defined_[v] = 1;
    }
  }
tqchen committed
71 72 73
  std::unordered_map<const Variable*, int> defined_;
};

74
class IRConvertSSA final : public IRMutator {
tqchen committed
75
 public:
76 77 78 79 80 81 82 83 84 85 86 87
  Expr Mutate_(const Variable* op, const Expr& e) final {
    if (scope_.count(op)) {
      return scope_[op].back();
    } else {
      return e;
    }
  }
  Expr Mutate_(const Let* op, const Expr& e) final {
    const VarExpr& v = op->var;
    if (defined_.count(v.get())) {
      Expr value = IRMutator::Mutate(op->value);
      VarExpr new_var = Variable::make(v.type(), v->name_hint);
tqchen committed
88
      scope_[v.get()].push_back(new_var);
89
      Expr body = IRMutator::Mutate(op->body);
tqchen committed
90
      scope_[v.get()].pop_back();
91
      return Let::make(new_var, value, body);
tqchen committed
92
    } else {
93 94
      defined_.insert(v.get());
      return IRMutator::Mutate_(op, e);
tqchen committed
95 96
    }
  }
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
    if (scope_.count(op->buffer_var.get())) {
      return Load::make(
          op->type, scope_[op->buffer_var.get()].back(),
          op->index, op->predicate);
    } else {
      return expr;
    }
  }
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    if (scope_.count(op->buffer_var.get())) {
      return Store::make(
          scope_[op->buffer_var.get()].back(), op->value,
          op->index, op->predicate);
    } else {
      return stmt;
    }
  }
  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
    const VarExpr& v = op->var;
    if (defined_.count(v.get())) {
      Expr value = IRMutator::Mutate(op->value);
      VarExpr new_var = Variable::make(v.type(), v->name_hint);
tqchen committed
124
      scope_[v.get()].push_back(new_var);
125
      Stmt body = IRMutator::Mutate(op->body);
tqchen committed
126
      scope_[v.get()].pop_back();
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
      return LetStmt::make(new_var, value, body);
    } else {
      defined_.insert(v.get());
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const For* op, const Stmt& s) final {
    const VarExpr& v = op->loop_var;
    if (defined_.count(v.get())) {
      VarExpr new_var = Variable::make(v.type(), v->name_hint);
      scope_[v.get()].push_back(new_var);
      Stmt stmt = IRMutator::Mutate_(op, s);
      scope_[v.get()].pop_back();
      op = stmt.as<For>();
      return For::make(
          new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
    } else {
      defined_.insert(v.get());
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    const VarExpr& v = op->buffer_var;
    if (defined_.count(v.get())) {
      VarExpr new_var = Variable::make(v.type(), v->name_hint);
      scope_[v.get()].push_back(new_var);
      Stmt stmt = IRMutator::Mutate_(op, s);
      scope_[v.get()].pop_back();
      op = stmt.as<Allocate>();
      return Allocate::make(
          new_var, op->type, op->extents, op->condition,
          op->body, op->new_expr, op->free_function);
    } else {
      defined_.insert(v.get());
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    if (const Variable* v = op->node.as<Variable>()) {
      if (op->attr_key == attr::storage_scope) {
        const Allocate* alloc = op->body.as<Allocate>();
        if (alloc && op->node.same_as(alloc->buffer_var)) {
          Stmt new_alloc = Mutate(op->body);
          if (new_alloc.same_as(op->body)) return s;
          alloc = new_alloc.as<Allocate>();
          CHECK(alloc);
          return AttrStmt::make(
              alloc->buffer_var, op->attr_key, op->value, new_alloc);
        }
      }
      Stmt stmt = IRMutator::Mutate_(op, s);
      op = stmt.as<AttrStmt>();
      if (scope_.count(v) && scope_[v].size() != 0) {
        return AttrStmt::make(
            scope_[v].back(), op->attr_key, op->value, op->body);
tqchen committed
182
      } else {
183
        return stmt;
tqchen committed
184 185
      }
    } else {
186
      return IRMutator::Mutate_(op, s);
tqchen committed
187 188 189 190 191 192 193 194 195 196
    }
  }

 private:
  std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
  std::unordered_set<const Variable*> defined_;
};

}  // namespace

tqchen committed
197
bool VerifySSA(const Stmt& ir) {
tqchen committed
198 199 200 201 202
  IRVerifySSA v;
  v.Visit(ir);
  return v.is_ssa;
}

tqchen committed
203
Stmt ConvertSSA(Stmt stmt) {
tqchen committed
204 205 206 207 208
  return IRConvertSSA().Mutate(stmt);
}

}  // namespace ir
}  // namespace tvm