rewrite_unsafe_select.cc 4.82 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file unsafe_select_rewrite.cc
 * \brief Rewrite uinsafe select expression.
 */
24 25 26
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
27 28

namespace tvm {
29
namespace tir {
30 31 32 33


// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
34
class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
35 36 37
 public:
  // select itself is always considered safe if condition is safe
  // Because we will issue guard to make sure it is.
38
  bool VisitExpr_(const SelectNode* op) {
39 40
    return VisitExpr(op->condition);
  }
41
  bool VisitExpr_(const CallNode* op) {
42 43 44
    if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
      return VisitExpr(op->args[0]);
    } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
45
      const LoadNode* l = op->args[0].as<LoadNode>();
46 47
      return this->VisitExpr(l->index);
    } else if (op->is_pure()) {
48
      for (PrimExpr e : op->args) {
49 50 51 52 53 54 55
        if (VisitExpr(e)) return true;
      }
      return false;
    } else {
      return true;
    }
  }
56
  bool VisitExpr_(const LoadNode* op) {
57 58 59
    // Load is considered unsafe.
    return true;
  }
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const NENode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const LENode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const GENode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); }
  bool VisitExpr_(const NotNode* op) final {
78 79
    return VisitExpr(op->a);
  }
80
  bool VisitExpr_(const LetNode* op) final {
81
    return VisitExpr(op->body) || VisitExpr(op->value);
82
  }
83
  bool VisitExpr_(const CastNode* op) final {
84 85
    return VisitExpr(op->value);
  }
86
  bool VisitExpr_(const BroadcastNode* op) final {
87 88
    return VisitExpr(op->value);
  }
89
  bool VisitExpr_(const RampNode* op) final {
90 91
    return VisitExpr(op->base) && VisitExpr(op->stride);
  }
92
  bool VisitExpr_(const ShuffleNode* op) final {
93
    for (PrimExpr e : op->vectors) {
94 95 96 97
      if (VisitExpr(e)) return true;
    }
    return false;
  }
98 99 100 101
  bool VisitExpr_(const VarNode* op) final { return false; }
  bool VisitExpr_(const IntImmNode* op) final { return false; }
  bool VisitExpr_(const FloatImmNode* op) final { return false; }
  bool VisitExpr_(const StringImmNode* op) final { return false; }
102 103 104 105

 private:
  template<typename T>
  bool BinaryOp(const T* op) {
106
    return VisitExpr(op->a) || VisitExpr(op->b);
107 108 109
  }
};

110
class UnsafeSelectRewriter : public StmtExprMutator {
111
 public:
112 113
  PrimExpr VisitExpr_(const SelectNode* op) {
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
114
    op = expr.as<SelectNode>();
115
    UnsafeExprDetector unsafe;
116
    bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
117 118 119
    if ((unsafe.VisitExpr(op->true_value) ||
        unsafe.VisitExpr(op->false_value)) &&
        cond_is_scalar_bool) {
120
      return CallNode::make(
121
          op->dtype,
122 123
          intrinsic::tvm_if_then_else,
          {op->condition, op->true_value, op->false_value},
124
          CallNode::Intrinsic);
125 126 127 128 129 130 131
    } else {
      return expr;
    }
  }
};

Stmt RewriteUnsafeSelect(Stmt stmt) {
132
  return UnsafeSelectRewriter()(std::move(stmt));
133 134
}

135
}  // namespace tir
136
}  // namespace tvm