Commit 8a66ac23 by Tianqi Chen Committed by GitHub

[PASS/OP/REFACTOR] IRDeepCompare, isolate computeop part, allow fuzzy bind (#218)

parent 8e2ea2c4
......@@ -137,6 +137,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
......
......@@ -23,14 +23,6 @@
namespace tvm {
namespace ir {
inline bool Equal(Expr a, Expr b) {
return Halide::Internal::equal(a, b);
}
inline bool Equal(Stmt a, Stmt b) {
return Halide::Internal::equal(a, b);
}
inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
......@@ -40,6 +32,22 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Expr& lhs, const Expr& rhs);
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Stmt& lhs, const Stmt& rhs);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......
......@@ -9,6 +9,7 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./compute_op.h"
#include "./op_util.h"
#include "../schedule/message_passing.h"
......@@ -242,124 +243,6 @@ void MakeReduction(const ComputeOpNode* op,
}
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
// Cross Thread reduction
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
size_t size = self->body.size();
CHECK_GT(size, 0);
std::vector<const Reduce*> reduces(size);
for (size_t i = 0; i < size; ++i) {
const Reduce* reduce = self->body[i].as<Reduce>();
CHECK(reduce);
reduces[i] = reduce;
}
Expr cond = reduces[0]->condition;
for (Expr v : conds) {
cond = cond && v;
}
Array<Expr> freduce_args;
freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
freduce_args.push_back(res_handles[idx]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
}
}
}
// Checks for the thread.
std::vector<Expr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Stmt reduce_body = Evaluate::make(Call::make(
Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduces[0]->combiner,
attr::reduce_scope,
make_zero(Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
Type t = reduces[idx]->type;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
body = AttrStmt::make(
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
}
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
......@@ -370,27 +253,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
}
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
};
Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (self->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < self->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(self, source, &init, &provide);
init = op::Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = op::Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
}
}
ComputeLoopNest MakeComputeLoopNest(
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
} else {
return MakeComputeStmt(this, stage, dom_map);
}
}
ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
......@@ -446,51 +358,10 @@ ComputeLoopNest MakeComputeLoopNest(
e = likely(e);
}
} else {
ret.num_common_loop = ret.main_nest.size() - 1;
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
// grab the nest structure
ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (this->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < this->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
init = Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < this->body.size(); ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
Stmt provide = Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
}
}
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \brief Helper utilities to implement compute_op.
* \file compute_op.h
*/
#ifndef TVM_OP_COMPUTE_OP_H_
#define TVM_OP_COMPUTE_OP_H_
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <vector>
#include <unordered_map>
namespace tvm {
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
/*!
* \brief constructor to build ComputeOpNest
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \return The constructed loop nest
*/
static ComputeLoopNest make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
};
/*!
* \brief Whether compute op is a cross thread reduction structure.
* \param self The pointer to ComputeOpNode
* \param stage the schedule stage.
*/
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage);
/*!
* \brief Build body of compute for cross thread reduction pattern.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \return The created statement.
*/
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Logics related to cross thread reduction, used by ComputeOpNode.
* \file cross_thread_reduction.cc
*/
#include <tvm/ir_pass.h>
#include "./compute_op.h"
#include "./op_util.h"
namespace tvm {
using namespace ir;
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
size_t size = self->body.size();
CHECK_GT(size, 0);
std::vector<const Reduce*> reduces(size);
for (size_t i = 0; i < size; ++i) {
const Reduce* reduce = self->body[i].as<Reduce>();
CHECK(reduce);
reduces[i] = reduce;
}
Expr cond = reduces[0]->condition;
for (Expr v : conds) {
cond = cond && v;
}
Array<Expr> freduce_args;
freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
freduce_args.push_back(res_handles[idx]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
}
}
}
// Checks for the thread.
std::vector<Expr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Stmt reduce_body = Evaluate::make(Call::make(
Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduces[0]->combiner,
attr::reduce_scope,
make_zero(Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
Type t = reduces[idx]->type;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
body = AttrStmt::make(
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
body = op::Substitute(body, value_map);
return MergeNest(nest, body);
}
} // namespace tvm
......@@ -223,7 +223,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
......@@ -263,5 +262,16 @@ Expr ReplaceTensor(Expr expr,
Expr ret = repl.Mutate(expr);
return repl.found ? ret : expr;
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
std::unordered_map<const Variable*, Expr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
}
return ir::Substitute(s, init);
}
} // namespace op
} // namespace tvm
......@@ -12,6 +12,7 @@
#include <unordered_set>
#include <vector>
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
namespace tvm {
namespace op {
......@@ -74,6 +75,15 @@ Stmt ReplaceTensor(Stmt stmt,
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Substitute the variables of stmt by value map.
* \param stmt the statment
* \param value_map The value map.
* \return Substituted result.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<IterVar, Expr>& value_map);
} // namespace op
} // namespace tvm
#endif // TVM_OP_OP_UTIL_H_
......@@ -75,13 +75,38 @@ void ArgBinder::BindArray(const Array<Expr>& arg,
void ArgBinder::BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name) {
const std::string& arg_name,
bool fuzzy_match) {
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
if (arg->shape.size() > value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = arg->shape.size() - value->shape.size();
for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(arg->shape[i]))
<< "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape;
}
for (size_t i = 0; i < value->shape.size(); ++i) {
std::ostringstream os;
os << arg_name << ".shape[" << i << "]";
this->Bind(arg->shape[i + diff], value->shape[i], os.str());
}
if (arg->strides.size() != 0) {
CHECK_EQ(arg->strides.size(), arg->shape.size());
CHECK_EQ(value->strides.size(), value->shape.size());
for (size_t i = 0; i < value->strides.size(); ++i) {
std::ostringstream os;
os << arg_name << ".strides[" << i << "]";
this->Bind(arg->strides[i + diff], value->strides[i], os.str());
}
}
} else {
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
}
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
}
......
......@@ -71,10 +71,12 @@ class ArgBinder {
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
* \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
*/
void BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name);
const std::string& arg_name,
bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
......
/*!
* Copyright (c) 2017 by Contributors
* \file ir_deep_compare.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
namespace tvm {
namespace ir {
using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>;
using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>;
#define DEFINE_BIOP_EXPR_CMP_(OP) \
void VisitExpr_(const OP* op, const Expr& other) final { \
const OP* rhs = other.as<OP>(); \
if (CompareExpr(op->a, rhs->a) != 0) return; \
if (CompareExpr(op->b, rhs->b) != 0) return; \
}
// Deep comparison to check if two IR graph are equivalent
class IRDeepCompare :
public ExprComparator, public StmtComparator {
public:
// Equality comparison
bool Equal(const Stmt& lhs, const Stmt& rhs) {
tie_def_ = true;
VisitStmt(lhs, rhs);
return order_ == 0;
}
bool Equal(const Expr& lhs, const Expr& rhs) {
tie_def_ = true;
VisitExpr(lhs, rhs);
return order_ == 0;
}
void VisitExpr(const Expr& n, const Expr& other) override {
if (order_ != 0) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
if (CompareType(n.type(), other.type()) != 0) return;
ExprComparator::VisitExpr(n, other);
}
void VisitStmt(const Stmt& n, const Stmt& other) override {
if (order_ != 0) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
StmtComparator::VisitStmt(n, other);
}
// Stmt
void VisitStmt_(const LetStmt* op, const Stmt& other) final {
const LetStmt* rhs = other.as<LetStmt>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
} else {
if (CompareExpr(op->var, rhs->var) != 0) return;
}
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const AttrStmt* op, const Stmt& other) final {
const AttrStmt* rhs = other.as<AttrStmt>();
if (CompareString(op->attr_key, rhs->attr_key) != 0) return;
if (CompareNodeRef(op->node, rhs->node) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const IfThenElse* op, const Stmt& other) final {
const IfThenElse* rhs = other.as<IfThenElse>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->then_case, rhs->then_case) != 0) return;
if (CompareStmt(op->else_case, rhs->else_case) != 0) return;
}
void VisitStmt_(const For* op, const Stmt& other) final {
const For* rhs = other.as<For>();
if (CompareExpr(op->min, rhs->min) != 0) return;
if (CompareExpr(op->extent, rhs->extent) != 0) return;
if (tie_def_) {
vmap_[op->loop_var.get()] = rhs->loop_var.get();
} else {
if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return;
}
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Allocate* op, const Stmt& other) final {
const Allocate* rhs = other.as<Allocate>();
if (tie_def_) {
vmap_[op->buffer_var.get()] = rhs->buffer_var.get();
} else {
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareArray(op->extents, rhs->extents) != 0) return;
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return;
if (CompareString(op->free_function, rhs->free_function) != 0) return;
}
void VisitStmt_(const Store* op, const Stmt& other) final {
const Store* rhs = other.as<Store>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
void VisitStmt_(const Free* op, const Stmt& other) final {
const Free* rhs = other.as<Free>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
void VisitStmt_(const AssertStmt* op, const Stmt& other) final {
const AssertStmt* rhs = other.as<AssertStmt>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->message, rhs->message) != 0) return;
}
void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
const ProducerConsumer* rhs = other.as<ProducerConsumer>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->is_producer, rhs->is_producer) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Provide* op, const Stmt& other) final {
const Provide* rhs = other.as<Provide>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareArray(op->args, rhs->args) != 0) return;
}
void VisitStmt_(const Realize* op, const Stmt& other) final {
const Realize* rhs = other.as<Realize>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
void VisitStmt_(const Prefetch* op, const Stmt& other) final {
const Prefetch* rhs = other.as<Prefetch>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->type, rhs->type) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
}
void VisitStmt_(const Block* op, const Stmt& other) final {
const Block* rhs = other.as<Block>();
if (CompareStmt(op->first, rhs->first) != 0) return;
if (CompareStmt(op->rest, rhs->rest) != 0) return;
}
void VisitStmt_(const Evaluate* op, const Stmt& other) final {
const Evaluate* rhs = other.as<Evaluate>();
CompareExpr(op->value, rhs->value);
}
// Exprs
void VisitExpr_(const Variable* op, const Expr& other) final {
const Variable* rhs = other.as<Variable>();
auto it = vmap_.find(op);
if (it != vmap_.end()) op = it->second;
if (op < rhs) {
order_ = -1;
} else if (op > rhs) {
order_ = +1;
}
}
void VisitExpr_(const Load* op, const Expr& other) final {
const Load* rhs = other.as<Load>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
void VisitExpr_(const Let* op, const Expr& other) final {
const Let* rhs = other.as<Let>();
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
} else {
if (CompareExpr(op->var, rhs->var) != 0) return;
}
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareExpr(op->body, rhs->body) != 0) return;
}
void VisitExpr_(const Call* op, const Expr& other) final {
const Call* rhs = other.as<Call>();
if (CompareString(op->name, rhs->name)) return;
if (CompareArray(op->args, rhs->args)) return;
if (CompareValue(op->call_type, rhs->call_type) != 0) return;
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
}
void VisitExpr_(const Reduce *op, const Expr& other) final {
const Reduce* rhs = other.as<Reduce>();
if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
for (size_t i = 0; i < op->axis.size(); ++i) {
if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return;
if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return;
if (tie_def_) {
vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get();
} else {
if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return;
}
}
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareArray(op->source, rhs->source) != 0) return;
}
void VisitExpr_(const IntImm *op, const Expr& other) final {
CompareValue(op->value, other.as<IntImm>()->value);
}
void VisitExpr_(const UIntImm *op, const Expr& other) final {
CompareValue(op->value, other.as<UIntImm>()->value);
}
void VisitExpr_(const FloatImm *op, const Expr& other) final {
CompareValue(op->value, other.as<FloatImm>()->value);
}
void VisitExpr_(const StringImm *op, const Expr& other) final {
CompareString(op->value, other.as<StringImm>()->value);
}
void VisitExpr_(const Cast *op, const Expr& other) final {
CompareExpr(op->value, other.as<Cast>()->value);
}
void VisitExpr_(const Not *op, const Expr& other) final {
CompareExpr(op->a, other.as<Not>()->a);
}
void VisitExpr_(const Select *op, const Expr& other) final {
const Select* rhs = other.as<Select>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
}
void VisitExpr_(const Ramp *op, const Expr& other) final {
const Ramp* rhs = other.as<Ramp>();
if (CompareExpr(op->base, rhs->base) != 0) return;
if (CompareExpr(op->stride, rhs->stride) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
void VisitExpr_(const Broadcast *op, const Expr& other) final {
const Broadcast* rhs = other.as<Broadcast>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
void VisitExpr_(const Shuffle *op, const Expr& other) final {
const Shuffle* rhs = other.as<Shuffle>();
if (CompareArray(op->vectors, rhs->vectors) != 0) return;
if (CompareArray(op->indices, rhs->indices) != 0) return;
}
DEFINE_BIOP_EXPR_CMP_(Add)
DEFINE_BIOP_EXPR_CMP_(Sub)
DEFINE_BIOP_EXPR_CMP_(Mul)
DEFINE_BIOP_EXPR_CMP_(Div)
DEFINE_BIOP_EXPR_CMP_(Mod)
DEFINE_BIOP_EXPR_CMP_(Min)
DEFINE_BIOP_EXPR_CMP_(Max)
DEFINE_BIOP_EXPR_CMP_(EQ)
DEFINE_BIOP_EXPR_CMP_(NE)
DEFINE_BIOP_EXPR_CMP_(LT)
DEFINE_BIOP_EXPR_CMP_(LE)
DEFINE_BIOP_EXPR_CMP_(GT)
DEFINE_BIOP_EXPR_CMP_(GE)
DEFINE_BIOP_EXPR_CMP_(And)
DEFINE_BIOP_EXPR_CMP_(Or)
private:
int CompareExpr(const Expr& lhs, const Expr& rhs) {
if (order_ != 0) return order_;
if (!lhs.defined() && rhs.defined()) {
order_ = -1; return order_;
}
if (!rhs.defined() && lhs.defined()) {
order_ = +1; return order_;
}
VisitExpr(lhs, rhs);
return order_;
}
int CompareStmt(const Stmt& lhs, const Stmt& rhs) {
if (order_ != 0) return order_;
if (!lhs.defined() && rhs.defined()) {
order_ = -1; return order_;
}
if (!rhs.defined() && lhs.defined()) {
order_ = +1; return order_;
}
VisitStmt(lhs, rhs);
return order_;
}
int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
if (CompareExpr(lhs[i], rhs[i]) != 0) return order_;
}
return order_;
}
int CompareRegion(const Halide::Internal::Region& lhs,
const Halide::Internal::Region& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_;
if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_;
}
return order_;
}
int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) {
if (order_ != 0) return order_;
if (lhs.get() < rhs.get()) {
order_ = -1; return order_;
}
if (lhs.get() > rhs.get()) {
order_ = +1; return order_;
}
return order_;
}
int CompareType(const Type& lhs, const Type& rhs) {
if (order_ != 0) return order_;
if (lhs == rhs) return order_;
if (CompareValue(lhs.code(), rhs.code()) != 0) return order_;
if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_;
if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_;
return order_;
}
int CompareString(const std::string& lhs, const std::string& rhs) {
if (order_ != 0) return order_;
order_ = lhs.compare(rhs);
return order_;
}
template<typename T>
int CompareValue(const T& lhs, const T& rhs) {
if (order_ != 0) return order_;
if (lhs < rhs) {
order_ = -1; return order_;
} else if (lhs > rhs) {
order_ = +1; return order_;
}
return order_;
}
int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) {
if (order_ != 0) return order_;
if (lhs == rhs) return order_;
if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_;
if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_;
IRDeepCompare cmp;
if (tie_def_) {
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get();
}
for (size_t i = 0; i < lhs->rhs.size(); ++i) {
cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get();
}
} else {
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_;
}
for (size_t i = 0; i < lhs->lhs.size(); ++i) {
if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_;
}
}
order_ = cmp.CompareArray(lhs->result, rhs->result);
return order_;
}
// The order flag, smaller, -1, bigger: +1, equal: 0
int order_{0};
// Whether tie intermediate definitions.
// This allows use to tie definitions of two variables together.
// This enables us to assert equal between (let x in x + 1), (let y in y + 1)
// However, the comparison is no longer in total order.
// Only equality/non-equality information is valid.
bool tie_def_{false};
// varaible remap if any
std::unordered_map<const Variable*, const Variable*> vmap_;
};
bool Equal(const Stmt& lhs, const Stmt& rhs) {
return IRDeepCompare().Equal(lhs, rhs);
}
bool Equal(const Expr& lhs, const Expr& rhs) {
return IRDeepCompare().Equal(lhs, rhs);
}
} // namespace ir
} // namespace tvm
......@@ -195,7 +195,7 @@ class StorageFlattener : public IRMutator {
}
// start binding
ArgBinder binder(&var_remap_);
binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name);
binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
// Apply the remaps
Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body);
......
import tvm
def test_equal_expr():
x = tvm.var('x')
y = tvm.var('y')
def func1():
return x + y + 1
def func2():
return tvm.exp((x + y + 1) * y / 4)
assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2())
assert not tvm.ir_pass.Equal(func2(), func1())
def test_equal_compute():
x = tvm.var('x')
y = tvm.var('y')
n = 128
A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((n, n), name='B')
ii = tvm.var('i')
jj = tvm.var('j')
def func1():
k = tvm.reduce_axis((0, n), name='k')
return tvm.sum(A[ii, k] * B[jj, k], axis=k)
Ab = tvm.decl_buffer((n,), name='A')
n = tvm.var("n")
def func2():
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
return ib.get()
assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2())
if __name__ == "__main__":
test_equal_expr()
test_equal_compute()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment