compute_expr.h 3.26 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 * \file compute_expr.h
22
 * \brief Utility to invoke certan compute operations.
23
 */
24 25
#ifndef TVM_ARITH_COMPUTE_EXPR_H_
#define TVM_ARITH_COMPUTE_EXPR_H_
26

27
#include <tvm/tir/expr.h>
28
#include <limits>
29
#include <algorithm>
30 31

namespace tvm {
32
namespace arith {
33 34 35 36 37

/*!
 * \brief Compute the expression with the given binary op.
 * \param lhs The left operand
 * \param rhs The right operand
38
 * \tparam Op the computation operator
39 40 41
 * \return The result.
 */
template<typename OP>
42
inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
43 44 45
  return OP::make(lhs, rhs);
}

46 47 48
/*!
 * \brief Compute an reduction with Op
 * \param values The input values.
49 50
 * \param empty_value The value when return if it is empty, can be Expr()
 *        which will cause an error to be rasied.
51 52 53 54
 * \tparam Op The computation operator
 * \return The result.
 */
template<typename Op>
55 56
inline PrimExpr ComputeReduce(
    const Array<PrimExpr>& values, PrimExpr empty_value);
57

58
inline bool GetConst(PrimExpr e, int64_t* out) {
59
  if (e.dtype().is_vector()) return false;
60
  const int64_t* v = tir::as_const_int(e);
61 62 63 64 65 66 67
  if (v) {
    *out = *v; return true;
  } else {
    return false;
  }
}

68
// get a small constant int
69
inline bool GetConstInt(PrimExpr e, int* out) {
70 71 72 73 74 75 76 77 78
  int64_t v1 = 0;
  if (GetConst(e, &v1)) {
    if (v1 > static_cast<int64_t>(
            std::numeric_limits<int>::max())) return false;
    *out = static_cast<int>(v1); return true;
  }
  return false;
}

79
template<>
80
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
81
  return a + b;
82 83 84
}

template<>
85
inline PrimExpr Compute<tir::SubNode>(PrimExpr a, PrimExpr b) {
86
  return a - b;
87 88 89
}

template<>
90
inline PrimExpr Compute<tir::MulNode>(PrimExpr a, PrimExpr b) {
91
  return a * b;
92 93 94
}

template<>
95
inline PrimExpr Compute<tir::DivNode>(PrimExpr a, PrimExpr b) {
96
  return truncdiv(a, b);
97 98 99
}

template<>
100
inline PrimExpr Compute<tir::ModNode>(PrimExpr a, PrimExpr b) {
101
  return truncmod(a, b);
102 103 104
}

template<>
105
inline PrimExpr Compute<tir::MaxNode>(PrimExpr a, PrimExpr b) {
106
  return max(a, b);
107 108 109
}

template<>
110
inline PrimExpr Compute<tir::MinNode>(PrimExpr a, PrimExpr b) {
111
  return min(a, b);
112 113
}

114
template<typename Op>
115
inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
116 117 118 119
  if (values.size() == 0U) {
    CHECK(empty_value.defined());
    return empty_value;
  }
120
  PrimExpr res = values[0];
121
  for (size_t i = 1; i < values.size(); ++i) {
122
    res = Compute<Op>(res, values[i]);
123 124 125 126
  }
  return res;
}

127
}  // namespace arith
128
}  // namespace tvm
129
#endif   // TVM_ARITH_COMPUTE_EXPR_H_