/*! * Copyright (c) 2017 by Contributors * \file ir_deep_compare.cc */ #include <tvm/ir_pass.h> #include <tvm/ir_functor_ext.h> namespace tvm { namespace ir { using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>; using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>; #define DEFINE_BIOP_EXPR_CMP_(OP) \ void VisitExpr_(const OP* op, const Expr& other) final { \ const OP* rhs = other.as<OP>(); \ if (CompareExpr(op->a, rhs->a) != 0) return; \ if (CompareExpr(op->b, rhs->b) != 0) return; \ } // Deep comparison to check if two IR graph are equivalent class IRDeepCompare : public ExprComparator, public StmtComparator { public: // Equality comparison bool Equal(const Stmt& lhs, const Stmt& rhs) { tie_def_ = true; VisitStmt(lhs, rhs); return order_ == 0; } bool Equal(const Expr& lhs, const Expr& rhs) { tie_def_ = true; VisitExpr(lhs, rhs); return order_ == 0; } int Compare(const Expr& lhs, const Expr& rhs) { tie_def_ = false; VisitExpr(lhs, rhs); return order_; } void VisitExpr(const Expr& n, const Expr& other) override { if (order_ != 0) return; if (n.same_as(other)) return; if (CompareValue(n->type_index(), other->type_index()) != 0) return; if (CompareType(n.type(), other.type()) != 0) return; ExprComparator::VisitExpr(n, other); } void VisitStmt(const Stmt& n, const Stmt& other) override { if (order_ != 0) return; if (n.same_as(other)) return; if (CompareValue(n->type_index(), other->type_index()) != 0) return; StmtComparator::VisitStmt(n, other); } // Stmt void VisitStmt_(const LetStmt* op, const Stmt& other) final { const LetStmt* rhs = other.as<LetStmt>(); if (CompareExpr(op->value, rhs->value) != 0) return; if (tie_def_) { vmap_[op->var.get()] = rhs->var.get(); } else { if (CompareExpr(op->var, rhs->var) != 0) return; } if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const AttrStmt* op, const Stmt& other) final { const AttrStmt* rhs = other.as<AttrStmt>(); if (CompareString(op->attr_key, rhs->attr_key) != 0) return; if (CompareNodeRef(op->node, rhs->node) != 0) return; if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const IfThenElse* op, const Stmt& other) final { const IfThenElse* rhs = other.as<IfThenElse>(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareStmt(op->then_case, rhs->then_case) != 0) return; if (CompareStmt(op->else_case, rhs->else_case) != 0) return; } void VisitStmt_(const For* op, const Stmt& other) final { const For* rhs = other.as<For>(); if (CompareExpr(op->min, rhs->min) != 0) return; if (CompareExpr(op->extent, rhs->extent) != 0) return; if (tie_def_) { vmap_[op->loop_var.get()] = rhs->loop_var.get(); } else { if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return; } if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const Allocate* op, const Stmt& other) final { const Allocate* rhs = other.as<Allocate>(); if (tie_def_) { vmap_[op->buffer_var.get()] = rhs->buffer_var.get(); } else { if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; } if (CompareType(op->type, rhs->type) != 0) return; if (CompareArray(op->extents, rhs->extents) != 0) return; if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return; if (CompareString(op->free_function, rhs->free_function) != 0) return; } void VisitStmt_(const Store* op, const Stmt& other) final { const Store* rhs = other.as<Store>(); if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareExpr(op->index, rhs->index) != 0) return; if (CompareExpr(op->predicate, rhs->predicate) != 0) return; } void VisitStmt_(const Free* op, const Stmt& other) final { const Free* rhs = other.as<Free>(); if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; } void VisitStmt_(const AssertStmt* op, const Stmt& other) final { const AssertStmt* rhs = other.as<AssertStmt>(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->message, rhs->message) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final { const ProducerConsumer* rhs = other.as<ProducerConsumer>(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->is_producer, rhs->is_producer) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const Provide* op, const Stmt& other) final { const Provide* rhs = other.as<Provide>(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareArray(op->args, rhs->args) != 0) return; } void VisitStmt_(const Realize* op, const Stmt& other) final { const Realize* rhs = other.as<Realize>(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; if (CompareType(op->type, rhs->type) != 0) return; if (CompareRegion(op->bounds, rhs->bounds) != 0) return; if (CompareStmt(op->body, rhs->body) != 0) return; } void VisitStmt_(const Prefetch* op, const Stmt& other) final { const Prefetch* rhs = other.as<Prefetch>(); if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; if (CompareType(op->type, rhs->type) != 0) return; if (CompareRegion(op->bounds, rhs->bounds) != 0) return; } void VisitStmt_(const Block* op, const Stmt& other) final { const Block* rhs = other.as<Block>(); if (CompareStmt(op->first, rhs->first) != 0) return; if (CompareStmt(op->rest, rhs->rest) != 0) return; } void VisitStmt_(const Evaluate* op, const Stmt& other) final { const Evaluate* rhs = other.as<Evaluate>(); CompareExpr(op->value, rhs->value); } // Exprs void VisitExpr_(const Variable* op, const Expr& other) final { const Variable* rhs = other.as<Variable>(); auto it = vmap_.find(op); if (it != vmap_.end()) op = it->second; if (op < rhs) { order_ = -1; } else if (op > rhs) { order_ = +1; } } void VisitExpr_(const Load* op, const Expr& other) final { const Load* rhs = other.as<Load>(); if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; if (CompareExpr(op->index, rhs->index) != 0) return; if (CompareExpr(op->predicate, rhs->predicate) != 0) return; } void VisitExpr_(const Let* op, const Expr& other) final { const Let* rhs = other.as<Let>(); if (tie_def_) { vmap_[op->var.get()] = rhs->var.get(); } else { if (CompareExpr(op->var, rhs->var) != 0) return; } if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareExpr(op->body, rhs->body) != 0) return; } void VisitExpr_(const Call* op, const Expr& other) final { const Call* rhs = other.as<Call>(); if (CompareString(op->name, rhs->name)) return; if (CompareArray(op->args, rhs->args)) return; if (CompareValue(op->call_type, rhs->call_type) != 0) return; if (CompareNodeRef(op->func, rhs->func) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; } void VisitExpr_(const Reduce *op, const Expr& other) final { const Reduce* rhs = other.as<Reduce>(); if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return; if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return; if (CompareValue(op->value_index, rhs->value_index) != 0) return; for (size_t i = 0; i < op->axis.size(); ++i) { if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return; if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return; if (tie_def_) { vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get(); } else { if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return; } } if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareArray(op->source, rhs->source) != 0) return; } void VisitExpr_(const IntImm *op, const Expr& other) final { CompareValue(op->value, other.as<IntImm>()->value); } void VisitExpr_(const UIntImm *op, const Expr& other) final { CompareValue(op->value, other.as<UIntImm>()->value); } void VisitExpr_(const FloatImm *op, const Expr& other) final { CompareValue(op->value, other.as<FloatImm>()->value); } void VisitExpr_(const StringImm *op, const Expr& other) final { CompareString(op->value, other.as<StringImm>()->value); } void VisitExpr_(const Cast *op, const Expr& other) final { CompareExpr(op->value, other.as<Cast>()->value); } void VisitExpr_(const Not *op, const Expr& other) final { CompareExpr(op->a, other.as<Not>()->a); } void VisitExpr_(const Select *op, const Expr& other) final { const Select* rhs = other.as<Select>(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->true_value, rhs->true_value) != 0) return; if (CompareExpr(op->false_value, rhs->false_value) != 0) return; } void VisitExpr_(const Ramp *op, const Expr& other) final { const Ramp* rhs = other.as<Ramp>(); if (CompareExpr(op->base, rhs->base) != 0) return; if (CompareExpr(op->stride, rhs->stride) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } void VisitExpr_(const Broadcast *op, const Expr& other) final { const Broadcast* rhs = other.as<Broadcast>(); if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } void VisitExpr_(const Shuffle *op, const Expr& other) final { const Shuffle* rhs = other.as<Shuffle>(); if (CompareArray(op->vectors, rhs->vectors) != 0) return; if (CompareArray(op->indices, rhs->indices) != 0) return; } DEFINE_BIOP_EXPR_CMP_(Add) DEFINE_BIOP_EXPR_CMP_(Sub) DEFINE_BIOP_EXPR_CMP_(Mul) DEFINE_BIOP_EXPR_CMP_(Div) DEFINE_BIOP_EXPR_CMP_(Mod) DEFINE_BIOP_EXPR_CMP_(Min) DEFINE_BIOP_EXPR_CMP_(Max) DEFINE_BIOP_EXPR_CMP_(EQ) DEFINE_BIOP_EXPR_CMP_(NE) DEFINE_BIOP_EXPR_CMP_(LT) DEFINE_BIOP_EXPR_CMP_(LE) DEFINE_BIOP_EXPR_CMP_(GT) DEFINE_BIOP_EXPR_CMP_(GE) DEFINE_BIOP_EXPR_CMP_(And) DEFINE_BIOP_EXPR_CMP_(Or) private: int CompareExpr(const Expr& lhs, const Expr& rhs) { if (order_ != 0) return order_; if (!lhs.defined() && rhs.defined()) { order_ = -1; return order_; } if (!rhs.defined() && lhs.defined()) { order_ = +1; return order_; } VisitExpr(lhs, rhs); return order_; } int CompareStmt(const Stmt& lhs, const Stmt& rhs) { if (order_ != 0) return order_; if (!lhs.defined() && rhs.defined()) { order_ = -1; return order_; } if (!rhs.defined() && lhs.defined()) { order_ = +1; return order_; } VisitStmt(lhs, rhs); return order_; } int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) { if (order_ != 0) return order_; if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; for (size_t i = 0; i < lhs.size(); ++i) { if (CompareExpr(lhs[i], rhs[i]) != 0) return order_; } return order_; } int CompareRegion(const HalideIR::Internal::Region& lhs, const HalideIR::Internal::Region& rhs) { if (order_ != 0) return order_; if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; for (size_t i = 0; i < lhs.size(); ++i) { if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_; if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_; } return order_; } int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) { if (order_ != 0) return order_; if (lhs.get() < rhs.get()) { order_ = -1; return order_; } if (lhs.get() > rhs.get()) { order_ = +1; return order_; } return order_; } int CompareType(const Type& lhs, const Type& rhs) { if (order_ != 0) return order_; if (lhs == rhs) return order_; if (CompareValue(lhs.code(), rhs.code()) != 0) return order_; if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_; if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_; return order_; } int CompareString(const std::string& lhs, const std::string& rhs) { if (order_ != 0) return order_; order_ = lhs.compare(rhs); return order_; } template<typename T> int CompareValue(const T& lhs, const T& rhs) { if (order_ != 0) return order_; if (lhs < rhs) { order_ = -1; return order_; } else if (lhs > rhs) { order_ = +1; return order_; } return order_; } int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) { if (order_ != 0) return order_; if (lhs == rhs) return order_; if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_; if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_; IRDeepCompare cmp; if (tie_def_) { for (size_t i = 0; i < lhs->lhs.size(); ++i) { cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get(); } for (size_t i = 0; i < lhs->rhs.size(); ++i) { cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get(); } } else { for (size_t i = 0; i < lhs->lhs.size(); ++i) { if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_; } for (size_t i = 0; i < lhs->lhs.size(); ++i) { if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_; } } order_ = cmp.CompareArray(lhs->result, rhs->result); return order_; } // The order flag, smaller, -1, bigger: +1, equal: 0 int order_{0}; // Whether tie intermediate definitions. // This allows use to tie definitions of two variables together. // This enables us to assert equal between (let x in x + 1), (let y in y + 1) // However, the comparison is no longer in total order. // Only equality/non-equality information is valid. bool tie_def_{false}; // varaible remap if any std::unordered_map<const Variable*, const Variable*> vmap_; }; bool Equal(const Stmt& lhs, const Stmt& rhs) { return IRDeepCompare().Equal(lhs, rhs); } bool Equal(const Expr& lhs, const Expr& rhs) { return IRDeepCompare().Equal(lhs, rhs); } int Compare(const Expr& lhs, const Expr& rhs) { return IRDeepCompare().Compare(lhs, rhs); } } // namespace ir } // namespace tvm