Unverified Commit 6d460606 by Tianqi Chen Committed by GitHub

[EXPR] ir_operator.h->expr_operator.h Centralize const folder logic (#2719)

parent 1eb1dac4
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "expr_operator.h"
#include "tvm/node/container.h" #include "tvm/node/container.h"
namespace tvm { namespace tvm {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include "ir_operator.h" #include "expr_operator.h"
namespace tvm { namespace tvm {
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file tvm/ir_operator.h * \file tvm/expr_operator.h
* \brief Common operators defined for Expr. * \brief Common operators defined for Expr.
* *
* \note Most of the operator defined here perform simple constant folding * \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions. * when the type is int32 or int64 for simplifying the index expressions.
*/ */
#ifndef TVM_IR_OPERATOR_H_ #ifndef TVM_EXPR_OPERATOR_H_
#define TVM_IR_OPERATOR_H_ #define TVM_EXPR_OPERATOR_H_
#include <algorithm> #include <algorithm>
#include <type_traits> #include <type_traits>
...@@ -617,4 +617,4 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); ...@@ -617,4 +617,4 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
} // namespace tvm } // namespace tvm
#endif // TVM_IR_OPERATOR_H_ #endif // TVM_EXPR_OPERATOR_H_
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "expr_operator.h"
#include "tensor.h" #include "tensor.h"
#include "schedule.h" #include "schedule.h"
#include "arithmetic.h" #include "arithmetic.h"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "expr_operator.h"
#include "arithmetic.h" #include "arithmetic.h"
namespace tvm { namespace tvm {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "base.h" #include "base.h"
#include "expr.h" #include "expr.h"
#include "ir_operator.h" #include "expr_operator.h"
#include "tensor.h" #include "tensor.h"
#include "operation.h" #include "operation.h"
#include "packed_func_ext.h" #include "packed_func_ext.h"
......
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
/*!
* Copyright (c) 2019 by Contributors
* \file const_fold.h
* \brief Centralized location for constant folding.
*/
#ifndef TVM_ARITHMETIC_CONST_FOLD_H_
#define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h>
#include <algorithm>
namespace tvm {
namespace arith {
/*!
* \brief Try to run binary compute with constant folding.
*
* \param a The left operand.
* \param b The right operand.
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a, Expr b);
/*!
* \brief Try to run unary compute with constant folding.
*
* \param a The left operand.
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
inline Expr TryConstFold(Expr a);
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline bool IsIndexType(const Type& type) {
return type.is_int() && type.lanes() == 1 &&
(type.bits() == 32 || type.bits() == 64);
}
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
BODY; \
} \
// specialization of constant folders.
template<>
inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a;
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
}
if (fb) {
if (fb->value == 1) return a;
if (fb->value == 0) return b;
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) return a;
if (fb) {
if (fb->value == 1) return a;
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return a;
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
return Expr();
}
template<>
inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
return Expr();
}
template<>
inline Expr TryConstFold<ir::Not>(Expr a) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
if (pa) {
return UIntImm::make(UInt(1), !(pa->value));
}
return Expr();
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief Modular set analysis * \brief Modular set analysis
*/ */
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <limits> #include <limits>
#include "pattern_match.h" #include "pattern_match.h"
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file ir_operator.cc * \file expr_operator.cc
*/ */
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <cmath> #include <cmath>
// Centralized header for constant folders.
#include "../arithmetic/const_fold.h"
namespace tvm { namespace tvm {
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline bool IsIndexType(const Type& type) {
return type.is_int() && type.lanes() == 1 &&
(type.bits() == 32 || type.bits() == 64);
}
// simple cast that only checks if type matches and cast // simple cast that only checks if type matches and cast
inline Expr SimpleCast(const Type& t, Expr value) { inline Expr SimpleCast(const Type& t, Expr value) {
if (value.type() == t) return value; if (value.type() == t) return value;
...@@ -135,45 +123,14 @@ Expr reinterpret(const Type& t, Expr value) { ...@@ -135,45 +123,14 @@ Expr reinterpret(const Type& t, Expr value) {
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
} }
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (IsIndexType(ta) && IsIndexType(tb)) { \
BODY; \
} \
BinaryOpMatchTypes(a, b);
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
BinaryOpMatchTypes(a, b); \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
Expr operator+(Expr a, Expr b) { Expr operator+(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Add>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return SimpleCast(rtype, b);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return SimpleCast(rtype, b);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
});
return ir::Add::make(a, b); return ir::Add::make(a, b);
} }
// negation
Expr operator-(Expr a) { Expr operator-(Expr a) {
using ir::IntImm; using ir::IntImm;
using ir::FloatImm; using ir::FloatImm;
...@@ -185,114 +142,44 @@ Expr operator-(Expr a) { ...@@ -185,114 +142,44 @@ Expr operator-(Expr a) {
} }
Expr operator-(Expr a, Expr b) { Expr operator-(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Sub>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return SimpleCast(rtype, a);
if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return SimpleCast(rtype, a);
});
return ir::Sub::make(a, b); return ir::Sub::make(a, b);
} }
Expr operator*(Expr a, Expr b) { Expr operator*(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Mul>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return SimpleCast(rtype, b);
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
if (pb->value == 0) return SimpleCast(rtype, b);
}
if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return SimpleCast(rtype, b);
if (fa->value == 0) return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
if (fb->value == 0) return SimpleCast(rtype, b);
}
});
return ir::Mul::make(a, b); return ir::Mul::make(a, b);
} }
Expr operator/(Expr a, Expr b) { Expr operator/(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Div>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm::make(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) {
return SimpleCast(rtype, a);
}
if (fb) {
if (fb->value == 1) return SimpleCast(rtype, a);
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
return ir::Div::make(a, b); return ir::Div::make(a, b);
} }
Expr operator%(Expr a, Expr b) { Expr operator%(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
Type rtype = ta.bits() >= tb.bits() ? ta : tb; Expr ret = arith::TryConstFold<ir::Mod>(a, b);
// due to division and mod can have different modes if (ret.defined()) return ret;
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
return IntImm::make(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return SimpleCast(rtype, a);
}
if (pb) {
if (pb->value == 1) return make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
return ir::Mod::make(a, b); return ir::Mod::make(a, b);
} }
Expr min(Expr a, Expr b) { Expr min(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Min>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
return ir::Min::make(a, b); return ir::Min::make(a, b);
} }
Expr max(Expr a, Expr b) { Expr max(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
const Type& ta = a.type(); Expr ret = arith::TryConstFold<ir::Max>(a, b);
const Type& tb = b.type(); if (ret.defined()) return ret;
Type rtype = ta.bits() >= tb.bits() ? ta : tb;
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
return ir::Max::make(a, b); return ir::Max::make(a, b);
} }
...@@ -328,129 +215,116 @@ Expr likely(Expr cond) { ...@@ -328,129 +215,116 @@ Expr likely(Expr cond) {
} }
Expr operator>(Expr a, Expr b) { Expr operator>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); Expr ret = arith::TryConstFold<ir::GT>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); if (ret.defined()) return ret;
});
return ir::GT::make(a, b); return ir::GT::make(a, b);
} }
Expr operator>=(Expr a, Expr b) { Expr operator>=(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); Expr ret = arith::TryConstFold<ir::GE>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); if (ret.defined()) return ret;
});
return ir::GE::make(a, b); return ir::GE::make(a, b);
} }
Expr operator<(Expr a, Expr b) { Expr operator<(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); Expr ret = arith::TryConstFold<ir::LT>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); if (ret.defined()) return ret;
});
return ir::LT::make(a, b); return ir::LT::make(a, b);
} }
Expr operator<=(Expr a, Expr b) { Expr operator<=(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); Expr ret = arith::TryConstFold<ir::LE>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); if (ret.defined()) return ret;
});
return ir::LE::make(a, b); return ir::LE::make(a, b);
} }
Expr operator==(Expr a, Expr b) { Expr operator==(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); Expr ret = arith::TryConstFold<ir::EQ>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); if (ret.defined()) return ret;
});
return ir::EQ::make(a, b); return ir::EQ::make(a, b);
} }
Expr operator!=(Expr a, Expr b) { Expr operator!=(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ BinaryOpMatchTypes(a, b);
if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); Expr ret = arith::TryConstFold<ir::NE>(a, b);
if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); if (ret.defined()) return ret;
});
return ir::NE::make(a, b); return ir::NE::make(a, b);
} }
Expr operator&&(Expr a, Expr b) { Expr operator&&(Expr a, Expr b) {
using ir::UIntImm; CHECK(a.type().is_bool());
if (a.type().is_bool() && b.type().is_bool()) { CHECK(b.type().is_bool());
const UIntImm* pa = a.as<UIntImm>(); Expr ret = arith::TryConstFold<ir::And>(a, b);
const UIntImm* pb = b.as<UIntImm>(); if (ret.defined()) return ret;
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
}
return ir::And::make(a, b); return ir::And::make(a, b);
} }
Expr operator||(Expr a, Expr b) { Expr operator||(Expr a, Expr b) {
using ir::UIntImm; CHECK(a.type().is_bool());
if (a.type().is_bool() && b.type().is_bool()) { CHECK(b.type().is_bool());
const UIntImm* pa = a.as<UIntImm>(); Expr ret = arith::TryConstFold<ir::Or>(a, b);
const UIntImm* pb = b.as<UIntImm>(); if (ret.defined()) return ret;
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
}
return ir::Or::make(a, b); return ir::Or::make(a, b);
} }
Expr operator!(Expr a) { Expr operator!(Expr a) {
using ir::UIntImm; CHECK(a.type().is_bool());
const UIntImm* pa = a.as<UIntImm>(); Expr ret = arith::TryConstFold<ir::Not>(a);
if (pa) { if (ret.defined()) return ret;
return UIntImm::make(UInt(1), !(pa->value));
}
return ir::Not::make(a); return ir::Not::make(a);
} }
Expr operator>>(Expr a, Expr b) { Expr operator>>(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) { if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a); if (pb->value == 0) return a;
} }
}); });
return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
} }
Expr operator<<(Expr a, Expr b) { Expr operator<<(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) { if (pb) {
if (pb->value == 0) return SimpleCast(rtype, a); if (pb->value == 0) return a;
} }
}); });
return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
} }
Expr operator&(Expr a, Expr b) { Expr operator&(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
}); });
return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
} }
Expr operator|(Expr a, Expr b) { Expr operator|(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
}); });
return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
} }
Expr operator^(Expr a, Expr b) { Expr operator^(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
Type rtype = ta.bits() >= tb.bits() ? ta : tb; const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
}); });
return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <ir/Expr.h> #include <ir/Expr.h>
#include <unordered_set> #include <unordered_set>
#include <string> #include <string>
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#define TVM_PASS_IR_UTIL_H_ #define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/target_info.h> #include <tvm/target_info.h>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief Implementation of operator pad * \brief Implementation of operator pad
*/ */
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn.h> #include <topi/nn.h>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <topi/transform.h> #include <topi/transform.h>
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* \brief This is a backend-aware optimization pass. * \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one. * Fuse necessary ops into a single one.
*/ */
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h> #include <tvm/expr_operator.h>
namespace { namespace {
using namespace tvm::ir; using namespace tvm::ir;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment