lower_intrin.cc 10.5 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 *  Lower intrinsic calls and ops to device specific ir when possible.
22 23
 * \file lower_intrin.cc
 */
24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
26
#include <tvm/tir/transform.h>
27 28
#include <tvm/runtime/registry.h>

29
#include <tvm/tir/op.h>
30
#include <tvm/target/target.h>
31
#include <unordered_set>
32 33
#include "../../arith/pattern_match.h"
#include "../../arith/ir_mutator_with_analyzer.h"
34 35

namespace tvm {
36
namespace tir {
37

38
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
39
 public:
40 41
  using IRMutatorWithAnalyzer::VisitStmt_;
  using IRMutatorWithAnalyzer::VisitExpr_;
42

43
  IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
44
      : IRMutatorWithAnalyzer(analyzer) {
45
    patterns_.push_back("tvm.intrin.rule." + target_name + ".");
46
    patterns_.push_back("tvm.intrin.rule.default.");
47
    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
48
    if (target_name == "stackvm") {
49 50
      support_bitwise_op_ = false;
    }
51 52
  }

53
  PrimExpr VisitExpr_(const CallNode* op) final {
54 55
    if (op->call_type == CallNode::Intrinsic ||
        op->call_type == CallNode::PureIntrinsic) {
56
      PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
57
      if (r.defined()) return r;
58
    }
59
    return IRMutatorWithAnalyzer::VisitExpr_(op);
60 61
  }

62
  PrimExpr VisitExpr_(const AddNode* op) final {
63
    if (const MulNode* mb = op->b.as<MulNode>()) {
64
      return MakeFMA(mb->a, mb->b, op->a, op);
65
    } else if (const MulNode* ma = op->a.as<MulNode>()) {
66
      return MakeFMA(ma->a, ma->b, op->b, op);
67
    }
68
    return IRMutatorWithAnalyzer::VisitExpr_(op);
69 70
  }

71 72
  // We use floordiv for integer analysis,
  // but will need to lower them to native truncdiv instructions
73 74 75
  PrimExpr VisitExpr_(const FloorDivNode* op) final {
    auto e = GetRef<PrimExpr>(op);
    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
76
    op = ret.as<FloorDivNode>();
77 78
    if (op == nullptr) return ret;
    int shift;
79
    const DataType& dtype = op->dtype;
80
    CHECK(dtype.is_int() || dtype.is_uint());
81

82 83
    if (support_bitwise_op_ &&
        is_const_power_of_two_integer(op->b, &shift)) {
84 85 86 87 88 89 90 91 92 93 94
      // lower to right shift if possible.
      return op->a >> make_const(dtype, shift);
    }

    if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
      // Common path, positive divisor
      if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
          analyzer_->CanProveGreaterEqual(e, 0)) {
        return truncdiv(op->a, op->b);
      } else {
        DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
95 96
        PrimExpr rdiv = truncdiv(op->a, op->b);
        PrimExpr rmod = truncmod(op->a, op->b);
97 98 99
        // condition on b >= 0.
        // truncmod(a, b) < 0 will implies ceildiv,
        // So we need to correct these cases.
100
        if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
101 102 103
          // equivalent to rdiv + (rmod >= 0 ? 0: -1);
          return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
        } else {
104
          return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
105 106 107 108 109 110 111
        }
      }
    } else {
      // uncommon case
      DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
      // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
      // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
112 113
      PrimExpr rdiv = truncdiv(op->a, op->b);
      PrimExpr rmod = truncmod(op->a, op->b);
114
      return tir::SelectNode::make(
115 116 117 118 119
          (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
          rdiv, rdiv - make_const(dtype, 1));
    }
  }

120 121
  PrimExpr VisitExpr_(const FloorModNode* op) final {
    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
122
    op = ret.as<FloorModNode>();
123 124 125
    if (op == nullptr) return ret;
    // Lower floordiv to native truncdiv.
    int shift;
126
    const DataType& dtype = op->dtype;
127
    CHECK(dtype.is_int() || dtype.is_uint());
128

129 130
    if (support_bitwise_op_ &&
        is_const_power_of_two_integer(op->b, &shift)) {
131 132 133 134 135 136 137 138
      // lower to masking if possible.
      int64_t mask = (
          static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
      return op->a & make_const(dtype, mask);
    }

    if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
      // Common pass, positive divisor
139
      if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
140 141 142 143 144 145
        return truncmod(op->a, op->b);
      } else {
        DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
        // NOTE:condition on b >= 0.
        // mod(a, b) < 0 will imply we are doing ceildiv,
        // So we need to correct these cases.
146
        PrimExpr rmod = truncmod(op->a, op->b);
147
        if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
148 149 150 151 152
          // (rmod >> shift) & b
          // -> (rmod >= 0 ? 0: -1) & b
          // -> rmod >= 0 ? 0 : b
          return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
        } else {
153
          return tir::SelectNode::make(rmod >= 0, rmod, rmod + op->b);
154 155 156 157 158
        }
      }
    } else {
      // uncommon case
      DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
159
      PrimExpr rmod = truncmod(op->a, op->b);
160 161 162 163
      // b > 0 && rmod >= 0 -> rmod
      // b > 0 && rmod < 0  -> rmod + b
      // b < 0 && rmod < 0 -> rmod
      // b < 0 && rmod > 0 -> rmod + b
164
      return tir::SelectNode::make(
165 166 167 168 169
          (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
          rmod, rmod + op->b);
    }
  }

170
  PrimExpr VisitExpr_(const MaxNode* op) final {
171
    using namespace arith;
172
    PVar<PrimExpr> x, y;
173
    PVar<IntImm> c;
174
    auto e = GetRef<PrimExpr>(op);
175 176 177
    if (max(floordiv(x, y), c).Match(e) &&
        c.Eval()->value >= 0 &&
        analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
178
      return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
179
    }
180
    return IRMutatorWithAnalyzer::VisitExpr_(op);
181 182
  }

183
  PrimExpr VisitExpr_(const EQNode* op) final {
184
    using namespace arith;
185 186
    PVar<PrimExpr> x, y;
    auto e = GetRef<PrimExpr>(op);
187
    if ((floormod(x, y) == 0).Match(e)) {
188
      return VisitExpr((truncmod(x, y) == 0).Eval());
189
    }
190
    return IRMutatorWithAnalyzer::VisitExpr_(op);
191 192
  }

193
  PrimExpr VisitExpr_(const NENode* op) final {
194
    using namespace arith;
195 196
    PVar<PrimExpr> x, y;
    auto e = GetRef<PrimExpr>(op);
197
    if ((floormod(x, y) != 0).Match(e)) {
198
      return VisitExpr((truncmod(x, y) != 0).Eval());
199
    }
200
    return IRMutatorWithAnalyzer::VisitExpr_(op);
201 202
  }

203
 private:
204
  PrimExpr SwapBroadcastCast(const PrimExpr& e) {
205 206 207 208
    // Try to change broadcast(cast(x)) to cast(broadcast(x))
    // For some targets, LLVM will generate more efficient FMA
    // instruction with the latter. For example, vmla vs. vmlal
    // on ARM.
209 210
    if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
      if (const CastNode* cast = bcast->value.as<CastNode>()) {
211 212
        auto should_swap = [&]() {
          // Maintain behaviour (int8 -> int16, fp16 -> fp32).
213
          if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
214 215 216
            return true;
          }
          // Check both operands are integer-like.
217
          if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
218 219
            return false;
          }
220
          if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
221 222 223
            return false;
          }
          // If both are integer-like, swap if we have a widening cast.
224
          return cast->dtype.bits() > cast->value.dtype().bits();
225 226 227
        };

        if (should_swap()) {
228
          PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
229
          return CastNode::make(bcast->dtype, new_bcast);
230 231 232 233 234 235
        }
      }
    }
    return e;
  }

236
  PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c,
237
               const AddNode* op) {
238
    // emit fma instruction: a * b + c
239 240
    PrimExpr lhs = SwapBroadcastCast(a);
    PrimExpr rhs = SwapBroadcastCast(b);
241

242
    if (fma_ != nullptr && op->dtype.is_float()) {
243
      PrimExpr r = (*fma_)(CallNode::make(
244
          op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
245
      if (r.defined()) return this->VisitExpr(r);
246 247
    } else {
      if (!lhs.same_as(a) || !rhs.same_as(b)) {
248
        PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs));
249
        return AddNode::make(mul, this->VisitExpr(c));
250
      }
251
    }
252
    return IRMutatorWithAnalyzer::VisitExpr_(op);
253 254
  }

255
  PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) {
256 257 258 259 260 261 262 263 264
    for (size_t i = 0; i < patterns_.size(); ++i) {
      std::string& p = patterns_[i];
      size_t psize = p.length();
      p.resize(psize + name.length());
      name.copy(&p[0] + psize, name.length());
      const runtime::PackedFunc* f = runtime::Registry::Get(p);
      p.resize(psize);
      // if pattern exists.
      if (f != nullptr) {
265
        PrimExpr r = (*f)(e);
266 267
        CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
        if (!r.same_as(e)) {
268
          return this->VisitExpr(r);
269 270 271
        }
      }
    }
272
    return PrimExpr();
273
  }
274

275
  // patterns
276
  std::vector<std::string> patterns_;
277
  const PackedFunc* fma_{nullptr};
278
  bool support_bitwise_op_{true};
279 280
};

281
Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
282
  arith::Analyzer analyzer;
283
  return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
284 285
}

286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
namespace transform {

Pass LowerIntrin() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
    CHECK(target.defined())
        << "LowerIntrin: Require the target attribute";
    arith::Analyzer analyzer;
    n->body =
        IntrinInjecter(&analyzer, target->target_name)(std::move(n->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin")
.set_body_typed(LowerIntrin);

}  // namespace transform
306

307
}  // namespace tir
308
}  // namespace tvm