/*! * Copyright (c) 2017 by Contributors * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ #include <tvm/ir.h> #include <tvm/ir_functor_ext.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> namespace tvm { namespace ir { // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. bool VisitExpr_(const Select* op) { return VisitExpr(op->condition); } bool VisitExpr_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load* l = op->args[0].as<Load>(); return this->VisitExpr(l->index); } else if (op->is_pure()) { for (Expr e : op->args) { if (VisitExpr(e)) return true; } return false; } else { return true; } } bool VisitExpr_(const Load* op) { // Load is considered unsafe. return true; } bool VisitExpr_(const Add* op) final { return BinaryOp(op); } bool VisitExpr_(const Sub* op) final { return BinaryOp(op); } bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } bool VisitExpr_(const Div* op) final { return BinaryOp(op); } bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } bool VisitExpr_(const Min* op) final { return BinaryOp(op); } bool VisitExpr_(const Max* op) final { return BinaryOp(op); } bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } bool VisitExpr_(const NE* op) final { return BinaryOp(op); } bool VisitExpr_(const LT* op) final { return BinaryOp(op); } bool VisitExpr_(const LE* op) final { return BinaryOp(op); } bool VisitExpr_(const GT* op) final { return BinaryOp(op); } bool VisitExpr_(const GE* op) final { return BinaryOp(op); } bool VisitExpr_(const And* op) final { return BinaryOp(op); } bool VisitExpr_(const Or* op) final { return BinaryOp(op); } bool VisitExpr_(const Not* op) final { return VisitExpr(op->a); } bool VisitExpr_(const Let* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } bool VisitExpr_(const Cast* op) final { return VisitExpr(op->value); } bool VisitExpr_(const Broadcast* op) final { return VisitExpr(op->value); } bool VisitExpr_(const Ramp* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } bool VisitExpr_(const Shuffle* op) final { for (Expr e : op->vectors) { if (VisitExpr(e)) return true; } return false; } bool VisitExpr_(const Variable* op) final { return false; } bool VisitExpr_(const UIntImm* op) final { return false; } bool VisitExpr_(const IntImm* op) final { return false; } bool VisitExpr_(const FloatImm* op) final { return false; } bool VisitExpr_(const StringImm* op) final { return false; } private: template<typename T> bool BinaryOp(const T* op) { return VisitExpr(op->a) || VisitExpr(op->b); } }; class UnsafeSelectRewriter : public IRMutator { public: Expr Mutate_(const Select* op, const Expr& e) { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<Select>(); UnsafeExprDetector unsafe; if (unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) { return Call::make( op->type, intrinsic::tvm_if_then_else, {op->condition, op->true_value, op->false_value}, Call::Intrinsic); } else { return expr; } } }; Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter().Mutate(stmt); } } // namespace ir } // namespace tvm