expr_operator.cc 16.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file expr_operator.cc
22 23 24
 */
#include <tvm/base.h>
#include <tvm/ir.h>
25
#include <tvm/expr_operator.h>
26
#include <cmath>
27 28
// Centralized header for constant folders.
#include "../arithmetic/const_fold.h"
29 30 31

namespace tvm {

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
// simple cast that only checks if type matches and cast
inline Expr SimpleCast(const Type& t, Expr value) {
  if (value.type() == t) return value;
  return ir::Cast::make(t, value);
}

// The public function with a quick checking path.
void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) {  // NOLINT(*)
  if (lhs.type() == rhs.type()) return;
  Type ltype = lhs.type();
  Type rtype = rhs.type();
  if (ltype.lanes() == 1 && rtype.lanes() != 1) {
    lhs = ir::Broadcast::make(lhs, rtype.lanes());
  } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
    rhs = ir::Broadcast::make(rhs, ltype.lanes());
  } else {
    CHECK(ltype.lanes() == rtype.lanes())
        << "Cannot match type " << ltype << " vs " << rtype;
  }
  if (lhs.type() == rhs.type()) return;
  // Only do very simple type coversion
  // int->float, int(32)->int(64)
  // require the types to be relatively consistent
  // This will the reduce amount code generated by operators
  // and also help user to find potential type conversion problems.
  if (!lhs.type().is_float() && rhs.type().is_float()) {
    // int->float
59
    lhs = cast(rhs.type(), lhs);
60 61
  } else if (lhs.type().is_float() && !rhs.type().is_float()) {
    // int->float
62
    rhs = cast(lhs.type(), rhs);
63 64 65 66
  } else if ((lhs.type().is_int() && rhs.type().is_int()) ||
             (lhs.type().is_uint() && rhs.type().is_uint())) {
    // promote int to higher bits
    if (lhs.type().bits() < rhs.type().bits()) {
67
      lhs = cast(rhs.type(), lhs);
68
    } else {
69
      rhs = cast(lhs.type(), rhs);
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    }
  } else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
             (lhs.type().is_uint() && rhs.type().is_int())) {
    int bits = std::max(lhs.type().bits(), rhs.type().bits());
    lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
    rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
  } else {
    LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
  }
}


template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
  if (val <= 0) return false;
  shift[0] = 0;
  while (val != 0) {
    if (val & 1) {
      return (val == 1);
    }
    ++shift[0];
    val = val >> 1;
  }
  return true;
}

bool is_const_power_of_two_integer(const Expr& x, int* shift) {
  if (const auto* op = x.as<ir::IntImm>()) {
    return ConstPowerHelper(op->value, shift);
  } else if (const auto* op = x.as<ir::UIntImm>()) {
    return ConstPowerHelper(op->value, shift);
  } else {
    return false;
  }
}

Expr cast(const Type& t, Expr value) {
  using ir::IntImm;
108
  using ir::UIntImm;
109
  using ir::FloatImm;
110 111 112 113 114
  if (value.type() == t) return value;
  // const fold IntImm as they are used in index computations
  if (t.lanes() == 1) {
    if (const IntImm* op = value.as<IntImm>()) {
      return make_const(t, op->value);
115 116
    } else if (const UIntImm* op = value.as<UIntImm>()) {
      return make_const(t, op->value);
117 118
    } else if (const FloatImm* op = value.as<FloatImm>()) {
      return make_const(t, op->value);
119 120 121 122 123 124 125 126 127
    }
    return ir::Cast::make(t, value);
  } else {
    if (value.type().lanes() == 1) {
      // manually unroll cast
      Type vtype = t.element_of();
      if (value.type() != vtype) {
        if (const IntImm* op = value.as<IntImm>()) {
          value = make_const(vtype, op->value);
128 129
        } else if (const UIntImm* op = value.as<UIntImm>()) {
          return make_const(t, op->value);
130 131
        } else if (const FloatImm* op = value.as<FloatImm>()) {
          value = make_const(vtype, op->value);
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        } else {
          value = ir::Cast::make(vtype, value);
        }
      }
      return ir::Broadcast::make(value, t.lanes());
    } else {
      CHECK(value.type().lanes() == t.lanes());
      return ir::Cast::make(t, value);
    }
  }
}

Expr reinterpret(const Type& t, Expr value) {
  if (value.type() == t) return value;
  return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}

Expr operator+(Expr a, Expr b) {
150 151 152
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Add>(a, b);
  if (ret.defined()) return ret;
153 154 155
  return ir::Add::make(a, b);
}

156
// negation
157 158
Expr operator-(Expr a) {
  using ir::IntImm;
159
  using ir::FloatImm;
160
  const IntImm* pa = a.as<IntImm>();
161 162 163
  const FloatImm* fa = a.as<FloatImm>();
  if (pa) return ir::IntImm::make(a.type(), -pa->value);
  if (fa) return ir::FloatImm::make(a.type(), -fa->value);
164 165 166 167
  return make_zero(a.type()) - a;
}

Expr operator-(Expr a, Expr b) {
168 169 170
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Sub>(a, b);
  if (ret.defined()) return ret;
171 172 173 174
  return ir::Sub::make(a, b);
}

Expr operator*(Expr a, Expr b) {
175 176 177
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Mul>(a, b);
  if (ret.defined()) return ret;
178 179 180
  return ir::Mul::make(a, b);
}

181
Expr div(Expr a, Expr b) {
182 183 184
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Div>(a, b);
  if (ret.defined()) return ret;
185 186 187
  return ir::Div::make(a, b);
}

188 189 190 191 192 193
Expr truncdiv(Expr a, Expr b) {
  CHECK(a.type().is_int() || a.type().is_uint());
  CHECK(b.type().is_int() || b.type().is_uint());
  return div(a, b);
}

194
Expr truncmod(Expr a, Expr b) {
195 196 197
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Mod>(a, b);
  if (ret.defined()) return ret;
198 199 200
  return ir::Mod::make(a, b);
}

201
Expr operator/(Expr a, Expr b) {
202
  return div(a, b);
203 204 205 206 207 208
}

Expr operator%(Expr a, Expr b) {
  return truncmod(a, b);
}

209 210
// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
211
  return floordiv(a, b);
212 213 214
}

Expr indexmod(Expr a, Expr b) {
215
  return floormod(a, b);
216 217
}

218
Expr floordiv(Expr a, Expr b) {
219 220
  CHECK(a.type().is_int() || a.type().is_uint());
  CHECK(b.type().is_int() || b.type().is_uint());
221 222 223 224 225 226 227
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
  if (ret.defined()) return ret;
  return ir::FloorDiv::make(a, b);
}

Expr floormod(Expr a, Expr b) {
228 229
  CHECK(a.type().is_int() || a.type().is_uint());
  CHECK(b.type().is_int() || b.type().is_uint());
230 231 232 233 234
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
  if (ret.defined()) return ret;
  return ir::FloorMod::make(a, b);
}
235

236
Expr min(Expr a, Expr b) {
237 238 239 240 241 242 243
  // inf-aware simplificaiton
  using arith::is_pos_inf;
  using arith::is_neg_inf;
  if (is_pos_inf(a)) return b;
  if (is_neg_inf(a)) return a;
  if (is_pos_inf(b)) return a;
  if (is_neg_inf(b)) return b;
244 245 246
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Min>(a, b);
  if (ret.defined()) return ret;
247 248 249 250
  return ir::Min::make(a, b);
}

Expr max(Expr a, Expr b) {
251 252 253 254 255 256 257
  // inf-aware simplificaiton
  using arith::is_pos_inf;
  using arith::is_neg_inf;
  if (is_pos_inf(a)) return a;
  if (is_neg_inf(a)) return b;
  if (is_pos_inf(b)) return b;
  if (is_neg_inf(b)) return a;
258 259 260
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::Max>(a, b);
  if (ret.defined()) return ret;
261 262 263
  return ir::Max::make(a, b);
}

264
Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
265 266
  using ir::IntImm;
  using ir::UIntImm;
267
  CHECK(cond.type() == Bool(1))
268
      << "if_then_else only accept the condition to be boolean type.";
269 270 271 272 273 274 275 276 277 278 279 280 281 282
  BinaryOpMatchTypes(true_value, false_value);
  if (const UIntImm* op = cond.as<UIntImm>()) {
    if (op->value != 0) {
      return true_value;
    } else {
      return false_value;
    }
  } else if (const IntImm* op = cond.as<IntImm>()) {
    if (op->value != 0) {
      return true_value;
    } else {
      return false_value;
    }
  }
283 284 285 286 287
  return ir::Call::make(
      true_value.type(),
      ir::intrinsic::tvm_if_then_else,
      {cond, true_value, false_value},
      ir::Call::PureIntrinsic);
288 289 290 291 292 293 294 295
}

Expr likely(Expr cond) {
  if (is_const(cond)) return cond;
  return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
}

Expr operator>(Expr a, Expr b) {
296 297 298
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::GT>(a, b);
  if (ret.defined()) return ret;
299 300 301 302
  return ir::GT::make(a, b);
}

Expr operator>=(Expr a, Expr b) {
303 304 305
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::GE>(a, b);
  if (ret.defined()) return ret;
306 307 308 309
  return ir::GE::make(a, b);
}

Expr operator<(Expr a, Expr b) {
310 311 312
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::LT>(a, b);
  if (ret.defined()) return ret;
313 314 315 316
  return ir::LT::make(a, b);
}

Expr operator<=(Expr a, Expr b) {
317 318 319
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::LE>(a, b);
  if (ret.defined()) return ret;
320 321 322 323
  return ir::LE::make(a, b);
}

Expr operator==(Expr a, Expr b) {
324 325 326
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::EQ>(a, b);
  if (ret.defined()) return ret;
327 328 329 330
  return ir::EQ::make(a, b);
}

Expr operator!=(Expr a, Expr b) {
331 332 333
  BinaryOpMatchTypes(a, b);
  Expr ret = arith::TryConstFold<ir::NE>(a, b);
  if (ret.defined()) return ret;
334 335 336 337
  return ir::NE::make(a, b);
}

Expr operator&&(Expr a, Expr b) {
338 339 340 341
  CHECK(a.type().is_bool());
  CHECK(b.type().is_bool());
  Expr ret = arith::TryConstFold<ir::And>(a, b);
  if (ret.defined()) return ret;
342 343 344 345
  return ir::And::make(a, b);
}

Expr operator||(Expr a, Expr b) {
346 347 348 349
  CHECK(a.type().is_bool());
  CHECK(b.type().is_bool());
  Expr ret = arith::TryConstFold<ir::Or>(a, b);
  if (ret.defined()) return ret;
350 351 352 353
  return ir::Or::make(a, b);
}

Expr operator!(Expr a) {
354 355 356
  CHECK(a.type().is_bool());
  Expr ret = arith::TryConstFold<ir::Not>(a);
  if (ret.defined()) return ret;
357 358 359 360
  return ir::Not::make(a);
}

Expr operator>>(Expr a, Expr b) {
361
  BinaryOpMatchTypes(a, b);
362
  TVM_INDEX_CONST_PROPAGATION({
363
      const Type& rtype = a.type();
364 365
      if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
      if (pb) {
366
        if (pb->value == 0) return a;
367 368 369 370 371 372
      }
    });
  return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator<<(Expr a, Expr b) {
373
  BinaryOpMatchTypes(a, b);
374
  TVM_INDEX_CONST_PROPAGATION({
375
      const Type& rtype = a.type();
376 377
      if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
      if (pb) {
378
        if (pb->value == 0) return a;
379 380 381 382 383 384
      }
    });
  return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator&(Expr a, Expr b) {
385
  BinaryOpMatchTypes(a, b);
386
  TVM_INDEX_CONST_PROPAGATION({
387
      const Type& rtype = a.type();
388 389 390 391 392 393
      if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator|(Expr a, Expr b) {
394
  BinaryOpMatchTypes(a, b);
395
  TVM_INDEX_CONST_PROPAGATION({
396
      const Type& rtype = a.type();
397 398 399 400 401 402
      if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator^(Expr a, Expr b) {
403
  BinaryOpMatchTypes(a, b);
404
  TVM_INDEX_CONST_PROPAGATION({
405
      const Type& rtype = a.type();
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
      if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
    });
  return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
}

Expr operator~(Expr a) {
  CHECK(a.type().is_int() || a.type().is_uint());
  return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
}

Expr pow(Expr x, Expr y) {
  BinaryOpMatchTypes(x, y);
  CHECK(x.type().is_float()) << "power only applies to float";
  return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}

Expr abs(Expr x) {
  if (x.type().is_int()) {
424 425 426 427 428 429
    using ir::IntImm;
    const IntImm* px = x.as<IntImm>();
    if (px) {
      return ir::IntImm::make(x.type(), std::abs(px->value));
    }
    return ir::Select::make(x >= make_zero(x.type()), x, -x);
430
  } else if (x.type().is_float()) {
431 432 433 434 435
    using ir::FloatImm;
    const FloatImm* fx = x.as<FloatImm>();
    if (fx) {
      return ir::FloatImm::make(x.type(), std::fabs(fx->value));
    }
436 437 438 439 440 441 442 443 444 445
    return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
  } else if (x.type().is_uint()) {
    return x;
  } else {
    LOG(FATAL) << "Data type " << x.type()
               <<" not supported for absolute op. Skipping absolute op...";
    return x;
  }
}

446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
Expr isnan(Expr x) {
  Type t = Bool(x.type().lanes());
  if (x.type().is_int() || x.type().is_uint()) {
    return make_const(t, false);
  } else if (x.type().is_float()) {
    using ir::FloatImm;
    const FloatImm* fx = x.as<FloatImm>();
    if (fx) {
      return make_const(t, std::isnan(fx->value));
    }
    if (x.type().bits() == 16) {
      return ir::Call::make(t, ir::Call::isnan,
                               {cast(Float(32, t.lanes()), std::move(x))},
                               ir::Call::PureIntrinsic);
    } else {
      return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
    }
  } else {
    LOG(FATAL) << "Data type " << x.type()
               <<" not supported for isnan op. Skipping isnan op...";
    return x;
  }
}

470
Expr sum(Expr source, Array<IterVar> rdom) {
471
  Var x("x", source.type()), y("y", source.type());
472 473 474 475 476 477 478
  Expr result = ir::Add::make(x, y);
  Expr identity_element = make_zero(source.type());
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

479 480 481 482 483 484 485 486 487 488
Expr all(Expr source, Array<IterVar> rdom) {
  CHECK(source.type().is_bool());
  Var x("x", source.type()), y("y", source.type());
  Expr result = ir::And::make(x, y);
  Expr identity_element = make_const(source.type(), true);
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

489
Expr max(Expr source, Array<IterVar> rdom) {
490
  Var x("x", source.type()), y("y", source.type());
491 492 493 494 495 496 497 498
  Expr result = ir::Max::make(x, y);
  Expr identity_element = source.type().min();
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr min(Expr source, Array<IterVar> rdom) {
499
  Var x("x", source.type()), y("y", source.type());
500 501 502 503 504 505 506
  Expr result = ir::Min::make(x, y);
  Expr identity_element = source.type().max();
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

507
Expr prod(Expr source, Array<IterVar> rdom) {
508
  Var x("x", source.type()), y("y", source.type());
509
  Expr result = ir::Mul::make(x, y);
510
  Expr identity_element = make_const(source.type(), 1);
511 512 513 514 515
  ir::CommReducer combiner =
    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

516 517 518 519 520 521
Expr fmod(Expr x, Expr y) {
  BinaryOpMatchTypes(x, y);
  CHECK(x.type().is_float()) << "fmod only applies to float";
  return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}

522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
Expr floor(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
  return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
}

Expr ceil(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
  return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
}

Expr round(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
  return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
}

543 544 545 546 547 548 549
Expr nearbyint(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
  return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic);
}

550 551 552 553 554 555 556 557 558 559
Expr trunc(Expr x) {
  using ir::FloatImm;
  const FloatImm* fx = x.as<FloatImm>();
  if (fx) {
    return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
                                     std::floor(fx->value)));
  }
  return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
}

560
}  // namespace tvm