lower_intrin.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 *  Lower intrinsic calls and ops to device specific ir when possible.
22 23
 * \file lower_intrin.cc
 */
24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
26 27
#include <tvm/runtime/registry.h>

28
#include <tvm/tir/op.h>
29
#include <unordered_set>
30
#include "ir_util.h"
31 32
#include "../../arith/pattern_match.h"
#include "../../arith/ir_mutator_with_analyzer.h"
33 34

namespace tvm {
35
namespace tir {
36

37
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
38
 public:
39 40
  using IRMutatorWithAnalyzer::VisitStmt_;
  using IRMutatorWithAnalyzer::VisitExpr_;
41 42 43

  IntrinInjecter(arith::Analyzer* analyzer, std::string target)
      : IRMutatorWithAnalyzer(analyzer) {
44 45 46 47
    std::istringstream is(target);
    std::string starget;
    is >> starget;
    patterns_.push_back("tvm.intrin.rule." + starget + ".");
48
    patterns_.push_back("tvm.intrin.rule.default.");
49
    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
50 51 52
    if (target == "stackvm") {
      support_bitwise_op_ = false;
    }
53 54
  }

55
  PrimExpr VisitExpr_(const CallNode* op) final {
56 57
    if (op->call_type == CallNode::Intrinsic ||
        op->call_type == CallNode::PureIntrinsic) {
58
      PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
59
      if (r.defined()) return r;
60
    }
61
    return IRMutatorWithAnalyzer::VisitExpr_(op);
62 63
  }

64
  PrimExpr VisitExpr_(const AddNode* op) final {
65
    if (const MulNode* mb = op->b.as<MulNode>()) {
66
      return MakeFMA(mb->a, mb->b, op->a, op);
67
    } else if (const MulNode* ma = op->a.as<MulNode>()) {
68
      return MakeFMA(ma->a, ma->b, op->b, op);
69
    }
70
    return IRMutatorWithAnalyzer::VisitExpr_(op);
71 72
  }

73 74
  // We use floordiv for integer analysis,
  // but will need to lower them to native truncdiv instructions
75 76 77
  PrimExpr VisitExpr_(const FloorDivNode* op) final {
    auto e = GetRef<PrimExpr>(op);
    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
78
    op = ret.as<FloorDivNode>();
79 80
    if (op == nullptr) return ret;
    int shift;
81
    const DataType& dtype = op->dtype;
82
    CHECK(dtype.is_int() || dtype.is_uint());
83

84 85
    if (support_bitwise_op_ &&
        is_const_power_of_two_integer(op->b, &shift)) {
86 87 88 89 90 91 92 93 94 95 96
      // lower to right shift if possible.
      return op->a >> make_const(dtype, shift);
    }

    if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
      // Common path, positive divisor
      if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
          analyzer_->CanProveGreaterEqual(e, 0)) {
        return truncdiv(op->a, op->b);
      } else {
        DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
97 98
        PrimExpr rdiv = truncdiv(op->a, op->b);
        PrimExpr rmod = truncmod(op->a, op->b);
99 100 101
        // condition on b >= 0.
        // truncmod(a, b) < 0 will implies ceildiv,
        // So we need to correct these cases.
102
        if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
103 104 105
          // equivalent to rdiv + (rmod >= 0 ? 0: -1);
          return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
        } else {
106
          return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
107 108 109 110 111 112 113
        }
      }
    } else {
      // uncommon case
      DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
      // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
      // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
114 115
      PrimExpr rdiv = truncdiv(op->a, op->b);
      PrimExpr rmod = truncmod(op->a, op->b);
116
      return tir::SelectNode::make(
117 118 119 120 121
          (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
          rdiv, rdiv - make_const(dtype, 1));
    }
  }

122 123
  PrimExpr VisitExpr_(const FloorModNode* op) final {
    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
124
    op = ret.as<FloorModNode>();
125 126 127
    if (op == nullptr) return ret;
    // Lower floordiv to native truncdiv.
    int shift;
128
    const DataType& dtype = op->dtype;
129
    CHECK(dtype.is_int() || dtype.is_uint());
130

131 132
    if (support_bitwise_op_ &&
        is_const_power_of_two_integer(op->b, &shift)) {
133 134 135 136 137 138 139 140
      // lower to masking if possible.
      int64_t mask = (
          static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
      return op->a & make_const(dtype, mask);
    }

    if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
      // Common pass, positive divisor
141
      if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
142 143 144 145 146 147
        return truncmod(op->a, op->b);
      } else {
        DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
        // NOTE:condition on b >= 0.
        // mod(a, b) < 0 will imply we are doing ceildiv,
        // So we need to correct these cases.
148
        PrimExpr rmod = truncmod(op->a, op->b);
149
        if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
150 151 152 153 154
          // (rmod >> shift) & b
          // -> (rmod >= 0 ? 0: -1) & b
          // -> rmod >= 0 ? 0 : b
          return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
        } else {
155
          return tir::SelectNode::make(rmod >= 0, rmod, rmod + op->b);
156 157 158 159 160
        }
      }
    } else {
      // uncommon case
      DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
161
      PrimExpr rmod = truncmod(op->a, op->b);
162 163 164 165
      // b > 0 && rmod >= 0 -> rmod
      // b > 0 && rmod < 0  -> rmod + b
      // b < 0 && rmod < 0 -> rmod
      // b < 0 && rmod > 0 -> rmod + b
166
      return tir::SelectNode::make(
167 168 169 170 171
          (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
          rmod, rmod + op->b);
    }
  }

172
  PrimExpr VisitExpr_(const MaxNode* op) final {
173
    using namespace arith;
174
    PVar<PrimExpr> x, y;
175
    PVar<IntImm> c;
176
    auto e = GetRef<PrimExpr>(op);
177 178 179
    if (max(floordiv(x, y), c).Match(e) &&
        c.Eval()->value >= 0 &&
        analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
180
      return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
181
    }
182
    return IRMutatorWithAnalyzer::VisitExpr_(op);
183 184
  }

185
  PrimExpr VisitExpr_(const EQNode* op) final {
186
    using namespace arith;
187 188
    PVar<PrimExpr> x, y;
    auto e = GetRef<PrimExpr>(op);
189
    if ((floormod(x, y) == 0).Match(e)) {
190
      return VisitExpr((truncmod(x, y) == 0).Eval());
191
    }
192
    return IRMutatorWithAnalyzer::VisitExpr_(op);
193 194
  }

195
  PrimExpr VisitExpr_(const NENode* op) final {
196
    using namespace arith;
197 198
    PVar<PrimExpr> x, y;
    auto e = GetRef<PrimExpr>(op);
199
    if ((floormod(x, y) != 0).Match(e)) {
200
      return VisitExpr((truncmod(x, y) != 0).Eval());
201
    }
202
    return IRMutatorWithAnalyzer::VisitExpr_(op);
203 204
  }

205
 private:
206
  PrimExpr SwapBroadcastCast(const PrimExpr& e) {
207 208 209 210
    // Try to change broadcast(cast(x)) to cast(broadcast(x))
    // For some targets, LLVM will generate more efficient FMA
    // instruction with the latter. For example, vmla vs. vmlal
    // on ARM.
211 212
    if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
      if (const CastNode* cast = bcast->value.as<CastNode>()) {
213 214
        auto should_swap = [&]() {
          // Maintain behaviour (int8 -> int16, fp16 -> fp32).
215
          if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
216 217 218
            return true;
          }
          // Check both operands are integer-like.
219
          if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
220 221
            return false;
          }
222
          if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
223 224 225
            return false;
          }
          // If both are integer-like, swap if we have a widening cast.
226
          return cast->dtype.bits() > cast->value.dtype().bits();
227 228 229
        };

        if (should_swap()) {
230
          PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
231
          return CastNode::make(bcast->dtype, new_bcast);
232 233 234 235 236 237
        }
      }
    }
    return e;
  }

238
  PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c,
239
               const AddNode* op) {
240
    // emit fma instruction: a * b + c
241 242
    PrimExpr lhs = SwapBroadcastCast(a);
    PrimExpr rhs = SwapBroadcastCast(b);
243

244
    if (fma_ != nullptr && op->dtype.is_float()) {
245
      PrimExpr r = (*fma_)(CallNode::make(
246
          op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
247
      if (r.defined()) return this->VisitExpr(r);
248 249
    } else {
      if (!lhs.same_as(a) || !rhs.same_as(b)) {
250
        PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs));
251
        return AddNode::make(mul, this->VisitExpr(c));
252
      }
253
    }
254
    return IRMutatorWithAnalyzer::VisitExpr_(op);
255 256
  }

257
  PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) {
258 259 260 261 262 263 264 265 266
    for (size_t i = 0; i < patterns_.size(); ++i) {
      std::string& p = patterns_[i];
      size_t psize = p.length();
      p.resize(psize + name.length());
      name.copy(&p[0] + psize, name.length());
      const runtime::PackedFunc* f = runtime::Registry::Get(p);
      p.resize(psize);
      // if pattern exists.
      if (f != nullptr) {
267
        PrimExpr r = (*f)(e);
268 269
        CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
        if (!r.same_as(e)) {
270
          return this->VisitExpr(r);
271 272 273
        }
      }
    }
274
    return PrimExpr();
275
  }
276

277
  // patterns
278
  std::vector<std::string> patterns_;
279
  const PackedFunc* fma_{nullptr};
280
  bool support_bitwise_op_{true};
281 282
};

283 284
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
  arith::Analyzer analyzer;
285
  return IntrinInjecter(&analyzer, target)(std::move(stmt));
286 287
}

288 289
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
290
  auto n = make_object<LoweredFuncNode>(*f.operator->());
291
  n->body = LowerIntrinStmt(n->body, target);
292 293 294
  return LoweredFunc(n);
}

295
// Register the api only for test purposes
296
TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt")
297 298
.set_body_typed(LowerIntrinStmt);

299
}  // namespace tir
300
}  // namespace tvm