Commit 330d49f8 by Tianqi Chen Committed by GitHub

[IR] Update new version of HalideIR (#116)

parent d3c8256b
Subproject commit 398edacd956c6de82185821ffd9f482598182e51
Subproject commit 4fffc62c124651c1cde18f31957db413b677d601
......@@ -174,6 +174,14 @@ namespace intrinsic {
/*!
* \brief See pesudo code
*
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field;
* }
......@@ -355,6 +363,7 @@ using Halide::Internal::Realize;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
using Halide::Internal::Shuffle;
// ir functions
using Halide::Internal::is_const_power_of_two_integer;
......
......@@ -98,6 +98,7 @@ class IRMutator {
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};
} // namespace ir
......
......@@ -10,7 +10,7 @@
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <pass/Simplify.h>
#include <arithmetic/Simplify.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
......
......@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
args[5]);
});
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Type t = args[0];
if (args.size() == 3) {
*ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = Load::make(t, args[1], args[2], args[3]);
}
});
TVM_REGISTER_API("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1];
if (args.size() == 3) {
*ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
} else {
*ret = Store::make(args[0], value, args[2], args[3]);
}
});
TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args[0],
......@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
args[5]);
});
TVM_REGISTER_API("make.Allocate")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Allocate::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
......@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_MAKE5(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
......@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Load);
REGISTER_MAKE3(Store);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
......
......@@ -8,7 +8,7 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
#include <arithmetic/Interval.h>
#include <limits>
namespace tvm {
......
......@@ -6,7 +6,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <pass/Interval.h>
#include <arithmetic/Interval.h>
#include <unordered_map>
#include "./compute_expr.h"
#include "./int_set_internal.h"
......
......@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc(op, " << ", os, this);
} else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, this);
} else if (op->is_intrinsic(Call::address_of)) {
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
os << "((";
......@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
os << ref;
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
......@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value);
......
......@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
}
} else if (op->is_intrinsic(Call::address_of)) {
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
return CreateBufferPtr(
......@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
} else if (op->is_intrinsic(Call::null_handle)) {
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else {
LOG(FATAL) << "Unknown intrinstic " << op->name;
......@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
Type t = op->type;
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
......@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, first_shift)),
make_const(bt, 1), ramp->lanes)));
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
llvm::Value* next = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, ramp->lanes + next_shift)),
make_const(bt, 1), ramp->lanes)));
make_const(bt, 1), ramp->lanes),
const_true(t.lanes())));
// shuffle
std::vector<llvm::Constant*> indices;
int target_index = 0;
......@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
make_const(ramp->base.type(), 1),
lanes);
// load value then flip
llvm::Value* v = MakeValue(Load::make(t, op->buffer_var, neg_ramp));
llvm::Value* v = MakeValue(
Load::make(t, op->buffer_var, neg_ramp, const_true(t.lanes())));
return CreateVecFlip(v);
} else {
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
......@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate))
<< "Predicated Load is not supported";
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
const Ramp* ramp = op->index.as<Ramp>();
......
......@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
}
void CodeGenStackVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(Call::address_of)) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
......@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(Call::null_handle)) {
this->PushOp(StackVM::PUSH_I64, 0);
} else if (op->is_intrinsic(Call::reinterpret)) {
this->Push(op->args[0]);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
......
......@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
if (is_zero(op->index) && load) {
compute->body = Store::make(
op->buffer_var,
Load::make(load->type, load->buffer_var, repl.Mutate(load->index)),
op->index);
Load::make(load->type, load->buffer_var,
repl.Mutate(load->index), op->predicate),
op->index, op->predicate);
} else {
compute->body = Store::make(
op->buffer_var, repl.Mutate(op->value), repl.Mutate(op->index));
op->buffer_var, repl.Mutate(op->value),
repl.Mutate(op->index), op->predicate);
}
compute->inputs = repl.inputs_;
pipeline_->stages.push_back(ComputeBlock(compute));
......
......@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->();
return ir::Load::make(n->dtype, n->data, BufferOffset(n, index));
return ir::Load::make(
n->dtype, n->data, BufferOffset(n, index),
const_true(n->dtype.lanes()));
}
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->data, value, BufferOffset(n, index));
return ir::Store::make(n->data, value, BufferOffset(n, index),
const_true(n->dtype.lanes()));
}
Buffer BufferNode::make(std::string name,
......
......@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
}
}
}
Type t = reduce->type;
Expr pred = const_true(t.lanes());
Stmt reduce_body = Store::make(res_handle,
Call::make(
reduce->type,
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic),
0);
0, pred);
reduce_body = AttrStmt::make(
reduce->combiner,
attr::reduce_scope,
make_zero(reduce->type),
reduce_body);
Stmt assign_body = Provide::make(
stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Allocate::make(
......
......@@ -152,12 +152,8 @@ class VTInjector : public IRMutator {
return e;
}
Expr RewriteIndex(Expr index, Expr alloc_extent) const {
if (index_rewrite_strategy_ == 0) {
return index * num_threads_ + var_;
} else {
return index + var_ * alloc_extent;
}
}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
......@@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
auto it = touched_alloc_.find(op->buffer_var.get());
if (it != touched_alloc_.end()) {
return Load::make(op->type, op->buffer_var,
RewriteIndex(op->index, it->second));
RewriteIndex(op->index, it->second),
op->predicate);
} else {
return expr;
}
......@@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
if (it != touched_alloc_.end()) {
return Store::make(op->buffer_var,
op->value,
RewriteIndex(op->index, it->second));
RewriteIndex(op->index, it->second),
op->predicate);
} else {
return stmt;
}
......@@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
for (size_t i = 1; i < extents.size(); ++i) {
stride = arith::ComputeExpr<Mul>(stride, extents[i]);
}
if (op->type.lanes() != 0) {
stride = stride * op->type.lanes();
}
Array<Expr> other;
other.push_back(num_threads_);
for (Expr e : extents) {
......@@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
Var var_;
// the threads/lanes
int num_threads_;
// Index rewriting strategy
int index_rewrite_strategy_{1};
// whethe the loop is already injected.
bool vt_loop_injected_{false};
// whether current expression get touched.
......
......@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
return Store::make(op->buffer_var, value, index, pred);
}
}
......@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) {
Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
return Load::make(op->type, op->buffer_var, index, pred);
}
}
......@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
}
}
Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) {
return e;
} else {
return Shuffle::make(new_vec, op->indices);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
......@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
.DISPATCH_TO_MUTATE_EXPR(StringImm);
.DISPATCH_TO_MUTATE_EXPR(StringImm)
.DISPATCH_TO_MUTATE_EXPR(Shuffle);
} // namespace ir
} // namespace tvm
......@@ -111,8 +111,10 @@ inline Expr TVMStructGet(
*/
inline Expr AddressOffset(Var handle, Type dtype, int offset) {
return Call::make(
Handle(), Call::address_of,
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic);
Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, make_const(Int(32), offset * dtype.lanes()),
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
/*!
......
......@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
void IRVisitor::Visit_(const Load *op) {
this->Visit(op->index);
this->Visit(op->predicate);
}
void IRVisitor::Visit_(const Store *op) {
this->Visit(op->value);
this->Visit(op->index);
this->Visit(op->predicate);
}
void IRVisitor::Visit_(const IfThenElse *op) {
......
......@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
ConstInt32(stack_begin +i)));
ConstInt32(stack_begin +i), const_true(1)));
}
return AddressOffset(stack_shape_, Int(64), stack_begin);
}
......@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
prep_seq_.emplace_back(
Store::make(stack_tcode_,
ConstInt32(arg_tcode),
stack_index));
stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
......
......@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
int threadx_extent = 1;
Expr reduce_index = FlattenThread(vred, &reduce_extent);
Expr group_index = FlattenThread(vpar, &group_extent);
Expr pred = const_true(value.type().lanes());
if (reduce_extent == 1) {
// special case, no reduction is needed.
return Store::make(op->buffer_var, value, 0);
return Store::make(op->buffer_var, value, 0, pred);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
......@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
std::vector<Stmt> seq;
seq.emplace_back(Store::make(
shared_buf, value,
BufIndex(reduce_index, group_index, reduce_extent)));
BufIndex(reduce_index, group_index, reduce_extent), pred));
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf,
......@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
load_remap_[op->buffer_var.get()] =
Load::make(
value.type(), shared_buf,
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent));
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
pred);
alloc_remap_[op->buffer_var.get()] =
Allocate::make(shared_buf, value.type(),
{Expr(group_extent), Expr(reduce_extent)},
const_true(), Evaluate::make(0));
pred, Evaluate::make(0));
return MergeSeq(seq);
}
// make allreduce.
......@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
auto freduce = [&](int offset) {
Expr b = Load::make(
type, shared_buf,
BufIndex(reduce_index + offset, group_index, reduce_extent));
Expr a = Load::make(type, shared_buf, buf_index);
return Store::make(shared_buf, (*combiner)(a, b), buf_index);
BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
Expr a = Load::make(type, shared_buf, buf_index, const_true());
return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
};
// Step one, check for
if (reduce_align > reduce_extent) {
......
......@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
Var tcode(v_arg->name_hint + ".code", Int(32));
seq_init.emplace_back(LetStmt::make(
tcode, Load::make(
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop));
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
nop));
Type t = v_arg.type();
if (t.is_handle()) {
std::ostringstream msg;
......@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k))),
IntImm::make(Int(32), k), const_true(1))),
field_name.str());
}
// strides field
......@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k))),
IntImm::make(Int(32), k), const_true(1))),
field_name.str());
}
}
......
......@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op = expr.as<Load>();
if (read_access_ && buf_var_ == op->buffer_var.get()) {
return Load::make(
op->type, op->buffer_var, ir::Simplify(op->index - min_));
op->type, op->buffer_var, ir::Simplify(op->index - min_),
op->predicate);
} else {
return expr;
}
......@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op = stmt.as<Store>();
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
return Store::make(
op->buffer_var, op->value, ir::Simplify(op->index - min_));
op->buffer_var, op->value, ir::Simplify(op->index - min_),
op->predicate);
} else {
return stmt;
}
......
......@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
Expr index = Mutate(op->index);
Stmt provide = Store::make(
ch->handle_var,
Load::make(op->type, op->buffer_var, index), 0);
Load::make(op->type, op->buffer_var, index, op->predicate),
0, op->predicate);
Stmt temp = nest_.back(); nest_.pop_back();
stages_.emplace_back(BuildStage(provide, ch));
nest_.push_back(temp);
fifo_map_[ch->handle_var.get()] = ch;
return Load::make(op->type, ch->handle_var, 0);
return Load::make(op->type, ch->handle_var, 0, op->predicate);
}
Stmt Split(Stmt stmt, const ProducerConsumer* env) {
......
......@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
op = stmt.as<Store>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Store::make(it->second, op->value, op->index);
return Store::make(it->second, op->value, op->index, op->predicate);
} else {
return stmt;
}
......@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
op = expr.as<Load>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Load::make(op->type, it->second, op->index);
return Load::make(op->type, it->second, op->index, op->predicate);
} else {
return expr;
}
......
......@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var, op->value, op->index);
return Store::make(it->second->alloc_var, op->value, op->index, op->predicate);
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
return Load::make(op->type, it->second->alloc_var, op->index);
return Load::make(op->type, it->second->alloc_var, op->index, op->predicate);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op);
......
......@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
}
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(Call::address_of)) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
} else {
......
......@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
op = expr.as<Load>();
if (op->buffer_var.get() == buf_) {
return Load::make(op->type, op->buffer_var,
op->index * var_lanes_ + var_);
op->index * var_lanes_ + var_,
op->predicate);
} else {
return expr;
}
......@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
if (op->buffer_var.get() == buf_) {
return Store::make(op->buffer_var,
op->value,
op->index * var_lanes_ + var_);
op->index * var_lanes_ + var_,
op->predicate);
} else {
return stmt;
}
......@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) {
Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
} else {
return Load::make(op->type.with_lanes(index.type().lanes()),
op->buffer_var, index);
int lanes = std::max(index.type().lanes(), pred.type().lanes());
return Load::make(
op->type.with_lanes(lanes),
op->buffer_var,
BroadcastTo(index, lanes),
BroadcastTo(pred, lanes));
}
}
// Let
......@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
Stmt Mutate_(const Store* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
int lanes = std::max(value.type().lanes(), index.type().lanes());
lanes = std::max(lanes, pred.type().lanes());
return Store::make(op->buffer_var,
BroadcastTo(value, lanes),
BroadcastTo(index, lanes));
BroadcastTo(index, lanes),
BroadcastTo(pred, lanes));
}
}
// For
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <pass/CSE.h>
TEST(IR_PASS, CSE) {
using namespace Halide::Internal;
cse_test();
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <pass/Simplify.h>
#include <arithmetic/Simplify.h>
TEST(IRSIMPLIFY, Basic) {
using namespace Halide::Internal;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment