Commit 090468aa by Tianqi Chen Committed by GitHub

[PASS] RewriteUnsafeSelect lowers unsafe select to condition expr (#335)

parent 25ded693
...@@ -218,6 +218,14 @@ namespace intrinsic { ...@@ -218,6 +218,14 @@ namespace intrinsic {
*/ */
constexpr const char* tvm_address_of = "tvm_address_of"; constexpr const char* tvm_address_of = "tvm_address_of";
/*! /*!
* \brief Same as select, used for unsafe memory access.
*
* Type tvm_if_then_else(cond, a, b) {
* return cond ? a : b;
* }
*/
constexpr const char* tvm_if_then_else = "tvm_if_then_else";
/*!
* \brief Get head access address with memory access pattern info. * \brief Get head access address with memory access pattern info.
* *
* This operator also marks range of the memory access * This operator also marks range of the memory access
......
...@@ -267,6 +267,13 @@ Stmt CoProcSync(Stmt stmt); ...@@ -267,6 +267,13 @@ Stmt CoProcSync(Stmt stmt);
Stmt LiftAttrScope(Stmt stmt, std::string attr_key); Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*! /*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten.
* \return Transformed stmt.
*/
Stmt RewriteUnsafeSelect(Stmt stmt);
/*!
* \brief Lower attached storage access information. * \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish. * Do this pass after all storage access analysis finish.
* *
......
...@@ -211,6 +211,7 @@ def lower(sch, ...@@ -211,6 +211,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
...@@ -85,6 +85,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -85,6 +85,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS3(StorageFlatten); REGISTER_PASS3(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
......
...@@ -482,6 +482,14 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -482,6 +482,14 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc(op, " << ", os, this); PrintBinaryIntrinsitc(op, " << ", os, this);
} else if (op->is_intrinsic(Call::shift_right)) { } else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, this); PrintBinaryIntrinsitc(op, " >> ", os, this);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintExpr(op->args[0], os);
os << " ? ";
PrintExpr(op->args[1], os);
os << " : ";
PrintExpr(op->args[2], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
......
...@@ -1028,6 +1028,31 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -1028,6 +1028,31 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
llvm::Value* ptr = MakeValue(op->args[0]); llvm::Value* ptr = MakeValue(op->args[0]);
return builder_->CreateICmpEQ( return builder_->CreateICmpEQ(
ptr, llvm::Constant::getNullValue(ptr->getType())); ptr, llvm::Constant::getNullValue(ptr->getType()));
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 3U);
llvm::Value* cond = MakeValue(op->args[0]);
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
builder_->CreateCondBr(cond, then_block, else_block);
// Then
builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
// else
llvm::Value* else_value = MakeValue(op->args[2]);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(end_block);
// phi
llvm::PHINode* phi = builder_->CreatePHI(then_value->getType(), 2);
phi->addIncoming(then_value, then_block);
phi->addIncoming(else_value, else_block);
return phi;
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value; int kind = op->args[2].as<IntImm>()->value;
......
/*!
* Copyright (c) 2017 by Contributors
* \file unsafe_select_rewrite.cc
* \brief Rewrite uinsafe select expression.
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
public:
// select itself is always considered safe if condition is safe
// Because we will issue guard to make sure it is.
bool VisitExpr_(const Select* op) {
return VisitExpr(op->condition);
}
bool VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
return VisitExpr(op->args[0]);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load* l = op->args[0].as<Load>();
return this->VisitExpr(l->index);
} else if (op->is_pure()) {
for (Expr e : op->args) {
if (VisitExpr(e)) return true;
}
return false;
} else {
return true;
}
}
bool VisitExpr_(const Load* op) {
// Load is considered unsafe.
return true;
}
bool VisitExpr_(const Add* op) final { return BinaryOp(op); }
bool VisitExpr_(const Sub* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
bool VisitExpr_(const NE* op) final { return BinaryOp(op); }
bool VisitExpr_(const LT* op) final { return BinaryOp(op); }
bool VisitExpr_(const LE* op) final { return BinaryOp(op); }
bool VisitExpr_(const GT* op) final { return BinaryOp(op); }
bool VisitExpr_(const GE* op) final { return BinaryOp(op); }
bool VisitExpr_(const And* op) final { return BinaryOp(op); }
bool VisitExpr_(const Or* op) final { return BinaryOp(op); }
bool VisitExpr_(const Not* op) final {
return VisitExpr(op->a);
}
bool VisitExpr_(const Let* op) final {
return VisitExpr(op->body) && VisitExpr(op->value);
}
bool VisitExpr_(const Cast* op) final {
return VisitExpr(op->value);
}
bool VisitExpr_(const Broadcast* op) final {
return VisitExpr(op->value);
}
bool VisitExpr_(const Ramp* op) final {
return VisitExpr(op->base) && VisitExpr(op->stride);
}
bool VisitExpr_(const Shuffle* op) final {
for (Expr e : op->vectors) {
if (VisitExpr(e)) return true;
}
return false;
}
bool VisitExpr_(const Variable* op) final { return false; }
bool VisitExpr_(const IntImm* op) final { return false; }
bool VisitExpr_(const FloatImm* op) final { return false; }
bool VisitExpr_(const StringImm* op) final { return false; }
private:
template<typename T>
bool BinaryOp(const T* op) {
return VisitExpr(op->a) && VisitExpr(op->b);
}
};
class UnsafeSelectRewriter : public IRMutator {
public:
Expr Mutate_(const Select* op, const Expr& e) {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Select>();
UnsafeExprDetector unsafe;
if (unsafe.VisitExpr(op->true_value) ||
unsafe.VisitExpr(op->false_value)) {
return Call::make(
op->type,
intrinsic::tvm_if_then_else,
{op->condition, op->true_value, op->false_value},
Call::Intrinsic);
} else {
return expr;
}
}
};
Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -201,7 +201,28 @@ def test_multiple_func(): ...@@ -201,7 +201,28 @@ def test_multiple_func():
check_llvm() check_llvm()
def test_llvm_select():
def check_llvm(n, offset):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n, ), name='A')
C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C')
s = tvm.create_schedule(C.op)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
f(a, c)
c_np = a.asnumpy()
c_np[:offset] = 0
np.testing.assert_allclose(c.asnumpy(), c_np)
check_llvm(64, 8)
if __name__ == "__main__": if __name__ == "__main__":
test_llvm_select()
test_llvm_vadd_pipeline() test_llvm_vadd_pipeline()
test_llvm_add_pipeline() test_llvm_add_pipeline()
test_llvm_intrin() test_llvm_intrin()
......
import tvm
def test_rewrite_select():
ib = tvm.ir_builder.create()
A = ib.allocate("float32", 100, name="A", scope="global")
i = tvm.var("i")
y = tvm.select(i > 1, A[i-1], 1.0)
yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value
z = tvm.select(tvm.select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
a = tvm.select(i>10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
assert isinstance(aa, tvm.expr.Select)
if __name__ == "__main__":
test_rewrite_select()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment