Commit 8a66ac23 by Tianqi Chen Committed by GitHub

[PASS/OP/REFACTOR] IRDeepCompare, isolate computeop part, allow fuzzy bind (#218)

parent 8e2ea2c4
......@@ -137,6 +137,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
......
......@@ -23,14 +23,6 @@
namespace tvm {
namespace ir {
inline bool Equal(Expr a, Expr b) {
return Halide::Internal::equal(a, b);
}
inline bool Equal(Stmt a, Stmt b) {
return Halide::Internal::equal(a, b);
}
inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
......@@ -40,6 +32,22 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Expr& lhs, const Expr& rhs);
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Stmt& lhs, const Stmt& rhs);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......
......@@ -9,6 +9,7 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./compute_op.h"
#include "./op_util.h"
#include "../schedule/message_passing.h"
......@@ -242,124 +243,6 @@ void MakeReduction(const ComputeOpNode* op,
}
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
// Cross Thread reduction
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
size_t size = self->body.size();
CHECK_GT(size, 0);
std::vector<const Reduce*> reduces(size);
for (size_t i = 0; i < size; ++i) {
const Reduce* reduce = self->body[i].as<Reduce>();
CHECK(reduce);
reduces[i] = reduce;
}
Expr cond = reduces[0]->condition;
for (Expr v : conds) {
cond = cond && v;
}
Array<Expr> freduce_args;
freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
freduce_args.push_back(res_handles[idx]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
}
}
}
// Checks for the thread.
std::vector<Expr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Stmt reduce_body = Evaluate::make(Call::make(
Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduces[0]->combiner,
attr::reduce_scope,
make_zero(Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
Type t = reduces[idx]->type;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
body = AttrStmt::make(
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
}
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
......@@ -370,27 +253,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
}
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
};
Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
// grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (self->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < self->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(self, source, &init, &provide);
init = op::Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = op::Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
}
}
ComputeLoopNest MakeComputeLoopNest(
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
} else {
return MakeComputeStmt(this, stage, dom_map);
}
}
ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
......@@ -446,51 +358,10 @@ ComputeLoopNest MakeComputeLoopNest(
e = likely(e);
}
} else {
ret.num_common_loop = ret.main_nest.size() - 1;
CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
ret.num_common_loop = stage->leaf_iter_vars.size();
}
// copy elison here.
return ret;
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
// grab the nest structure
ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (this->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < this->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
init = Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < this->body.size(); ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
Stmt provide = Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
}
}
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \brief Helper utilities to implement compute_op.
* \file compute_op.h
*/
#ifndef TVM_OP_COMPUTE_OP_H_
#define TVM_OP_COMPUTE_OP_H_
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <vector>
#include <unordered_map>
namespace tvm {
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
/*!
* \brief constructor to build ComputeOpNest
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \return The constructed loop nest
*/
static ComputeLoopNest make(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
};
/*!
* \brief Whether compute op is a cross thread reduction structure.
* \param self The pointer to ComputeOpNode
* \param stage the schedule stage.
*/
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage);
/*!
* \brief Build body of compute for cross thread reduction pattern.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \return The created statement.
*/
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map);
} // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Logics related to cross thread reduction, used by ComputeOpNode.
* \file cross_thread_reduction.cc
*/
#include <tvm/ir_pass.h>
#include "./compute_op.h"
#include "./op_util.h"
namespace tvm {
using namespace ir;
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0;
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
++thread_red;
} else {
++normal_red;
}
} else {
CHECK_EQ(thread_red, 0)
<< "Cross thread reduce cannot swap with normal data axis";
}
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
return thread_red != 0;
}
Stmt MakeCrossThreadReduction(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) {
Array<Expr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto conds = op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
size_t size = self->body.size();
CHECK_GT(size, 0);
std::vector<const Reduce*> reduces(size);
for (size_t i = 0; i < size; ++i) {
const Reduce* reduce = self->body[i].as<Reduce>();
CHECK(reduce);
reduces[i] = reduce;
}
Expr cond = reduces[0]->condition;
for (Expr v : conds) {
cond = cond && v;
}
Array<Expr> freduce_args;
freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
freduce_args.push_back(res_handles[idx]);
}
for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() &&
(*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
}
}
}
// Checks for the thread.
std::vector<Expr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
Stmt reduce_body = Evaluate::make(Call::make(
Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduces[0]->combiner,
attr::reduce_scope,
make_zero(Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
Type t = reduces[idx]->type;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
body = AttrStmt::make(
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
body = op::Substitute(body, value_map);
return MergeNest(nest, body);
}
} // namespace tvm
......@@ -223,7 +223,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
......@@ -263,5 +262,16 @@ Expr ReplaceTensor(Expr expr,
Expr ret = repl.Mutate(expr);
return repl.found ? ret : expr;
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
std::unordered_map<const Variable*, Expr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
}
return ir::Substitute(s, init);
}
} // namespace op
} // namespace tvm
......@@ -12,6 +12,7 @@
#include <unordered_set>
#include <vector>
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
namespace tvm {
namespace op {
......@@ -74,6 +75,15 @@ Stmt ReplaceTensor(Stmt stmt,
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Substitute the variables of stmt by value map.
* \param stmt the statment
* \param value_map The value map.
* \return Substituted result.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<IterVar, Expr>& value_map);
} // namespace op
} // namespace tvm
#endif // TVM_OP_OP_UTIL_H_
......@@ -75,13 +75,38 @@ void ArgBinder::BindArray(const Array<Expr>& arg,
void ArgBinder::BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name) {
const std::string& arg_name,
bool fuzzy_match) {
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
if (arg->shape.size() > value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = arg->shape.size() - value->shape.size();
for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(arg->shape[i]))
<< "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape;
}
for (size_t i = 0; i < value->shape.size(); ++i) {
std::ostringstream os;
os << arg_name << ".shape[" << i << "]";
this->Bind(arg->shape[i + diff], value->shape[i], os.str());
}
if (arg->strides.size() != 0) {
CHECK_EQ(arg->strides.size(), arg->shape.size());
CHECK_EQ(value->strides.size(), value->shape.size());
for (size_t i = 0; i < value->strides.size(); ++i) {
std::ostringstream os;
os << arg_name << ".strides[" << i << "]";
this->Bind(arg->strides[i + diff], value->strides[i], os.str());
}
}
} else {
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
}
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
}
......
......@@ -71,10 +71,12 @@ class ArgBinder {
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
* \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
*/
void BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name);
const std::string& arg_name,
bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
......
......@@ -195,7 +195,7 @@ class StorageFlattener : public IRMutator {
}
// start binding
ArgBinder binder(&var_remap_);
binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name);
binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
// Apply the remaps
Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body);
......
import tvm
def test_equal_expr():
x = tvm.var('x')
y = tvm.var('y')
def func1():
return x + y + 1
def func2():
return tvm.exp((x + y + 1) * y / 4)
assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2())
assert not tvm.ir_pass.Equal(func2(), func1())
def test_equal_compute():
x = tvm.var('x')
y = tvm.var('y')
n = 128
A = tvm.placeholder((n, n), name='A')
B = tvm.placeholder((n, n), name='B')
ii = tvm.var('i')
jj = tvm.var('j')
def func1():
k = tvm.reduce_axis((0, n), name='k')
return tvm.sum(A[ii, k] * B[jj, k], axis=k)
Ab = tvm.decl_buffer((n,), name='A')
n = tvm.var("n")
def func2():
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
return ib.get()
assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2())
if __name__ == "__main__":
test_equal_expr()
test_equal_compute()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment