compute_expr.h 2.52 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2017 by Contributors
 * \file compute_expr.h
 * \brief Utility integer expression with quick eager simplification.
 *  This is weaker than Simplify but can be done Eagerly.
 */
7 8
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
9 10

#include <tvm/ir.h>
11
#include <arithmetic/Interval.h>
12
#include <limits>
13 14

namespace tvm {
15
namespace arith {
16 17 18 19 20

/*!
 * \brief Compute the expression with the given binary op.
 * \param lhs The left operand
 * \param rhs The right operand
21
 * \tparam Op the computation operator
22 23 24 25 26 27 28
 * \return The result.
 */
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
  return OP::make(lhs, rhs);
}

29 30 31
/*!
 * \brief Compute an reduction with Op
 * \param values The input values.
32 33
 * \param empty_value The value when return if it is empty, can be Expr()
 *        which will cause an error to be rasied.
34 35 36 37
 * \tparam Op The computation operator
 * \return The result.
 */
template<typename Op>
38 39
inline Expr ComputeReduce(
    const Array<Expr>& values, Expr empty_value);
40

41
inline bool GetConst(Expr e, int64_t* out) {
42
  if (e.type().is_vector()) return false;
43
  const int64_t* v = as_const_int(e);
44 45 46 47 48 49 50
  if (v) {
    *out = *v; return true;
  } else {
    return false;
  }
}

51 52 53 54 55 56 57 58 59 60 61
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
  int64_t v1 = 0;
  if (GetConst(e, &v1)) {
    if (v1 > static_cast<int64_t>(
            std::numeric_limits<int>::max())) return false;
    *out = static_cast<int>(v1); return true;
  }
  return false;
}

62 63
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
64
  return a + b;
65 66 67 68
}

template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
69
  return a - b;
70 71 72 73
}

template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
74
  return a * b;
75 76 77 78
}

template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
79
  return a / b;
80 81 82
}

template<>
83
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
84
  return a % b;
85 86 87
}

template<>
88
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
89
  return HalideIR::Internal::Interval::make_max(a, b);
90 91 92 93
}

template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
94
  return HalideIR::Internal::Interval::make_min(a, b);
95 96
}

97
template<typename Op>
98 99 100 101 102
inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
  if (values.size() == 0U) {
    CHECK(empty_value.defined());
    return empty_value;
  }
103 104 105 106 107 108 109
  Expr res = values[0];
  for (size_t i = 1; i < values.size(); ++i) {
    res = ComputeExpr<Op>(res, values[i]);
  }
  return res;
}

110
}  // namespace arith
111
}  // namespace tvm
112
#endif   // TVM_ARITHMETIC_COMPUTE_EXPR_H_