Commit 330d49f8 by Tianqi Chen Committed by GitHub

[IR] Update new version of HalideIR (#116)

parent d3c8256b
Subproject commit 398edacd956c6de82185821ffd9f482598182e51 Subproject commit 4fffc62c124651c1cde18f31957db413b677d601
...@@ -174,6 +174,14 @@ namespace intrinsic { ...@@ -174,6 +174,14 @@ namespace intrinsic {
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) { * Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field; * return arr[index]->field;
* } * }
...@@ -355,6 +363,7 @@ using Halide::Internal::Realize; ...@@ -355,6 +363,7 @@ using Halide::Internal::Realize;
using Halide::Internal::Block; using Halide::Internal::Block;
using Halide::Internal::IfThenElse; using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate; using Halide::Internal::Evaluate;
using Halide::Internal::Shuffle;
// ir functions // ir functions
using Halide::Internal::is_const_power_of_two_integer; using Halide::Internal::is_const_power_of_two_integer;
......
...@@ -98,6 +98,7 @@ class IRMutator { ...@@ -98,6 +98,7 @@ class IRMutator {
virtual Expr Mutate_(const UIntImm* op, const Expr& e); virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
}; };
} // namespace ir } // namespace ir
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
#include <ir/IREquality.h> #include <ir/IREquality.h>
#include <pass/Simplify.h> #include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h> #include <tvm/ir_functor.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
......
...@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For") ...@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
args[5]); args[5]);
}); });
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Type t = args[0];
if (args.size() == 3) {
*ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = Load::make(t, args[1], args[2], args[3]);
}
});
TVM_REGISTER_API("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1];
if (args.size() == 3) {
*ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
} else {
*ret = Store::make(args[0], value, args[2], args[3]);
}
});
TVM_REGISTER_API("make.Realize") TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args[0], *ret = Realize::make(args[0],
...@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call") ...@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
args[5]); args[5]);
}); });
TVM_REGISTER_API("make.Allocate")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Allocate::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API("make.CommReducer") TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]); *ret = CommReducerNode::make(args[0], args[1], args[2]);
...@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer") ...@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3]); \ *ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \ }) \
#define REGISTER_MAKE5(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \ #define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
...@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let); ...@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Load); REGISTER_MAKE5(Allocate);
REGISTER_MAKE3(Store);
REGISTER_MAKE4(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <pass/Interval.h> #include <arithmetic/Interval.h>
#include <limits> #include <limits>
namespace tvm { namespace tvm {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <pass/Interval.h> #include <arithmetic/Interval.h>
#include <unordered_map> #include <unordered_map>
#include "./compute_expr.h" #include "./compute_expr.h"
#include "./int_set_internal.h" #include "./int_set_internal.h"
......
...@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc(op, " << ", os, this); PrintBinaryIntrinsitc(op, " << ", os, this);
} else if (op->is_intrinsic(Call::shift_right)) { } else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, this); PrintBinaryIntrinsitc(op, " >> ", os, this);
} else if (op->is_intrinsic(Call::address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
os << "(("; os << "((";
...@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index); std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
os << ref; os << ref;
} else { } else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
Expr base; Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base); std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
...@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
this->PrintIndent(); this->PrintIndent();
stream << ref << " = " << value << ";\n"; stream << ref << " = " << value << ";\n";
} else { } else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
Expr base; Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) { if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value); std::string value = this->PrintExpr(op->value);
......
...@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return builder_->CreateLShr( return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1])); MakeValue(op->args[0]), MakeValue(op->args[1]));
} }
} else if (op->is_intrinsic(Call::address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
return CreateBufferPtr( return CreateBufferPtr(
...@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
} else { } else {
LOG(FATAL) << "Unknown stack alloca type " << type; LOG(FATAL) << "Unknown stack alloca type " << type;
} }
} else if (op->is_intrinsic(Call::null_handle)) { } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_); return llvm::Constant::getNullValue(t_void_p_);
} else { } else {
LOG(FATAL) << "Unknown intrinstic " << op->name; LOG(FATAL) << "Unknown intrinstic " << op->name;
...@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat( ...@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
} }
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
Type t = op->type; Type t = op->type;
const Ramp* ramp = op->index.as<Ramp>(); const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get()); llvm::Value* buf = GetVarValue(op->buffer_var.get());
...@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { ...@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
t, op->buffer_var, t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>( Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, first_shift)), ramp->base, make_const(bt, first_shift)),
make_const(bt, 1), ramp->lanes))); make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
llvm::Value* next = MakeValue(Load::make( llvm::Value* next = MakeValue(Load::make(
t, op->buffer_var, t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>( Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, ramp->lanes + next_shift)), ramp->base, make_const(bt, ramp->lanes + next_shift)),
make_const(bt, 1), ramp->lanes))); make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
// shuffle // shuffle
std::vector<llvm::Constant*> indices; std::vector<llvm::Constant*> indices;
int target_index = 0; int target_index = 0;
...@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { ...@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
make_const(ramp->base.type(), 1), make_const(ramp->base.type(), 1),
lanes); lanes);
// load value then flip // load value then flip
llvm::Value* v = MakeValue(Load::make(t, op->buffer_var, neg_ramp)); llvm::Value* v = MakeValue(
Load::make(t, op->buffer_var, neg_ramp, const_true(t.lanes())));
return CreateVecFlip(v); return CreateVecFlip(v);
} else { } else {
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t)); llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
...@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { ...@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
// stmts // stmts
void CodeGenLLVM::VisitStmt_(const Store* op) { void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
llvm::Value* value = MakeValue(op->value); llvm::Value* value = MakeValue(op->value);
Type t = op->value.type(); Type t = op->value.type();
const Ramp* ramp = op->index.as<Ramp>(); const Ramp* ramp = op->index.as<Ramp>();
......
...@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) { ...@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
} }
void CodeGenStackVM::VisitExpr_(const Call* op) { void CodeGenStackVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(Call::address_of)) { if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
...@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes()); this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD); this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(Call::null_handle)) { } else if (op->is_intrinsic(Call::reinterpret)) {
this->PushOp(StackVM::PUSH_I64, 0); this->Push(op->args[0]);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value; int kind = op->args[2].as<IntImm>()->value;
......
...@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor { ...@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
if (is_zero(op->index) && load) { if (is_zero(op->index) && load) {
compute->body = Store::make( compute->body = Store::make(
op->buffer_var, op->buffer_var,
Load::make(load->type, load->buffer_var, repl.Mutate(load->index)), Load::make(load->type, load->buffer_var,
op->index); repl.Mutate(load->index), op->predicate),
op->index, op->predicate);
} else { } else {
compute->body = Store::make( compute->body = Store::make(
op->buffer_var, repl.Mutate(op->value), repl.Mutate(op->index)); op->buffer_var, repl.Mutate(op->value),
repl.Mutate(op->index), op->predicate);
} }
compute->inputs = repl.inputs_; compute->inputs = repl.inputs_;
pipeline_->stages.push_back(ComputeBlock(compute)); pipeline_->stages.push_back(ComputeBlock(compute));
......
...@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) { ...@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr Buffer::MakeLoad(Array<Expr> index) const { Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
return ir::Load::make(n->dtype, n->data, BufferOffset(n, index)); return ir::Load::make(
n->dtype, n->data, BufferOffset(n, index),
const_true(n->dtype.lanes()));
} }
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const { Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype); CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->data, value, BufferOffset(n, index)); return ir::Store::make(n->data, value, BufferOffset(n, index),
const_true(n->dtype.lanes()));
} }
Buffer BufferNode::make(std::string name, Buffer BufferNode::make(std::string name,
......
...@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction( ...@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
} }
} }
} }
Type t = reduce->type;
Expr pred = const_true(t.lanes());
Stmt reduce_body = Store::make(res_handle, Stmt reduce_body = Store::make(res_handle,
Call::make( Call::make(
reduce->type, reduce->type,
ir::intrinsic::tvm_thread_allreduce, ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic), freduce_args, Call::Intrinsic),
0); 0, pred);
reduce_body = AttrStmt::make( reduce_body = AttrStmt::make(
reduce->combiner, reduce->combiner,
attr::reduce_scope, attr::reduce_scope,
make_zero(reduce->type), make_zero(reduce->type),
reduce_body); reduce_body);
Stmt assign_body = Provide::make( Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0), args); stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Allocate::make( Stmt body = Allocate::make(
......
...@@ -152,11 +152,7 @@ class VTInjector : public IRMutator { ...@@ -152,11 +152,7 @@ class VTInjector : public IRMutator {
return e; return e;
} }
Expr RewriteIndex(Expr index, Expr alloc_extent) const { Expr RewriteIndex(Expr index, Expr alloc_extent) const {
if (index_rewrite_strategy_ == 0) { return index + var_ * alloc_extent;
return index * num_threads_ + var_;
} else {
return index + var_ * alloc_extent;
}
} }
// Load // Load
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
...@@ -168,7 +164,8 @@ class VTInjector : public IRMutator { ...@@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
auto it = touched_alloc_.find(op->buffer_var.get()); auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) { if (it != touched_alloc_.end()) {
return Load::make(op->type, op->buffer_var, return Load::make(op->type, op->buffer_var,
RewriteIndex(op->index, it->second)); RewriteIndex(op->index, it->second),
op->predicate);
} else { } else {
return expr; return expr;
} }
...@@ -184,7 +181,8 @@ class VTInjector : public IRMutator { ...@@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
if (it != touched_alloc_.end()) { if (it != touched_alloc_.end()) {
return Store::make(op->buffer_var, return Store::make(op->buffer_var,
op->value, op->value,
RewriteIndex(op->index, it->second)); RewriteIndex(op->index, it->second),
op->predicate);
} else { } else {
return stmt; return stmt;
} }
...@@ -307,6 +305,9 @@ class VTInjector : public IRMutator { ...@@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
for (size_t i = 1; i < extents.size(); ++i) { for (size_t i = 1; i < extents.size(); ++i) {
stride = arith::ComputeExpr<Mul>(stride, extents[i]); stride = arith::ComputeExpr<Mul>(stride, extents[i]);
} }
if (op->type.lanes() != 0) {
stride = stride * op->type.lanes();
}
Array<Expr> other; Array<Expr> other;
other.push_back(num_threads_); other.push_back(num_threads_);
for (Expr e : extents) { for (Expr e : extents) {
...@@ -368,8 +369,6 @@ class VTInjector : public IRMutator { ...@@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
Var var_; Var var_;
// the threads/lanes // the threads/lanes
int num_threads_; int num_threads_;
// Index rewriting strategy
int index_rewrite_strategy_{1};
// whethe the loop is already injected. // whethe the loop is already injected.
bool vt_loop_injected_{false}; bool vt_loop_injected_{false};
// whether current expression get touched. // whether current expression get touched.
......
...@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { ...@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) { Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) {
return s; return s;
} else { } else {
return Store::make(op->buffer_var, value, index); return Store::make(op->buffer_var, value, index, pred);
} }
} }
...@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { ...@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
Expr IRMutator::Mutate_(const Load *op, const Expr& e) { Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) { Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e; return e;
} else { } else {
return Load::make(op->type, op->buffer_var, index); return Load::make(op->type, op->buffer_var, index, pred);
} }
} }
...@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { ...@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) {
return e;
} else {
return Shuffle::make(new_vec, op->indices);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \ return e; \
...@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) ...@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(IntImm) .DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm)
.DISPATCH_TO_MUTATE_EXPR(StringImm); .DISPATCH_TO_MUTATE_EXPR(StringImm)
.DISPATCH_TO_MUTATE_EXPR(Shuffle);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -111,8 +111,10 @@ inline Expr TVMStructGet( ...@@ -111,8 +111,10 @@ inline Expr TVMStructGet(
*/ */
inline Expr AddressOffset(Var handle, Type dtype, int offset) { inline Expr AddressOffset(Var handle, Type dtype, int offset) {
return Call::make( return Call::make(
Handle(), Call::address_of, Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic); {Load::make(dtype, handle, make_const(Int(32), offset * dtype.lanes()),
const_true(dtype.lanes()))},
Call::PureIntrinsic);
} }
/*! /*!
......
...@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) { ...@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
void IRVisitor::Visit_(const Load *op) { void IRVisitor::Visit_(const Load *op) {
this->Visit(op->index); this->Visit(op->index);
this->Visit(op->predicate);
} }
void IRVisitor::Visit_(const Store *op) { void IRVisitor::Visit_(const Store *op) {
this->Visit(op->value); this->Visit(op->value);
this->Visit(op->index); this->Visit(op->index);
this->Visit(op->predicate);
} }
void IRVisitor::Visit_(const IfThenElse *op) { void IRVisitor::Visit_(const IfThenElse *op) {
......
...@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator { ...@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back( prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]), Store::make(stack_shape_, Convert(Int(64), op->args[i]),
ConstInt32(stack_begin +i))); ConstInt32(stack_begin +i), const_true(1)));
} }
return AddressOffset(stack_shape_, Int(64), stack_begin); return AddressOffset(stack_shape_, Int(64), stack_begin);
} }
...@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator { ...@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
prep_seq_.emplace_back( prep_seq_.emplace_back(
Store::make(stack_tcode_, Store::make(stack_tcode_,
ConstInt32(arg_tcode), ConstInt32(arg_tcode),
stack_index)); stack_index, const_true(1)));
} }
// UPDATE stack value // UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
......
...@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
int threadx_extent = 1; int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent); Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent); Expr group_index = FlattenThread(vpar, &group_extent);
Expr pred = const_true(value.type().lanes());
if (reduce_extent == 1) { if (reduce_extent == 1) {
// special case, no reduction is needed. // special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0); return Store::make(op->buffer_var, value, 0, pred);
} }
// Whether the threadIdx.x is involved in reduction. // Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) { if (vred[0].scope.dim_index == 0) {
...@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
std::vector<Stmt> seq; std::vector<Stmt> seq;
seq.emplace_back(Store::make( seq.emplace_back(Store::make(
shared_buf, value, shared_buf, value,
BufIndex(reduce_index, group_index, reduce_extent))); BufIndex(reduce_index, group_index, reduce_extent), pred));
seq.emplace_back(SyncThread("shared")); seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce( seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf, combiner, value.type(), shared_buf,
...@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
load_remap_[op->buffer_var.get()] = load_remap_[op->buffer_var.get()] =
Load::make( Load::make(
value.type(), shared_buf, value.type(), shared_buf,
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent)); BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
pred);
alloc_remap_[op->buffer_var.get()] = alloc_remap_[op->buffer_var.get()] =
Allocate::make(shared_buf, value.type(), Allocate::make(shared_buf, value.type(),
{Expr(group_extent), Expr(reduce_extent)}, {Expr(group_extent), Expr(reduce_extent)},
const_true(), Evaluate::make(0)); pred, Evaluate::make(0));
return MergeSeq(seq); return MergeSeq(seq);
} }
// make allreduce. // make allreduce.
...@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
auto freduce = [&](int offset) { auto freduce = [&](int offset) {
Expr b = Load::make( Expr b = Load::make(
type, shared_buf, type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent)); BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
Expr a = Load::make(type, shared_buf, buf_index); Expr a = Load::make(type, shared_buf, buf_index, const_true());
return Store::make(shared_buf, (*combiner)(a, b), buf_index); return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
}; };
// Step one, check for // Step one, check for
if (reduce_align > reduce_extent) { if (reduce_align > reduce_extent) {
......
...@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
Var tcode(v_arg->name_hint + ".code", Int(32)); Var tcode(v_arg->name_hint + ".code", Int(32));
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
tcode, Load::make( tcode, Load::make(
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop)); Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
nop));
Type t = v_arg.type(); Type t = v_arg.type();
if (t.is_handle()) { if (t.is_handle()) {
std::ostringstream msg; std::ostringstream msg;
...@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push(buf->shape[k], f_push(buf->shape[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_shape, Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k))), IntImm::make(Int(32), k), const_true(1))),
field_name.str()); field_name.str());
} }
// strides field // strides field
...@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push(buf->strides[k], f_push(buf->strides[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_strides, Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k))), IntImm::make(Int(32), k), const_true(1))),
field_name.str()); field_name.str());
} }
} }
......
...@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator { ...@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op = expr.as<Load>(); op = expr.as<Load>();
if (read_access_ && buf_var_ == op->buffer_var.get()) { if (read_access_ && buf_var_ == op->buffer_var.get()) {
return Load::make( return Load::make(
op->type, op->buffer_var, ir::Simplify(op->index - min_)); op->type, op->buffer_var, ir::Simplify(op->index - min_),
op->predicate);
} else { } else {
return expr; return expr;
} }
...@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator { ...@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op = stmt.as<Store>(); op = stmt.as<Store>();
if (!read_access_ && buf_var_ == op->buffer_var.get()) { if (!read_access_ && buf_var_ == op->buffer_var.get()) {
return Store::make( return Store::make(
op->buffer_var, op->value, ir::Simplify(op->index - min_)); op->buffer_var, op->value, ir::Simplify(op->index - min_),
op->predicate);
} else { } else {
return stmt; return stmt;
} }
......
...@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator { ...@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
Expr index = Mutate(op->index); Expr index = Mutate(op->index);
Stmt provide = Store::make( Stmt provide = Store::make(
ch->handle_var, ch->handle_var,
Load::make(op->type, op->buffer_var, index), 0); Load::make(op->type, op->buffer_var, index, op->predicate),
0, op->predicate);
Stmt temp = nest_.back(); nest_.pop_back(); Stmt temp = nest_.back(); nest_.pop_back();
stages_.emplace_back(BuildStage(provide, ch)); stages_.emplace_back(BuildStage(provide, ch));
nest_.push_back(temp); nest_.push_back(temp);
fifo_map_[ch->handle_var.get()] = ch; fifo_map_[ch->handle_var.get()] = ch;
return Load::make(op->type, ch->handle_var, 0); return Load::make(op->type, ch->handle_var, 0, op->predicate);
} }
Stmt Split(Stmt stmt, const ProducerConsumer* env) { Stmt Split(Stmt stmt, const ProducerConsumer* env) {
......
...@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator { ...@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
op = stmt.as<Store>(); op = stmt.as<Store>();
auto it = extern_buf_remap_.find(op->buffer_var.get()); auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) { if (it != extern_buf_remap_.end()) {
return Store::make(it->second, op->value, op->index); return Store::make(it->second, op->value, op->index, op->predicate);
} else { } else {
return stmt; return stmt;
} }
...@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator { ...@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
op = expr.as<Load>(); op = expr.as<Load>();
auto it = extern_buf_remap_.find(op->buffer_var.get()); auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) { if (it != extern_buf_remap_.end()) {
return Load::make(op->type, it->second, op->index); return Load::make(op->type, it->second, op->index, op->predicate);
} else { } else {
return expr; return expr;
} }
......
...@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator { ...@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
op = stmt.as<Store>(); op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get()); auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt; if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var, op->value, op->index); return Store::make(it->second->alloc_var, op->value, op->index, op->predicate);
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>(); op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get()); auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr; if (it == alloc_map_.end()) return expr;
return Load::make(op->type, it->second->alloc_var, op->index); return Load::make(op->type, it->second->alloc_var, op->index, op->predicate);
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op); auto it = alloc_map_.find(op);
......
...@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor { ...@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
} }
} }
void Visit_(const Call* op) final { void Visit_(const Call* op) final {
if (op->is_intrinsic(Call::address_of)) { if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l); IRVisitor::Visit_(l);
} else { } else {
......
...@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator { ...@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
op = expr.as<Load>(); op = expr.as<Load>();
if (op->buffer_var.get() == buf_) { if (op->buffer_var.get() == buf_) {
return Load::make(op->type, op->buffer_var, return Load::make(op->type, op->buffer_var,
op->index * var_lanes_ + var_); op->index * var_lanes_ + var_,
op->predicate);
} else { } else {
return expr; return expr;
} }
...@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator { ...@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
if (op->buffer_var.get() == buf_) { if (op->buffer_var.get() == buf_) {
return Store::make(op->buffer_var, return Store::make(op->buffer_var,
op->value, op->value,
op->index * var_lanes_ + var_); op->index * var_lanes_ + var_,
op->predicate);
} else { } else {
return stmt; return stmt;
} }
...@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator { ...@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
// Load // Load
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) { Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e; return e;
} else { } else {
return Load::make(op->type.with_lanes(index.type().lanes()), int lanes = std::max(index.type().lanes(), pred.type().lanes());
op->buffer_var, index); return Load::make(
op->type.with_lanes(lanes),
op->buffer_var,
BroadcastTo(index, lanes),
BroadcastTo(pred, lanes));
} }
} }
// Let // Let
...@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator { ...@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index)) { if (value.same_as(op->value) && index.same_as(op->index)) {
return s; return s;
} else { } else {
int lanes = std::max(value.type().lanes(), index.type().lanes()); int lanes = std::max(value.type().lanes(), index.type().lanes());
lanes = std::max(lanes, pred.type().lanes());
return Store::make(op->buffer_var, return Store::make(op->buffer_var,
BroadcastTo(value, lanes), BroadcastTo(value, lanes),
BroadcastTo(index, lanes)); BroadcastTo(index, lanes),
BroadcastTo(pred, lanes));
} }
} }
// For // For
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <pass/CSE.h>
TEST(IR_PASS, CSE) {
using namespace Halide::Internal;
cse_test();
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <pass/Simplify.h> #include <arithmetic/Simplify.h>
TEST(IRSIMPLIFY, Basic) { TEST(IRSIMPLIFY, Basic) {
using namespace Halide::Internal; using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment