Unverified Commit 32af4d28 by Tianqi Chen Committed by GitHub

[IR] eager constant folding in operator overloading (#1789)

parent 3455c8a5
......@@ -10,6 +10,7 @@
#include "base.h"
#include "expr.h"
#include "ir_operator.h"
#include "node/container.h"
namespace tvm {
......
......@@ -7,7 +7,6 @@
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
......@@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
using HalideIR::Internal::make_const;
using HalideIR::Internal::make_zero;
using HalideIR::Internal::make_one;
using HalideIR::Internal::as_const_int;
using HalideIR::Internal::as_const_uint;
using HalideIR::Internal::const_true;
using HalideIR::Internal::const_false;
using HalideIR::Internal::is_no_op;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
......
......@@ -495,8 +495,6 @@ using HalideIR::Internal::Block;
using HalideIR::Internal::IfThenElse;
using HalideIR::Internal::Evaluate;
using HalideIR::Internal::Shuffle;
// ir functions
using HalideIR::Internal::is_const_power_of_two_integer;
/*!
* \brief Create a type annotation expression
......
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2018 by Contributors
* \file tvm/ir_operator.h
* \brief Common operators of Expr
* \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/
#ifndef TVM_IR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_
#include <algorithm>
#include <type_traits>
#include "expr.h"
#include "ir.h"
namespace tvm {
/*!
* \brief Make a const value with certain data type.
* \param t The target type.
* \param value The input value
* \return the result expression.
* \tparam ValueType The constant value type
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline Expr make_const(Type t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline Expr make_zero(Type t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
return make_const(UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
return make_const(UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Get x as constant uint expression.
* \param x The expression
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
return &(op->value);
} else {
return nullptr;
}
}
/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
* \param value the value to be compared against.
* \return whether x is constant expression.
*/
inline bool is_const_int(const Expr& x, int64_t value);
/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
* \return whether stmt is nop
*/
inline bool is_no_op(const Stmt& stmt);
/*!
* \brief Check whether x is a constant integer 1
* \param x The input argument.
* \note This only return true for integer types.
* \return whether x is constant 1
*/
inline bool is_one(const Expr& x) {
return is_const_int(x, 1);
}
using HalideIR::likely;
using HalideIR::likely_if_innermost;
// functions
using HalideIR::cast;
using HalideIR::min;
using HalideIR::max;
using HalideIR::select;
/*!
* \brief Check whether x is a constant integer 0
* \param x The input argument
* \return whether x is constant 0
* \note This only return true for integer types.
*/
inline bool is_zero(const Expr& x) {
return is_const_int(x, 0);
}
/*!
* \brief Check whether x is a constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const(const Expr& x);
/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
* \param x The input expression.
* \param shift The output shift if x is power of two.
* \return whether x is constant power of two
*/
TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
/*!
* \brief cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr cast(const Type& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
* \param t the target type.
* \param value The value
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr reinterpret(const Type& t, Expr value);
/*!
* \brief add operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator+(Expr a, Expr b);
/*!
* \brief subtraction operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a, Expr b);
/*!
* \brief negation.
*
* \param a input.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator-(Expr a);
/*!
* \brief multiplication operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator*(Expr a, Expr b);
/*!
* \brief division operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator/(Expr a, Expr b);
/*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<<(Expr a, Expr b);
/*!
* \brief right shift operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>>(Expr a, Expr b);
/*!
* \brief greater
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>(Expr a, Expr b);
/*!
* \brief greater_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>=(Expr a, Expr b);
/*!
* \brief less
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<(Expr a, Expr b);
/*!
* \brief less_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<=(Expr a, Expr b);
/*!
* \brief equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator==(Expr a, Expr b);
/*!
* \brief not_equal
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator!=(Expr a, Expr b);
/*!
* \brief and
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator&&(Expr a, Expr b);
/*!
* \brief or
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator||(Expr a, Expr b);
/*!
* \brief not
*
* \param a left operand
* \return The result expression.
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief take maximum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr max(Expr a, Expr b);
/*!
* \brief take minimum of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr min(Expr a, Expr b);
/*!
* \brief right shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator>>(Expr a, Expr b);
/*!
* \brief left shift
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator<<(Expr a, Expr b);
/*!
* \brief take bitwise and of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator&(Expr a, Expr b);
/*!
* \brief take bitwise or of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator|(Expr a, Expr b);
/*!
* \brief take bitwise xor of two values
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator^(Expr a, Expr b);
/*!
* \brief take bitwise negation of two values
*
* \param a the input expression.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator~(Expr a);
/*!
* \brief select result by condition
*
* \param cond The condition
* \param true_value The value when results are true.
* \param false_value The value when results are false.
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr select(Expr cond, Expr true_value, Expr false_value);
/*!
* \brief Mark condition as likely.
* \param cond The condition
* \return The marked expression.
*/
TVM_DLL Expr likely(Expr cond);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
TVM_DLL Expr pow(Expr x, Expr y);
/*!
* \brief Calculate absolute value of x.
* \param x The input data
*
* \return The aboslute value of input data x
*/
TVM_DLL Expr abs(Expr x);
/*!
* \brief sum of of source expression over axis
......@@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
// Unary intrinsic operators
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
......@@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
inline Expr pow(Expr x, Expr y) {
match_types(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
// Implementation details after this
inline bool is_const(const Expr& x) {
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
return true;
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
return true;
}
}
return false;
}
/*!
* \brief Calculate absolute value of x, elementwise
* \param x The input data
*
* \return The aboslute value of input data x
*/
inline Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
inline bool is_positive_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
return op->value > 0;
} else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
return op->value > 0;
} else {
LOG(WARNING) << "Warning: Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
return false;
}
}
} // namespace tvm
inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
return op->value < 0;
} else {
return false;
}
}
inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
return false;
}
inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
return false;
}
template<typename ValueType>
inline Expr MakeConstScalar(Type t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
template<typename ValueType, typename>
inline Expr make_const(Type t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
return ir::Broadcast::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
inline Expr make_zero(Type t) {
if (t.is_handle()) {
return reinterpret(t, make_const(UInt(64), 0));
}
return make_const(t, 0);
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
a = OpFunc(a, b); \
return a; \
}
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, float b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(float a, const Expr& b) { \
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
}
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, bool b) { \
return Name(a, Expr(b)); \
} \
inline Expr Name(bool a, const Expr& b) { \
return Name(Expr(a), b); \
}
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
}
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
// logical ops
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
} // namespace tvm
#endif // TVM_IR_OPERATOR_H_
......@@ -13,6 +13,7 @@
#include "base.h"
#include "expr.h"
#include "ir_operator.h"
#include "arithmetic.h"
#include "node/container.h"
......
......@@ -354,7 +354,7 @@ Example::
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
Expr count = make_one(inputs[0]->dtype);
Expr count = make_const(inputs[0]->dtype, 1);
for (auto& i : r_axes) {
count *= inputs[0]->shape[i];
}
......
......@@ -156,9 +156,9 @@ def any(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _expr.Or(args[0], args[1])
ret = _make._OpOr(args[0], args[1])
for i in range(2, len(args)):
ret = _expr.Or(ret, args[i])
ret = _make._OpOr(ret, args[i])
return ret
......@@ -180,9 +180,9 @@ def all(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _expr.And(args[0], args[1])
ret = _make._OpAnd(args[0], args[1])
for i in range(2, len(args)):
ret = _expr.And(ret, args[i])
ret = _make._OpAnd(ret, args[i])
return ret
......@@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max')
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
......@@ -60,7 +60,7 @@ class ExprOp(object):
return self.__rdiv__(other)
def __mod__(self, other):
return _make.Mod(self, other)
return _make._OpMod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
......@@ -85,10 +85,10 @@ class ExprOp(object):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
def __lt__(self, other):
return _make.LT(self, other)
return _make._OpLT(self, other)
def __le__(self, other):
return _make.LE(self, other)
return _make._OpLE(self, other)
def __eq__(self, other):
return EqualOp(self, other)
......@@ -97,10 +97,10 @@ class ExprOp(object):
return NotEqualOp(self, other)
def __gt__(self, other):
return _make.GT(self, other)
return _make._OpGT(self, other)
def __ge__(self, other):
return _make.GE(self, other)
return _make._OpGE(self, other)
def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
......@@ -122,7 +122,7 @@ class ExprOp(object):
ret : Expr
The equality expression.
"""
return _make.EQ(self, other)
return _make._OpEQ(self, other)
def astype(self, dtype):
"""Cast the expression to other type.
......@@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp):
def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)
return _make._OpEQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp):
......@@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp):
def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)
return _make._OpNE(self.a, self.b)
class Expr(ExprOp, NodeBase):
......
......@@ -24,7 +24,7 @@ def add(lhs, rhs):
op : tvm.Expr
The result Expr of add operaton.
"""
return _make.Add(lhs, rhs)
return _make._OpAdd(lhs, rhs)
def subtract(lhs, rhs):
......@@ -42,7 +42,7 @@ def subtract(lhs, rhs):
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _make.Sub(lhs, rhs)
return _make._OpSub(lhs, rhs)
def multiply(lhs, rhs):
......@@ -60,7 +60,7 @@ def multiply(lhs, rhs):
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _make.Mul(lhs, rhs)
return _make._OpMul(lhs, rhs)
def divide(lhs, rhs):
......@@ -78,7 +78,7 @@ def divide(lhs, rhs):
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.Div(lhs, rhs)
return _make._OpDiv(lhs, rhs)
def cast(src, dtype):
......
......@@ -5,7 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IROperator.h>
#include <tvm/ir_operator.h>
#include <tvm/api_registry.h>
#include <tvm/ir_operator.h>
......@@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE2(Add);
REGISTER_MAKE2(Sub);
REGISTER_MAKE2(Mul);
REGISTER_MAKE2(Div);
REGISTER_MAKE2(Mod);
REGISTER_MAKE2(Min);
REGISTER_MAKE2(Max);
REGISTER_MAKE2(EQ);
REGISTER_MAKE2(NE);
REGISTER_MAKE2(LT);
REGISTER_MAKE2(LE);
REGISTER_MAKE2(GT);
REGISTER_MAKE2(GE);
REGISTER_MAKE2(And);
REGISTER_MAKE2(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
......@@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer")
} \
})
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -14,10 +14,6 @@
namespace tvm {
namespace arith {
using HalideIR::Internal::add_would_overflow;
using HalideIR::Internal::sub_would_overflow;
using HalideIR::Internal::mul_would_overflow;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
......@@ -42,23 +38,9 @@ template<typename Op>
inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value);
template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
inline bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
inline bool GetConst(Expr e, int64_t* out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
const int64_t* v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
......@@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua OP ub); \
} \
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
return a + b;
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Sub::make(a, b);
return a - b;
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
return a * b;
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Div::make(a, b);
return a / b;
}
template<>
inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
if (is_zero(a)) return make_zero(a.type());
return ir::Mod::make(a, b);
return a % b;
}
template<>
......
......@@ -194,7 +194,7 @@ bool DetectClipBound(
if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
ret.coeff = Simplify(ret.coeff);
IntervalEntry& p = (*bmap)[var.get()];
if (is_one(ret.coeff)) {
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base);
......@@ -203,7 +203,7 @@ bool DetectClipBound(
}
return true;
}
if (is_const(ret.coeff, -1)) {
if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base);
......
......@@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() {
}
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
CHECK(is_zero(op->min));
CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) {
PrintIndent();
stream << "#pragma unroll\n";
......
......@@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor {
ChannelEntry& cb = cmap_.at(ch->handle_var.get());
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size.
int trigger_size;
int trigger_size = 0;
if (attr->attr_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0));
......
......@@ -5,6 +5,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <ir/IRPrinter.h>
#include <memory>
......
......@@ -8,6 +8,406 @@
namespace tvm {
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline bool IsIndexType(const Type& type) {
return type.is_int() && type.lanes() == 1 &&
(type.bits() == 32 || type.bits() == 64);
}
// simple cast that only checks if type matches and cast
inline Expr SimpleCast(const Type& t, Expr value) {
if (value.type() == t) return value;
return ir::Cast::make(t, value);
}
// The public function with a quick checking path.
void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
if (lhs.type() == rhs.type()) return;
Type ltype = lhs.type();
Type rtype = rhs.type();
if (ltype.lanes() == 1 && rtype.lanes() != 1) {
lhs = ir::Broadcast::make(lhs, rtype.lanes());
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
rhs = ir::Broadcast::make(rhs, ltype.lanes());
} else {
CHECK(ltype.lanes() == rtype.lanes())
<< "Cannot match type " << ltype << " vs " << rtype;
}
if (lhs.type() == rhs.type()) return;
// Only do very simple type coversion
// int->float, int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
if (!lhs.type().is_float() && rhs.type().is_float()) {
// int->float
lhs = ir::Cast::make(rhs.type(), lhs);
} else if (lhs.type().is_float() && !rhs.type().is_float()) {
// int->float
rhs = ir::Cast::make(lhs.type(), rhs);
} else if ((lhs.type().is_int() && rhs.type().is_int()) ||
(lhs.type().is_uint() && rhs.type().is_uint())) {
// promote int to higher bits
if (lhs.type().bits() < rhs.type().bits()) {
lhs = ir::Cast::make(rhs.type(), lhs);
} else {
rhs = ir::Cast::make(lhs.type(), rhs);
}
} else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
(lhs.type().is_uint() && rhs.type().is_int())) {
int bits = std::max(lhs.type().bits(), rhs.type().bits());
lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
}
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
if (val <= 0) return false;
shift[0] = 0;
while (val != 0) {
if (val & 1) {
return (val == 1);
}
++shift[0];
val = val >> 1;
}
return true;
}
bool is_const_power_of_two_integer(const Expr& x, int* shift) {
if (const auto* op = x.as<ir::IntImm>()) {
return ConstPowerHelper(op->value, shift);
} else if (const auto* op = x.as<ir::UIntImm>()) {
return ConstPowerHelper(op->value, shift);
} else {
return false;
}
}
Expr cast(const Type& t, Expr value) {
using ir::IntImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
}
return ir::Cast::make(t, value);
} else {
if (value.type().lanes() == 1) {
// manually unroll cast
Type vtype = t.element_of();
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
} else {
value = ir::Cast::make(vtype, value);
}
}
return ir::Broadcast::make(value, t.lanes());
} else {
CHECK(value.type().lanes() == t.lanes());
return ir::Cast::make(t, value);
}
}
}
Expr reinterpret(const Type& t, Expr value) {
if (value.type() == t) return value;
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}
#define TVM_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (IsIndexType(ta) && IsIndexType(tb)) { \
BODY; \
} \
BinaryOpMatchTypes(a, b);
Expr operator+(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return SimpleCast(rtype, b);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
});
return ir::Add::make(a, b);
}
Expr operator-(Expr a) {
using ir::IntImm;
const IntImm* pa = a.as<IntImm>();
if (pa) {
return ir::IntImm::make(a.type(), -pa->value);
}
return make_zero(a.type()) - a;
}
Expr operator-(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
});
return ir::Sub::make(a, b);
}
Expr operator*(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return SimpleCast(rtype, b);
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
if (pb->value == 0) return SimpleCast(rtype, b);
}
});
return ir::Mul::make(a, b);
}
Expr operator/(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return ir::Div::make(a, b);
}
Expr operator%(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return ir::Mod::make(a, b);
}
Expr min(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
});
return ir::Min::make(a, b);
}
Expr max(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
});
return ir::Max::make(a, b);
}
Expr select(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm;
using ir::UIntImm;
CHECK(cond.type().is_bool());
BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
} else if (const IntImm* op = cond.as<IntImm>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
}
return ir::Select::make(cond, true_value, false_value);
}
Expr likely(Expr cond) {
if (is_const(cond)) return cond;
return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
}
Expr operator>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
});
return ir::GT::make(a, b);
}
Expr operator>=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
});
return ir::GE::make(a, b);
}
Expr operator<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
});
return ir::LT::make(a, b);
}
Expr operator<=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
});
return ir::LE::make(a, b);
}
Expr operator==(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
});
return ir::EQ::make(a, b);
}
Expr operator!=(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
});
return ir::NE::make(a, b);
}
Expr operator&&(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) {
return UIntImm::make(UInt(1), pa->value && pb->value);
}
return ir::And::make(a, b);
}
Expr operator||(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) {
return UIntImm::make(UInt(1), pa->value || pb->value);
}
return ir::Or::make(a, b);
}
Expr operator!(Expr a) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
if (pa) {
return UIntImm::make(UInt(1), !(pa->value));
}
return ir::Not::make(a);
}
Expr operator>>(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a);
}
});
return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator<<(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a);
}
});
return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator&(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator|(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator^(Expr a, Expr b) {
TVM_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
});
return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator~(Expr a) {
CHECK(a.type().is_int() || a.type().is_uint());
return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
}
Expr pow(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}
Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else {
LOG(FATAL) << "Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}
Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y);
......@@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) {
Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Mul::make(x, y);
Expr identity_element = make_one(source.type());
Expr identity_element = make_const(source.type(), 1);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
......
......@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/runtime/device_api.h>
#include <vector>
......@@ -75,7 +76,7 @@ inline Expr TVMStructGet(
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind)};
make_const(Int(32), static_cast<int>(kind))};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
}
......@@ -125,7 +126,7 @@ inline Stmt TVMStructSet(
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind),
make_const(Int(32), static_cast<int>(kind)),
value};
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
......
......@@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator {
} else {
alloc_size = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
alloc_size *= op->extents[i];
alloc_size = alloc_size * op->extents[i];
}
alloc_size = ir::Simplify(alloc_size);
}
if (rw.write_count) {
......
......@@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
combo_size += make_const(Int(32), 1);
combo_size = combo_size + make_const(Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
......
......@@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
CHECK(is_positive_const(op->extent));
int lanes = 0;
bool succ = arith::GetConstInt(op->extent, &lanes);
if (!succ || lanes < 1) {
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
namespace {
using namespace tvm::ir;
......
......@@ -35,7 +35,7 @@ def test_deduce():
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4)
ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0)
......@@ -63,7 +63,7 @@ def test_check():
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing()
# multiple target variable
......@@ -88,11 +88,11 @@ def test_deduce_basic():
res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
......@@ -137,4 +137,3 @@ if __name__ == "__main__":
test_check()
test_deduce_basic()
test_deduce_complex()
......@@ -8,7 +8,7 @@ def test_const():
def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
y = tvm.var("x")
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)
......
import tvm
def test_const_fold():
def check(f, *args):
x = f(*[tvm.const(x) for x in args])
y = f(*args)
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
check(lambda x, y: x + y, 3, 4)
check(lambda x, y: x * y, 3, 12)
check(lambda x, y: x * y - 10, 3, 12)
check(lambda x, y: x - y % 10, 3, 12)
check(lambda x, y: x // y + 10, 100, 12)
check(lambda x, y: x & y + 10, 112, 128)
check(lambda x, y: x > y, 112, 128)
check(lambda x, y: x < y, 112, 128)
check(lambda x, y: x <= y, 112, 128)
check(lambda x, y: x >= y, 112, 128)
check(lambda x, y: (x | y) ^ 10, 112, 128)
def test_const_fold2():
x = tvm.var("x")
assert (x + 0).same_as(x)
assert (0 + x).same_as(x)
assert (x - 0).same_as(x)
assert (x % 1).value == 0
assert (x * 1).same_as(x)
assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div)
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
......@@ -15,7 +15,7 @@ def test_make_smap():
# save load json
x = tvm.const(1)
y = tvm.const(10)
z = x + y
z = tvm.expr.Add(x, y)
smap = tvm.convert({"z": z, "x": x})
json_str = tvm.save_json(tvm.convert([smap]))
arr = tvm.load_json(json_str)
......
......@@ -53,7 +53,6 @@ def test_canonical():
assert (tvm.ir_pass.Equal(ret1, ret2))
if __name__ == "__main__":
test_modular()
test_bound()
test_basic()
test_simplify()
......
......@@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(dtype, fill_value);
Expr ev = cast(dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype;
}
......@@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape,
}
/*!
* \brief Creates an operation that construct a tensor with same shape as input tensor,
* \brief Creates an operation that construct a tensor with same shape as input tensor,
* then fill a tensor with fill_value
*
* \param x The input tensor
......@@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x,
const Expr fill_value,
std::string name = "tensor",
std::string tag = kElementWise) {
Expr ev = lossless_cast(x->dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << x->dtype;
}
Expr ev = cast(x->dtype, fill_value);
return compute(x->shape, [&](const Array<Var>& i) {
return ev;
}, name, tag);
......
......@@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x,
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top);
const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left);
const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom);
const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right);
const int64_t *padding_h0 = as_const_int(pad_top);
const int64_t *padding_w0 = as_const_int(pad_left);
const int64_t *padding_h1 = as_const_int(pad_bottom);
const int64_t *padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
......@@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout,
* Since pooling does not care about the factor size of dimensions
* other than `H` and `W`, one can pass `NCHWc` as well.
* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
*
*
*
* \return The output tensor in the same layout
*/
......
......@@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
......@@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.select(temp > score[0], j, cls_id[0])
score[0] = tvm.make.Max(temp, score[0])
score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0
# [id, prob, xmin, ymin, xmax, ymax]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment