/*! * Copyright (c) 2016 by Contributors * SSA related checks and pass. * * SSA requires each varaible to be only defined once. * \file ssa.cc */ #include <tvm/ir.h> #include <tvm/ir_visitor.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <unordered_set> #include <unordered_map> #include <vector> namespace tvm { namespace ir { namespace { class IRVerifySSA final : public IRVisitor { public: bool is_ssa{true}; void Visit(const NodeRef& n) final { if (!is_ssa) return; IRVisitor::Visit(n); } void Visit_(const Let* op) final { MarkDef(op->var.get()); IRVisitor::Visit_(op); } void Visit_(const LetStmt* op) final { MarkDef(op->var.get()); IRVisitor::Visit_(op); } void Visit_(const For* op) final { MarkDef(op->loop_var.get()); IRVisitor::Visit_(op); } void Visit_(const Allocate* op) final { MarkDef(op->buffer_var.get()); IRVisitor::Visit_(op); } private: void MarkDef(const Variable* v) { if (defined_.count(v) != 0) { is_ssa = false; return; } else { defined_[v] = 1; } } std::unordered_map<const Variable*, int> defined_; }; class IRConvertSSA final : public IRMutator { public: Expr Mutate_(const Variable* op, const Expr& e) final { if (scope_.count(op)) { return scope_[op].back(); } else { return e; } } Expr Mutate_(const Let* op, const Expr& e) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { Expr value = IRMutator::Mutate(op->value); VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); Expr body = IRMutator::Mutate(op->body); scope_[v.get()].pop_back(); return Let::make(new_var, value, body); } else { defined_.insert(v.get()); return IRMutator::Mutate_(op, e); } } Expr Mutate_(const Load* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<Load>(); if (scope_.count(op->buffer_var.get())) { return Load::make( op->type, scope_[op->buffer_var.get()].back(), op->index, op->predicate); } else { return expr; } } Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Store>(); if (scope_.count(op->buffer_var.get())) { return Store::make( scope_[op->buffer_var.get()].back(), op->value, op->index, op->predicate); } else { return stmt; } } Stmt Mutate_(const LetStmt* op, const Stmt& s) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { Expr value = IRMutator::Mutate(op->value); VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt body = IRMutator::Mutate(op->body); scope_[v.get()].pop_back(); return LetStmt::make(new_var, value, body); } else { defined_.insert(v.get()); return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const For* op, const Stmt& s) final { const VarExpr& v = op->loop_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt stmt = IRMutator::Mutate_(op, s); scope_[v.get()].pop_back(); op = stmt.as<For>(); return For::make( new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const Allocate* op, const Stmt& s) final { const VarExpr& v = op->buffer_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); Stmt stmt = IRMutator::Mutate_(op, s); scope_[v.get()].pop_back(); op = stmt.as<Allocate>(); return Allocate::make( new_var, op->type, op->extents, op->condition, op->body, op->new_expr, op->free_function); } else { defined_.insert(v.get()); return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { if (const Variable* v = op->node.as<Variable>()) { if (op->attr_key == attr::storage_scope) { const Allocate* alloc = op->body.as<Allocate>(); if (alloc && op->node.same_as(alloc->buffer_var)) { Stmt new_alloc = Mutate(op->body); if (new_alloc.same_as(op->body)) return s; alloc = new_alloc.as<Allocate>(); CHECK(alloc); return AttrStmt::make( alloc->buffer_var, op->attr_key, op->value, new_alloc); } } Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<AttrStmt>(); if (scope_.count(v) && scope_[v].size() != 0) { return AttrStmt::make( scope_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } } else { return IRMutator::Mutate_(op, s); } } private: std::unordered_map<const Variable*, std::vector<VarExpr> > scope_; std::unordered_set<const Variable*> defined_; }; } // namespace bool VerifySSA(const Stmt& ir) { IRVerifySSA v; v.Visit(ir); return v.is_ssa; } Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA().Mutate(stmt); } } // namespace ir } // namespace tvm