Unverified Commit 32af4d28 by Tianqi Chen Committed by GitHub

[IR] eager constant folding in operator overloading (#1789)

parent 3455c8a5
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h"
#include "node/container.h" #include "node/container.h"
namespace tvm { namespace tvm {
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IROperator.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
...@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt; ...@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter; using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable; using HalideIR::Internal::Variable;
using HalideIR::Internal::make_const;
using HalideIR::Internal::make_zero;
using HalideIR::Internal::make_one;
using HalideIR::Internal::as_const_int;
using HalideIR::Internal::as_const_uint;
using HalideIR::Internal::const_true;
using HalideIR::Internal::const_false;
using HalideIR::Internal::is_no_op;
inline Type TVMShapeIndexType() { inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) { if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8); return Int(sizeof(tvm_index_t) * 8);
......
...@@ -495,8 +495,6 @@ using HalideIR::Internal::Block; ...@@ -495,8 +495,6 @@ using HalideIR::Internal::Block;
using HalideIR::Internal::IfThenElse; using HalideIR::Internal::IfThenElse;
using HalideIR::Internal::Evaluate; using HalideIR::Internal::Evaluate;
using HalideIR::Internal::Shuffle; using HalideIR::Internal::Shuffle;
// ir functions
using HalideIR::Internal::is_const_power_of_two_integer;
/*! /*!
* \brief Create a type annotation expression * \brief Create a type annotation expression
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2018 by Contributors
* \file tvm/ir_operator.h * \file tvm/ir_operator.h
* \brief Common operators of Expr * \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/ */
#ifndef TVM_IR_OPERATOR_H_ #ifndef TVM_IR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_ #define TVM_IR_OPERATOR_H_
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "expr.h" #include "expr.h"
#include "ir.h" #include "ir.h"
namespace tvm { namespace tvm {
/*!
* \brief Make a const value with certain data type.
* \param t The target type.
* \param value The input value
* \return the result expression.
* \tparam ValueType The constant value type
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline Expr make_const(Type t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline Expr make_zero(Type t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
return make_const(UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
return make_const(UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
* \param value the value to be compared against.
* \return whether x is constant expression.
*/
inline bool is_const_int(const Expr& x, int64_t value);
/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
* \return whether stmt is nop
*/
inline bool is_no_op(const Stmt& stmt);
/*!
* \brief Check whether x is a constant integer 1
* \param x The input argument.
* \note This only return true for integer types.
* \return whether x is constant 1
*/
inline bool is_one(const Expr& x) {
return is_const_int(x, 1);
}
using HalideIR::likely; /*!
using HalideIR::likely_if_innermost; * \brief Check whether x is a constant integer 0
// functions * \param x The input argument
using HalideIR::cast; * \return whether x is constant 0
using HalideIR::min; * \note This only return true for integer types.
using HalideIR::max; */
using HalideIR::select; inline bool is_zero(const Expr& x) {
return is_const_int(x, 0);
}
/*!
* \brief Check whether x is a constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const(const Expr& x);
/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
* \param x The input expression.
* \param shift The output shift if x is power of two.
* \return whether x is constant power of two
*/
TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
/*!
* \brief cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr cast(const Type& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr reinterpret(const Type& t, Expr value);
/*!
* \brief add operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator+(Expr a, Expr b);
/*!
* \brief subtraction operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a, Expr b);
/*!
* \brief negation.
*
* \param a input.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a);
/*!
* \brief multiplication operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator*(Expr a, Expr b);
/*!
* \brief division operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator/(Expr a, Expr b);
/*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<<(Expr a, Expr b);
/*!
* \brief right shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>>(Expr a, Expr b);
/*!
* \brief greater
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>(Expr a, Expr b);
/*!
* \brief greater_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>=(Expr a, Expr b);
/*!
* \brief less
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<(Expr a, Expr b);
/*!
* \brief less_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<=(Expr a, Expr b);
/*!
* \brief equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator==(Expr a, Expr b);
/*!
* \brief not_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator!=(Expr a, Expr b);
/*!
* \brief and
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator&&(Expr a, Expr b);
/*!
* \brief or
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator||(Expr a, Expr b);
/*!
* \brief not
*
* \param a left operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief take maximum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr max(Expr a, Expr b);
/*!
* \brief take minimum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr min(Expr a, Expr b);
/*!
* \brief right shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>>(Expr a, Expr b);
/*!
* \brief left shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<<(Expr a, Expr b);
/*!
* \brief take bitwise and of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator&(Expr a, Expr b);
/*!
* \brief take bitwise or of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator|(Expr a, Expr b);
/*!
* \brief take bitwise xor of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator^(Expr a, Expr b);
/*!
* \brief take bitwise negation of two values
*
* \param a the input expression.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator~(Expr a);
/*!
* \brief select result by condition
*
* \param cond The condition
* \param true_value The value when results are true.
* \param false_value The value when results are false.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr select(Expr cond, Expr true_value, Expr false_value);
/*!
* \brief Mark condition as likely.
* \param cond The condition
* \return The marked expression.
*/
TVM_DLL Expr likely(Expr cond);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
TVM_DLL Expr pow(Expr x, Expr y);
/*!
* \brief Calculate absolute value of x.
* \param x The input data
*
* \return The aboslute value of input data x
*/
TVM_DLL Expr abs(Expr x);
/*! /*!
* \brief sum of of source expression over axis * \brief sum of of source expression over axis
...@@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis); ...@@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/ */
TVM_DLL Expr prod(Expr source, Array<IterVar> axis); TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
// Unary intrinsic operators // Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \ #define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \ inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \ } \
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
...@@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor); ...@@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil); TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round); TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc); TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);
/*!
* \brief Calculate power(x, y) // Implementation details after this
* \param x The left operand. inline bool is_const(const Expr& x) {
* \param y The right operand. if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
*/ return true;
inline Expr pow(Expr x, Expr y) { } else if (const auto* op = x.as<ir::Broadcast>()) {
match_types(x, y); const Expr& val = op->value;
CHECK(x.type().is_float()) << "power only applies to float"; if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); return true;
}
}
return false;
} }
/*! inline bool is_positive_const(const Expr& a) {
* \brief Calculate absolute value of x, elementwise if (const ir::IntImm* op = a.as<ir::IntImm>()) {
* \param x The input data return op->value > 0;
* } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
* \return The aboslute value of input data x return op->value > 0;
*/
inline Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else { } else {
LOG(WARNING) << "Warning: Data type " << x.type() return false;
<<" not supported for absolute op. Skipping absolute op...";
return x;
} }
} }
} // namespace tvm inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
return op->value < 0;
} else {
return false;
}
}
inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
return false;
}
inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
return false;
}
template<typename ValueType>
inline Expr MakeConstScalar(Type t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
template<typename ValueType, typename>
inline Expr make_const(Type t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
return ir::Broadcast::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
inline Expr make_zero(Type t) {
if (t.is_handle()) {
return reinterpret(t, make_const(UInt(64), 0));
}
return make_const(t, 0);
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
a = OpFunc(a, b); \
return a; \
}
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, float b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(float a, const Expr& b) { \
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
}
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, bool b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(bool a, const Expr& b) { \
return Name(Expr(a), b); \
}
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
}
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
// logical ops
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
} // namespace tvm
#endif // TVM_IR_OPERATOR_H_ #endif // TVM_IR_OPERATOR_H_
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h"
#include "arithmetic.h" #include "arithmetic.h"
#include "node/container.h" #include "node/container.h"
......
...@@ -354,7 +354,7 @@ Example:: ...@@ -354,7 +354,7 @@ Example::
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes); auto axis = ShapeToArray(r_axes);
Expr count = make_one(inputs[0]->dtype); Expr count = make_const(inputs[0]->dtype, 1);
for (auto& i : r_axes) { for (auto& i : r_axes) {
count *= inputs[0]->shape[i]; count *= inputs[0]->shape[i];
} }
......
...@@ -156,9 +156,9 @@ def any(*args): ...@@ -156,9 +156,9 @@ def any(*args):
raise ValueError("Any must take at least 1 argument") raise ValueError("Any must take at least 1 argument")
if len(args) == 1: if len(args) == 1:
return args[0] return args[0]
ret = _expr.Or(args[0], args[1]) ret = _make._OpOr(args[0], args[1])
for i in range(2, len(args)): for i in range(2, len(args)):
ret = _expr.Or(ret, args[i]) ret = _make._OpOr(ret, args[i])
return ret return ret
...@@ -180,9 +180,9 @@ def all(*args): ...@@ -180,9 +180,9 @@ def all(*args):
raise ValueError("Any must take at least 1 argument") raise ValueError("Any must take at least 1 argument")
if len(args) == 1: if len(args) == 1:
return args[0] return args[0]
ret = _expr.And(args[0], args[1]) ret = _make._OpAnd(args[0], args[1])
for i in range(2, len(args)): for i in range(2, len(args)):
ret = _expr.And(ret, args[i]) ret = _make._OpAnd(ret, args[i])
return ret return ret
...@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api("tvm.api") _init_api("tvm.api")
#pylint: disable=unnecessary-lambda #pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min') min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max') max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
...@@ -60,7 +60,7 @@ class ExprOp(object): ...@@ -60,7 +60,7 @@ class ExprOp(object):
return self.__rdiv__(other) return self.__rdiv__(other)
def __mod__(self, other): def __mod__(self, other):
return _make.Mod(self, other) return _make._OpMod(self, other)
def __neg__(self): def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype) neg_one = _api_internal._const(-1, self.dtype)
...@@ -85,10 +85,10 @@ class ExprOp(object): ...@@ -85,10 +85,10 @@ class ExprOp(object):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
def __lt__(self, other): def __lt__(self, other):
return _make.LT(self, other) return _make._OpLT(self, other)
def __le__(self, other): def __le__(self, other):
return _make.LE(self, other) return _make._OpLE(self, other)
def __eq__(self, other): def __eq__(self, other):
return EqualOp(self, other) return EqualOp(self, other)
...@@ -97,10 +97,10 @@ class ExprOp(object): ...@@ -97,10 +97,10 @@ class ExprOp(object):
return NotEqualOp(self, other) return NotEqualOp(self, other)
def __gt__(self, other): def __gt__(self, other):
return _make.GT(self, other) return _make._OpGT(self, other)
def __ge__(self, other): def __ge__(self, other):
return _make.GE(self, other) return _make._OpGE(self, other)
def __nonzero__(self): def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " + raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
...@@ -122,7 +122,7 @@ class ExprOp(object): ...@@ -122,7 +122,7 @@ class ExprOp(object):
ret : Expr ret : Expr
The equality expression. The equality expression.
""" """
return _make.EQ(self, other) return _make._OpEQ(self, other)
def astype(self, dtype): def astype(self, dtype):
"""Cast the expression to other type. """Cast the expression to other type.
...@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp): ...@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp):
def asnode(self): def asnode(self):
"""Convert node.""" """Convert node."""
return _make.EQ(self.a, self.b) return _make._OpEQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp): class NotEqualOp(NodeGeneric, ExprOp):
...@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp): ...@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp):
def asnode(self): def asnode(self):
"""Convert node.""" """Convert node."""
return _make.NE(self.a, self.b) return _make._OpNE(self.a, self.b)
class Expr(ExprOp, NodeBase): class Expr(ExprOp, NodeBase):
......
...@@ -24,7 +24,7 @@ def add(lhs, rhs): ...@@ -24,7 +24,7 @@ def add(lhs, rhs):
op : tvm.Expr op : tvm.Expr
The result Expr of add operaton. The result Expr of add operaton.
""" """
return _make.Add(lhs, rhs) return _make._OpAdd(lhs, rhs)
def subtract(lhs, rhs): def subtract(lhs, rhs):
...@@ -42,7 +42,7 @@ def subtract(lhs, rhs): ...@@ -42,7 +42,7 @@ def subtract(lhs, rhs):
op : tvm.Expr op : tvm.Expr
The result Expr of subtract operaton. The result Expr of subtract operaton.
""" """
return _make.Sub(lhs, rhs) return _make._OpSub(lhs, rhs)
def multiply(lhs, rhs): def multiply(lhs, rhs):
...@@ -60,7 +60,7 @@ def multiply(lhs, rhs): ...@@ -60,7 +60,7 @@ def multiply(lhs, rhs):
op : tvm.Expr op : tvm.Expr
The result Expr of multiply operaton. The result Expr of multiply operaton.
""" """
return _make.Mul(lhs, rhs) return _make._OpMul(lhs, rhs)
def divide(lhs, rhs): def divide(lhs, rhs):
...@@ -78,7 +78,7 @@ def divide(lhs, rhs): ...@@ -78,7 +78,7 @@ def divide(lhs, rhs):
op : tvm.Expr op : tvm.Expr
The result Expr of divide operaton. The result Expr of divide operaton.
""" """
return _make.Div(lhs, rhs) return _make._OpDiv(lhs, rhs)
def cast(src, dtype): def cast(src, dtype):
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <ir/IROperator.h> #include <tvm/ir_operator.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/ir_operator.h> #include <tvm/ir_operator.h>
...@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer") ...@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \ }) \
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE2(Add);
REGISTER_MAKE2(Sub);
REGISTER_MAKE2(Mul);
REGISTER_MAKE2(Div);
REGISTER_MAKE2(Mod);
REGISTER_MAKE2(Min);
REGISTER_MAKE2(Max);
REGISTER_MAKE2(EQ);
REGISTER_MAKE2(NE);
REGISTER_MAKE2(LT);
REGISTER_MAKE2(LE);
REGISTER_MAKE2(GT);
REGISTER_MAKE2(GE);
REGISTER_MAKE2(And);
REGISTER_MAKE2(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \ #define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
...@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer") ...@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer")
} \ } \
}) })
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm); REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE2(UIntImm); REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE2(FloatImm); REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE1(StringImm); REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(Add, operator+); REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(Sub, operator-); REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(Mul, operator*); REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(Div, operator/); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(Mod, operator%); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(Min, min); REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(Max, max); REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(EQ, operator==); REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(NE, operator!=); REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>); REGISTER_MAKE_BIT_OP(right_shift, operator>>);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using HalideIR::Internal::add_would_overflow;
using HalideIR::Internal::sub_would_overflow;
using HalideIR::Internal::mul_would_overflow;
/*! /*!
* \brief Compute the expression with the given binary op. * \brief Compute the expression with the given binary op.
* \param lhs The left operand * \param lhs The left operand
...@@ -42,23 +38,9 @@ template<typename Op> ...@@ -42,23 +38,9 @@ template<typename Op>
inline Expr ComputeReduce( inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value); const Array<Expr>& values, Expr empty_value);
template<typename T> inline bool GetConst(Expr e, int64_t* out) {
inline bool GetConst(Expr e, T* out);
template<>
inline bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
if (e.type().is_vector()) return false; if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e); const int64_t* v = as_const_int(e);
if (v) { if (v) {
*out = *v; return true; *out = *v; return true;
} else { } else {
...@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) { ...@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
// get a small constant int // get a small constant int
inline bool GetConstInt(Expr e, int* out) { inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0; int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) { if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>( if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false; std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true; *out = static_cast<int>(v1); return true;
} }
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false; return false;
} }
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua OP ub); \
} \
template<> template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b; return a + b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a; return a - b;
TVM_CONST_PROPAGATION(sub, -);
return ir::Sub::make(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b; return a * b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a; return a / b;
return ir::Div::make(a, b);
} }
template<> template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) { inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
if (is_zero(a)) return make_zero(a.type()); return a % b;
return ir::Mod::make(a, b);
} }
template<> template<>
......
...@@ -194,7 +194,7 @@ bool DetectClipBound( ...@@ -194,7 +194,7 @@ bool DetectClipBound(
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff); ret.coeff = Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()]; IntervalEntry& p = (*bmap)[var.get()];
if (is_one(ret.coeff)) { if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift // var + shift >=0 -> var >= -shift
if (p.min_value.defined()) { if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base); p.min_value = ir::Max::make(p.min_value, -ret.base);
...@@ -203,7 +203,7 @@ bool DetectClipBound( ...@@ -203,7 +203,7 @@ bool DetectClipBound(
} }
return true; return true;
} }
if (is_const(ret.coeff, -1)) { if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift // -var + shift >=0 -> var <= shift
if (p.max_value.defined()) { if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base); p.max_value = ir::Min::make(p.max_value, ret.base);
......
...@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() { ...@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() {
} }
void CodeGenCUDA::VisitStmt_(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CHECK(is_zero(op->min)); CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) { if (op->for_type == ir::ForType::Unrolled) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
......
...@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor { ...@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor {
ChannelEntry& cb = cmap_.at(ch->handle_var.get()); ChannelEntry& cb = cmap_.at(ch->handle_var.get());
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size()); trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size. // Grab the advance constant size.
int trigger_size; int trigger_size = 0;
if (attr->attr_key == attr::pipeline_stage_scope) { if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back( cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0)); ControlSignalNode::make(kComputeFinish, 0));
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
......
...@@ -8,6 +8,406 @@ ...@@ -8,6 +8,406 @@
namespace tvm { namespace tvm {
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline bool IsIndexType(const Type& type) {
return type.is_int() && type.lanes() == 1 &&
(type.bits() == 32 || type.bits() == 64);
}
// simple cast that only checks if type matches and cast
inline Expr SimpleCast(const Type& t, Expr value) {
if (value.type() == t) return value;
return ir::Cast::make(t, value);
}
// The public function with a quick checking path.
void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
if (lhs.type() == rhs.type()) return;
Type ltype = lhs.type();
Type rtype = rhs.type();
if (ltype.lanes() == 1 && rtype.lanes() != 1) {
lhs = ir::Broadcast::make(lhs, rtype.lanes());
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
rhs = ir::Broadcast::make(rhs, ltype.lanes());
} else {
CHECK(ltype.lanes() == rtype.lanes())
<< "Cannot match type " << ltype << " vs " << rtype;
}
if (lhs.type() == rhs.type()) return;
// Only do very simple type coversion
// int->float, int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
if (!lhs.type().is_float() && rhs.type().is_float()) {
// int->float
lhs = ir::Cast::make(rhs.type(), lhs);
} else if (lhs.type().is_float() && !rhs.type().is_float()) {
// int->float
rhs = ir::Cast::make(lhs.type(), rhs);
} else if ((lhs.type().is_int() && rhs.type().is_int()) ||
(lhs.type().is_uint() && rhs.type().is_uint())) {
// promote int to higher bits
if (lhs.type().bits() < rhs.type().bits()) {
lhs = ir::Cast::make(rhs.type(), lhs);
} else {
rhs = ir::Cast::make(lhs.type(), rhs);
}
} else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
(lhs.type().is_uint() && rhs.type().is_int())) {
int bits = std::max(lhs.type().bits(), rhs.type().bits());
lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
}
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
if (val <= 0) return false;
shift[0] = 0;
while (val != 0) {
if (val & 1) {
return (val == 1);
}
++shift[0];
val = val >> 1;
}
return true;
}
bool is_const_power_of_two_integer(const Expr& x, int* shift) {
if (const auto* op = x.as<ir::IntImm>()) {
return ConstPowerHelper(op->value, shift);
} else if (const auto* op = x.as<ir::UIntImm>()) {
return ConstPowerHelper(op->value, shift);
} else {
return false;
}
}
Expr cast(const Type& t, Expr value) {
using ir::IntImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
}
return ir::Cast::make(t, value);
} else {
if (value.type().lanes() == 1) {
// manually unroll cast
Type vtype = t.element_of();
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
} else {
value = ir::Cast::make(vtype, value);
}
}
return ir::Broadcast::make(value, t.lanes());
} else {
CHECK(value.type().lanes() == t.lanes());
return ir::Cast::make(t, value);
}
}
}
Expr reinterpret(const Type& t, Expr value) {
if (value.type() == t) return value;
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}
#define TVM_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (IsIndexType(ta) && IsIndexType(tb)) { \
BODY; \
} \
BinaryOpMatchTypes(a, b);
Expr operator+(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return SimpleCast(rtype, b);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
});
return ir::Add::make(a, b);
}
Expr operator-(Expr a) {
using ir::IntImm;
const IntImm* pa = a.as<IntImm>();
if (pa) {
return ir::IntImm::make(a.type(), -pa->value);
}
return make_zero(a.type()) - a;
}
Expr operator-(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
});
return ir::Sub::make(a, b);
}
Expr operator*(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return SimpleCast(rtype, b);
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
if (pb->value == 0) return SimpleCast(rtype, b);
}
});
return ir::Mul::make(a, b);
}
Expr operator/(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return ir::Div::make(a, b);
}
Expr operator%(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return ir::Mod::make(a, b);
}
Expr min(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
});
return ir::Min::make(a, b);
}
Expr max(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
});
return ir::Max::make(a, b);
}
Expr select(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm;
using ir::UIntImm;
CHECK(cond.type().is_bool());
BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
} else if (const IntImm* op = cond.as<IntImm>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
}
return ir::Select::make(cond, true_value, false_value);
}
Expr likely(Expr cond) {
if (is_const(cond)) return cond;
return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
}
Expr operator>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
});
return ir::GT::make(a, b);
}
Expr operator>=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
});
return ir::GE::make(a, b);
}
Expr operator<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
});
return ir::LT::make(a, b);
}
Expr operator<=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
});
return ir::LE::make(a, b);
}
Expr operator==(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
});
return ir::EQ::make(a, b);
}
Expr operator!=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
});
return ir::NE::make(a, b);
}
Expr operator&&(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) {
return UIntImm::make(UInt(1), pa->value && pb->value);
}
return ir::And::make(a, b);
}
Expr operator||(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) {
return UIntImm::make(UInt(1), pa->value || pb->value);
}
return ir::Or::make(a, b);
}
Expr operator!(Expr a) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
if (pa) {
return UIntImm::make(UInt(1), !(pa->value));
}
return ir::Not::make(a);
}
Expr operator>>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a);
}
});
return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator<<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a);
}
});
return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator&(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator|(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator^(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator~(Expr a) {
CHECK(a.type().is_int() || a.type().is_uint());
return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
}
Expr pow(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}
Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else {
LOG(FATAL) << "Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}
Expr sum(Expr source, Array<IterVar> rdom) { Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type()); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y); Expr result = ir::Add::make(x, y);
...@@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) { ...@@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
Expr prod(Expr source, Array<IterVar> rdom) { Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type()); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Mul::make(x, y); Expr result = ir::Mul::make(x, y);
Expr identity_element = make_one(source.type()); Expr identity_element = make_const(source.type(), 1);
ir::CommReducer combiner = ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_ #define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
...@@ -75,7 +76,7 @@ inline Expr TVMStructGet( ...@@ -75,7 +76,7 @@ inline Expr TVMStructGet(
Array<Expr> args ={ Array<Expr> args ={
handle, handle,
make_const(Int(32), index), make_const(Int(32), index),
make_const(Int(32), kind)}; make_const(Int(32), static_cast<int>(kind))};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic); return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
} }
...@@ -125,7 +126,7 @@ inline Stmt TVMStructSet( ...@@ -125,7 +126,7 @@ inline Stmt TVMStructSet(
Array<Expr> args ={ Array<Expr> args ={
handle, handle,
make_const(Int(32), index), make_const(Int(32), index),
make_const(Int(32), kind), make_const(Int(32), static_cast<int>(kind)),
value}; value};
return Evaluate::make( return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic)); Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
......
...@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator { ...@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator {
} else { } else {
alloc_size = op->extents[0]; alloc_size = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) { for (size_t i = 1; i < op->extents.size(); ++i) {
alloc_size *= op->extents[i]; alloc_size = alloc_size * op->extents[i];
} }
alloc_size = ir::Simplify(alloc_size);
} }
if (rw.write_count) { if (rw.write_count) {
......
...@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size = combo_size / type_bits; combo_size = combo_size / type_bits;
// round up for can not divided // round up for can not divided
if (!divided) { if (!divided) {
combo_size += make_const(Int(32), 1); combo_size = combo_size + make_const(Int(32), 1);
} }
combo_size = ir::Simplify(combo_size); combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make( e->new_alloc = Allocate::make(
......
...@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator { ...@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) final { Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Vectorized) { if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
CHECK(is_positive_const(op->extent));
int lanes = 0; int lanes = 0;
bool succ = arith::GetConstInt(op->extent, &lanes); bool succ = arith::GetConstInt(op->extent, &lanes);
if (!succ || lanes < 1) { if (!succ || lanes < 1) {
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
namespace { namespace {
using namespace tvm::ir; using namespace tvm::ir;
......
...@@ -35,7 +35,7 @@ def test_deduce(): ...@@ -35,7 +35,7 @@ def test_deduce():
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4) ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0) e2 = (tvm.max(5, a * 4) < 0)
...@@ -63,7 +63,7 @@ def test_check(): ...@@ -63,7 +63,7 @@ def test_check():
assert res1.is_nothing() assert res1.is_nothing()
# multiple compare operators # multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {}) res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing() assert res2.is_nothing()
# multiple target variable # multiple target variable
...@@ -88,11 +88,11 @@ def test_deduce_basic(): ...@@ -88,11 +88,11 @@ def test_deduce_basic():
res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s}) res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4) test_basic(0, 4, 4)
test_basic(1, 5, 4) test_basic(1, 5, 4)
test_basic(2, 6, 4) test_basic(2, 6, 4)
...@@ -137,4 +137,3 @@ if __name__ == "__main__": ...@@ -137,4 +137,3 @@ if __name__ == "__main__":
test_check() test_check()
test_deduce_basic() test_deduce_basic()
test_deduce_complex() test_deduce_complex()
...@@ -8,7 +8,7 @@ def test_const(): ...@@ -8,7 +8,7 @@ def test_const():
def test_make(): def test_make():
x = tvm.const(1) x = tvm.const(1)
y = tvm.make.IntImm('int32', 1) y = tvm.var("x")
z = x + y z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min) assert isinstance(tvm.min(x, y), tvm.expr.Min)
......
import tvm
def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x) for x in args])
y = f(*args)
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
check(lambda x, y: x + y, 3, 4)
check(lambda x, y: x * y, 3, 12)
check(lambda x, y: x * y - 10, 3, 12)
check(lambda x, y: x - y % 10, 3, 12)
check(lambda x, y: x // y + 10, 100, 12)
check(lambda x, y: x & y + 10, 112, 128)
check(lambda x, y: x > y, 112, 128)
check(lambda x, y: x < y, 112, 128)
check(lambda x, y: x <= y, 112, 128)
check(lambda x, y: x >= y, 112, 128)
check(lambda x, y: (x | y) ^ 10, 112, 128)
def test_const_fold2():
x = tvm.var("x")
assert (x + 0).same_as(x)
assert (0 + x).same_as(x)
assert (x - 0).same_as(x)
assert (x % 1).value == 0
assert (x * 1).same_as(x)
assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div)
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
...@@ -15,7 +15,7 @@ def test_make_smap(): ...@@ -15,7 +15,7 @@ def test_make_smap():
# save load json # save load json
x = tvm.const(1) x = tvm.const(1)
y = tvm.const(10) y = tvm.const(10)
z = x + y z = tvm.expr.Add(x, y)
smap = tvm.convert({"z": z, "x": x}) smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap])) json_str = tvm.save_json(tvm.convert([smap]))
arr = tvm.load_json(json_str) arr = tvm.load_json(json_str)
......
...@@ -53,7 +53,6 @@ def test_canonical(): ...@@ -53,7 +53,6 @@ def test_canonical():
assert (tvm.ir_pass.Equal(ret1, ret2)) assert (tvm.ir_pass.Equal(ret1, ret2))
if __name__ == "__main__": if __name__ == "__main__":
test_modular()
test_bound() test_bound()
test_basic() test_basic()
test_simplify() test_simplify()
......
...@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape, ...@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape,
const Expr fill_value, const Expr fill_value,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kElementWise) { std::string tag = kElementWise) {
Expr ev = lossless_cast(dtype, fill_value); Expr ev = cast(dtype, fill_value);
if (!ev.defined()) { if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype; LOG(ERROR) << "Can't cast fill_value to " << dtype;
} }
...@@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape, ...@@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape,
} }
/*! /*!
* \brief Creates an operation that construct a tensor with same shape as input tensor, * \brief Creates an operation that construct a tensor with same shape as input tensor,
* then fill a tensor with fill_value * then fill a tensor with fill_value
* *
* \param x The input tensor * \param x The input tensor
...@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x, ...@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x,
const Expr fill_value, const Expr fill_value,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kElementWise) { std::string tag = kElementWise) {
Expr ev = lossless_cast(x->dtype, fill_value); Expr ev = cast(x->dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << x->dtype;
}
return compute(x->shape, [&](const Array<Var>& i) { return compute(x->shape, [&](const Array<Var>& i) {
return ev; return ev;
}, name, tag); }, name, tag);
......
...@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x,
out_shape.Set(height_axis, out_height); out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width); out_shape.Set(width_axis, out_width);
const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top); const int64_t *padding_h0 = as_const_int(pad_top);
const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left); const int64_t *padding_w0 = as_const_int(pad_left);
const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom); const int64_t *padding_h1 = as_const_int(pad_bottom);
const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right); const int64_t *padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
...@@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout, ...@@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout,
* Since pooling does not care about the factor size of dimensions * Since pooling does not care about the factor size of dimensions
* other than `H` and `W`, one can pass `NCHWc` as well. * other than `H` and `W`, one can pass `NCHWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
* *
* *
* \return The output tensor in the same layout * \return The output tensor in the same layout
*/ */
......
...@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy = py * vy * ah + ay oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0 ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0] batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1] num_classes = cls_prob.shape[1]
...@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with ib.if_scope(j > 0): with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.select(temp > score[0], j, cls_id[0]) cls_id[0] = tvm.select(temp > score[0], j, cls_id[0])
score[0] = tvm.make.Max(temp, score[0]) score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)): with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0 cls_id[0] = 0
# [id, prob, xmin, ymin, xmax, ymax] # [id, prob, xmin, ymin, xmax, ymax]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment