Unverified Commit f0079a57 by Tianqi Chen Committed by GitHub

[ARITH] Refactor to use explicit div/mod functions instead of operators. (#4000)

* [ARITH] Use explicit div/mod functions instead of operators.

* fix pooling case
parent 17c2c0a1
...@@ -217,16 +217,6 @@ TVM_DLL Expr operator*(Expr a, Expr b); ...@@ -217,16 +217,6 @@ TVM_DLL Expr operator*(Expr a, Expr b);
*/ */
TVM_DLL Expr operator/(Expr a, Expr b); TVM_DLL Expr operator/(Expr a, Expr b);
/*! /*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator * \brief left shift operator
* *
* \param a left operand * \param a left operand
...@@ -371,6 +361,35 @@ TVM_DLL Expr truncdiv(Expr a, Expr b); ...@@ -371,6 +361,35 @@ TVM_DLL Expr truncdiv(Expr a, Expr b);
*/ */
TVM_DLL Expr truncmod(Expr a, Expr b); TVM_DLL Expr truncmod(Expr a, Expr b);
/*! /*!
* \brief compute floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
*
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexdiv(Expr a, Expr b);
/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b) * \brief compute floor(a / b)
* *
* \param a left operand * \param a left operand
...@@ -662,21 +681,6 @@ inline Expr make_zero(Type t) { ...@@ -662,21 +681,6 @@ inline Expr make_zero(Type t) {
return make_const(t, 0); return make_const(t, 0);
} }
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// additional const expression overloading // additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \ inline Expr Name(Expr& a, Expr b) { \
...@@ -718,11 +722,9 @@ inline void DivAmbiguityError(const TA& a) { ...@@ -718,11 +722,9 @@ inline void DivAmbiguityError(const TA& a) {
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
...@@ -731,11 +733,12 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); ...@@ -731,11 +733,12 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops // integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
...@@ -745,5 +748,45 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); ...@@ -745,5 +748,45 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, indexdiv/indexmod, "
"floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// The following code are not intended to be used in the codebase.
// Instead, they generate clear compiler errors that ask developers
// to use the specific division function.
// The second template argument is necessary to make sure the
// code compiles lazily by the compiler during invocation.
template<typename TB>
inline Expr operator/(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator/=(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator%(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_OPERATOR_H_ #endif // TVM_EXPR_OPERATOR_H_
...@@ -235,8 +235,6 @@ DEFINE_OVERLOAD_SLICE_UNARY_OP(-); ...@@ -235,8 +235,6 @@ DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(+); DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
DEFINE_OVERLOAD_SLICE_BINARY_OP(-); DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(*); DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
DEFINE_OVERLOAD_SLICE_BINARY_OP(/);
DEFINE_OVERLOAD_SLICE_BINARY_OP(%);
DEFINE_OVERLOAD_SLICE_BINARY_OP(==); DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=); DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=); DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
......
...@@ -198,8 +198,8 @@ TVM_REGISTER_API("make.Allocate") ...@@ -198,8 +198,8 @@ TVM_REGISTER_API("make.Allocate")
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
......
...@@ -146,10 +146,12 @@ class BoundDeducer: public IRVisitor { ...@@ -146,10 +146,12 @@ class BoundDeducer: public IRVisitor {
success_ = false; success_ = false;
return; return;
} }
// always use relax bound // always use relax bound
bool divided = analyzer_.CanProve(result_ % operand == 0); bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
result_ = result_ / operand; // TODO(tvm-team): use floordiv, which could give better bound.
result_ = truncdiv(result_, operand);
if (!divided) { if (!divided) {
// Handle non-divisible case // Handle non-divisible case
......
...@@ -912,7 +912,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -912,7 +912,7 @@ Mutate_(const Mod* op, const Expr& self) {
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (temp.as<IntImm>()) { if (temp.as<IntImm>()) {
return temp % c1.Eval(); return truncmod(temp, c1.Eval());
} else { } else {
// If temp < cval && temp >=0 then can remove the mod. // If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) { if (TryCompare(temp, cval) == kLT) {
......
...@@ -93,12 +93,12 @@ inline Expr Compute<ir::Mul>(Expr a, Expr b) { ...@@ -93,12 +93,12 @@ inline Expr Compute<ir::Mul>(Expr a, Expr b) {
template<> template<>
inline Expr Compute<ir::Div>(Expr a, Expr b) { inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b; return truncdiv(a, b);
} }
template<> template<>
inline Expr Compute<ir::Mod>(Expr a, Expr b) { inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b; return truncmod(a, b);
} }
template<> template<>
......
...@@ -227,7 +227,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer, ...@@ -227,7 +227,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value % b->min_value); return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
} }
if (a->IsEmpty()) return a; if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b; if (b->IsEmpty()) return b;
......
...@@ -31,6 +31,10 @@ ...@@ -31,6 +31,10 @@
namespace tvm { namespace tvm {
// TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod;
using IndexDiv = ir::Div;
Array<Expr> SimplifyArray(Array<Expr> array) { Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) { for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ir::Simplify(array[i])); array.Set(i, ir::Simplify(array[i]));
...@@ -109,7 +113,7 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr, ...@@ -109,7 +113,7 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
Expr mult_inner; // The inner multiplication factor Expr mult_inner; // The inner multiplication factor
Expr no_opt_sum; // Sum of the exprs that cannot be optimized Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) { while (true) {
auto inner_div_ptr = search_ptr->as<Div>(); auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<Mul>(); auto inner_mult_ptr = search_ptr->as<Mul>();
auto inner_add_ptr = search_ptr->as<Add>(); auto inner_add_ptr = search_ptr->as<Add>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
...@@ -156,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles, ...@@ -156,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
*has_mult = false; *has_mult = false;
*has_mod = false; *has_mod = false;
for (const Expr* ele : eles) { for (const Expr* ele : eles) {
auto mod_ptr = ele->as<Mod>(); auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<Mul>(); auto mult_ptr = ele->as<Mul>();
if (mod_ptr) { if (mod_ptr) {
*has_mod = true; *has_mod = true;
...@@ -235,7 +239,8 @@ inline Expr MergeMulMod(const Expr &base) { ...@@ -235,7 +239,8 @@ inline Expr MergeMulMod(const Expr &base) {
} }
for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin(); for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) { it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second; no_opt_sum = no_opt_sum.get() ?
no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
} }
return no_opt_sum; return no_opt_sum;
} }
......
...@@ -236,10 +236,10 @@ inline bool GetStoreRule(Array<Expr>* rule, ...@@ -236,10 +236,10 @@ inline bool GetStoreRule(Array<Expr>* rule,
if (store_axis.IsPrimal()) { if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis); const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) { if (factor > 0) {
store = store / Expr(factor); store = indexdiv(store, Expr(factor));
} }
} else { } else {
store = store % store_axis_impl->dom->extent; store = indexmod(store, store_axis_impl->dom->extent);
} }
rule->push_back(store); rule->push_back(store);
......
...@@ -206,6 +206,15 @@ Expr operator%(Expr a, Expr b) { ...@@ -206,6 +206,15 @@ Expr operator%(Expr a, Expr b) {
return truncmod(a, b); return truncmod(a, b);
} }
// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b);
}
Expr indexmod(Expr a, Expr b) {
return truncmod(a, b);
}
Expr floordiv(Expr a, Expr b) { Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b); Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
......
...@@ -309,7 +309,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -309,7 +309,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
if (op->loop_var.get() == inner) { if (op->loop_var.get() == inner) {
CHECK(under_outer); CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent % op->extent; rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent; extent = op->extent;
fused = true; fused = true;
return ir::Substitute(op->body, rmap); return ir::Substitute(op->body, rmap);
...@@ -317,7 +317,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -317,7 +317,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
under_outer = true; under_outer = true;
Stmt body = IRMutator::Mutate(op->body); Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent; rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
under_outer = false; under_outer = false;
return For::make(parent->var, Expr(0), extent * op->extent, return For::make(parent->var, Expr(0), extent * op->extent,
...@@ -325,7 +325,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -325,7 +325,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
} else if (under_outer) { } else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body); Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent % op->extent; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
extent = extent * op->extent; extent = extent * op->extent;
return body; return body;
......
...@@ -120,7 +120,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -120,7 +120,8 @@ void ArgBinder::BindBuffer(const Buffer& arg,
Expr offset = value->elem_offset; Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor); Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type()); Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_); BinderAddAssert(truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
} }
} }
...@@ -288,7 +289,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -288,7 +289,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
Expr offset = buffer->elem_offset; Expr offset = buffer->elem_offset;
Expr factor = make_const(offset.type(), buffer->offset_factor); Expr factor = make_const(offset.type(), buffer->offset_factor);
Expr zero = make_zero(offset.type()); Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
} }
} }
} }
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch. * \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc * \file inject_double_buffer.cc
*/ */
...@@ -230,7 +228,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -230,7 +228,7 @@ class DoubleBufferInjector : public IRMutator {
Expr loop_shift = e.loop->loop_var + one; Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type()); e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two; e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true; in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false; in_double_buffer_scope_ = false;
...@@ -239,7 +237,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -239,7 +237,7 @@ class DoubleBufferInjector : public IRMutator {
vmap[e.loop->loop_var.get()] = zero; vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two; vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap); body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body); body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body); body = IfThenElse::make(loop_shift < e.loop->extent, body);
......
...@@ -178,6 +178,24 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -178,6 +178,24 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::Mutate_(op, e); return IRMutatorWithAnalyzer::Mutate_(op, e);
} }
Expr Mutate_(const EQ* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
Expr Mutate_(const NE* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
private: private:
Expr SwapBroadcastCast(const Expr& e) { Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x)) // Try to change broadcast(cast(x)) to cast(broadcast(x))
......
...@@ -264,14 +264,15 @@ class WarpAccessRewriter : protected IRMutator { ...@@ -264,14 +264,15 @@ class WarpAccessRewriter : protected IRMutator {
// simple case, warp index is on the highest. // simple case, warp index is on the highest.
if (warp_group_ == 1) { if (warp_group_ == 1) {
Expr x = analyzer_->canonical_simplify(index % m); Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr z = analyzer_->canonical_simplify(index / m); Expr z = analyzer_->canonical_simplify(indexdiv(index, m));
return std::make_pair(x, z); return std::make_pair(x, z);
} else { } else {
Expr x = analyzer_->canonical_simplify(index % m); Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
y = y * m + x; y = y * m + x;
Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m; Expr z = indexdiv(indexmod(index, make_const(index.type(), warp_coeff_ * warp_size_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y), return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z)); analyzer_->canonical_simplify(z));
} }
......
...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator { ...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
if (dim < avec.size() && avec[dim].align_factor != 0) { if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor); Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset); Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + (factor + offset - stride % factor) % factor; stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride); stride = ir::Simplify(stride);
} }
rstrides.push_back(stride); rstrides.push_back(stride);
......
...@@ -610,8 +610,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -610,8 +610,8 @@ class StoragePlanRewriter : public IRMutator {
} }
// transform to alloc bytes // transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes(); auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = analyzer_.CanProve(combo_size % type_bits == 0); bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
combo_size = combo_size / type_bits; combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided // round up for can not divided
if (!divided) { if (!divided) {
combo_size = combo_size + make_const(Int(32), 1); combo_size = combo_size + make_const(Int(32), 1);
......
...@@ -66,12 +66,12 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -66,12 +66,12 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (i == bit_axis) { if (i == bit_axis) {
out_shape.push_back(bits); out_shape.push_back(bits);
if (i == pack_axis) { if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits); out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else { } else {
out_shape.push_back(data->shape[i]); out_shape.push_back(data->shape[i]);
} }
} else if (i == pack_axis) { } else if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits); out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else { } else {
out_shape.push_back(data->shape[i]); out_shape.push_back(data->shape[i]);
} }
......
...@@ -154,7 +154,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -154,7 +154,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({dshape_nchw[1], Array<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups, indexdiv(param->channels, param->groups),
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
...@@ -184,7 +184,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -184,7 +184,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape); << " wshape=" << Array<IndexExpr>(wshape);
} }
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0])); CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1]; channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
...@@ -738,7 +738,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -738,7 +738,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape( Array<IndexExpr> wshape(
{param->channels, {param->channels,
data->shape[1] / param->groups, indexdiv(data->shape[1], param->groups),
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
channels = param->channels; channels = param->channels;
...@@ -767,7 +767,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -767,7 +767,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << wshape; << " wshape=" << wshape;
} }
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0]; channels = wshape[0];
ksize_y = wshape[2]; ksize_y = wshape[2];
ksize_x = wshape[3]; ksize_x = wshape[3];
...@@ -777,8 +777,10 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -777,8 +777,10 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// dilation // dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0}); Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape.Set(2, (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y,
oshape.Set(3, (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
// infer offset shape // infer offset shape
......
...@@ -74,10 +74,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -74,10 +74,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) { if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution // infer weight's shape for depthwise convolution
wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0], wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
param->kernel_size[1]}}; param->kernel_size[1]}};
} else { } else {
wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0], wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
param->kernel_size[1]}}; param->kernel_size[1]}};
} }
...@@ -108,7 +108,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -108,7 +108,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv2D: shape of weight is inconsistent with channels, " << "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape; << " channels=" << param->channels << " wshape=" << wshape;
} }
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
channels = wshape[0]; channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
...@@ -116,8 +116,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -116,8 +116,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
......
...@@ -615,7 +615,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -615,7 +615,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) { if (d0.as<Any>()) {
oshape.push_back(Any::make()); oshape.push_back(Any::make());
} else { } else {
oshape.push_back(d0 / d2); oshape.push_back(indexdiv(d0, d2));
} }
used_output_dims.insert(oshape.size()); used_output_dims.insert(oshape.size());
oshape.push_back(d2); oshape.push_back(d2);
...@@ -627,7 +627,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -627,7 +627,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) { if (d0.as<Any>()) {
oshape.push_back(Any::make()); oshape.push_back(Any::make());
} else { } else {
oshape.push_back(d0 / d1); oshape.push_back(indexdiv(d0, d1));
} }
} else { } else {
oshape.push_back(d2); oshape.push_back(d2);
...@@ -659,7 +659,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -659,7 +659,7 @@ bool ReshapeRel(const Array<Type>& types,
infer_dim = Any::make(); infer_dim = Any::make();
break; break;
} }
infer_dim /= oshape[i]; infer_dim = indexdiv(infer_dim, oshape[i]);
} }
} }
oshape.Set(infer_idx, infer_dim); oshape.Set(infer_idx, infer_dim);
...@@ -1987,13 +1987,13 @@ bool SplitRel(const Array<Type>& types, ...@@ -1987,13 +1987,13 @@ bool SplitRel(const Array<Type>& types,
<< "axis should be within the input dimension range."; << "axis should be within the input dimension range.";
if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) { if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
CHECK(reporter->Assert(data->shape[axis] % CHECK(reporter->Assert(indexmod(data->shape[axis],
sections->value == make_zero(Int(64)))) sections->value) == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]"; << "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields; std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) { for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] /= int32_t(sections->value); oshape[axis] = indexdiv(oshape[axis], sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
} }
......
...@@ -55,8 +55,8 @@ bool YoloReorgRel(const Array<Type>& types, ...@@ -55,8 +55,8 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension."; CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[1] = oshape[1] * param->stride * param->stride; oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride; oshape[2] = indexdiv(oshape[2], param->stride);
oshape[3] = oshape[3] / param->stride; oshape[3] = indexdiv(oshape[3], param->stride);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true; return true;
} }
......
...@@ -56,10 +56,10 @@ void PassDownDomain(const Stage& stage, ...@@ -56,10 +56,10 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx, arith::Analyzer* actx,
bool allow_missing) { bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) { auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) { if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(a / b); return actx->Simplify(indexdiv(a, b));
} }
return actx->Simplify((a + (b - 1)) / b); return actx->Simplify(indexdiv(a + (b - 1), b));
}; };
auto& state = *p_state; auto& state = *p_state;
...@@ -146,8 +146,8 @@ void PassUpIndex(const Stage& stage, ...@@ -146,8 +146,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min; Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min; Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor; state[s->outer] = indexdiv(value, factor);
state[s->inner] = value % factor; state[s->inner] = indexmod(value, factor);
// add min if they exist // add min if they exist
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min; state[s->outer] = state[s->outer] + outer_min;
...@@ -190,8 +190,8 @@ void PassDownIndex(const Stage& stage, ...@@ -190,8 +190,8 @@ void PassDownIndex(const Stage& stage,
CHECK(is_zero(r->min)); CHECK(is_zero(r->min));
Expr parent = state.at(s->parent); Expr parent = state.at(s->parent);
Expr factor = r->extent; Expr factor = r->extent;
state[s->outer] = parent / factor; state[s->outer] = indexdiv(parent, factor);
state[s->inner] = parent % factor; state[s->inner] = indexmod(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) { } else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) { if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing); CHECK(allow_missing);
...@@ -266,8 +266,8 @@ void PassUpDomain(const FuseNode* s, ...@@ -266,8 +266,8 @@ void PassUpDomain(const FuseNode* s,
if (fused.is_single_point()) { if (fused.is_single_point()) {
Expr value = fused.point_value(); Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor; Expr v_outer = indexdiv(value, factor);
Expr v_inner = value % factor; Expr v_inner = indexmod(value, factor);
if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer); *outer = IntSet::single_point(v_outer);
...@@ -275,17 +275,18 @@ void PassUpDomain(const FuseNode* s, ...@@ -275,17 +275,18 @@ void PassUpDomain(const FuseNode* s,
} else { } else {
Expr fused_extent = (fused.max() - fused.min() + 1); Expr fused_extent = (fused.max() - fused.min() + 1);
Expr inner_extent = dom_map.at(s->inner)->extent; Expr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::interval(outer_min + fused.min() / inner_extent, *outer = IntSet::interval(
outer_min + fused.max() / inner_extent); outer_min + indexdiv(fused.min(), inner_extent),
if (is_zero(Simplify(inner_extent % fused_extent)) && outer_min + indexdiv(fused.max(), inner_extent));
is_zero(Simplify(fused.min() % fused_extent)) ) { if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
// fused never spans multiple rows, make a tight bounding box // fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened // there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + fused.min() % inner_extent, *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
inner_min + fused.max() % inner_extent); inner_min + indexmod(fused.max(), inner_extent));
} else { // fused may span multiple rows, use full row widths } else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(fused_extent % inner_extent)) || if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(Simplify(fused.min() % inner_extent))) { !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
LOG(WARNING) << LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations"; "fused and original axes are not aligned, this may cause redundant computations";
} }
......
...@@ -193,8 +193,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator*, multiply); ...@@ -193,8 +193,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator*, multiply);
* *
* \return The result. * \return The result.
*/ */
TOPI_DEFINE_BCAST_OP(divide, { return a / b; }); TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
/*! /*!
* \fn mod * \fn mod
...@@ -207,8 +206,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator/, divide); ...@@ -207,8 +206,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
* *
* \return The result. * \return The result.
*/ */
TOPI_DEFINE_BCAST_OP(mod, { return a % b; }); TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
TOPI_DEFINE_OP_OVERLOAD(operator%, mod);
/*! /*!
* \fn maximum * \fn maximum
......
...@@ -47,8 +47,8 @@ inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) { ...@@ -47,8 +47,8 @@ inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
pad_h *= 2; pad_h *= 2;
pad_w *= 2; pad_w *= 2;
auto pad_top = (pad_h + 1) / 2; auto pad_top = indexdiv(pad_h + 1, 2);
auto pad_left = (pad_w + 1) / 2; auto pad_left = indexdiv(pad_w + 1, 2);
return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left }; return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left };
} }
......
...@@ -68,8 +68,8 @@ inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) { ...@@ -68,8 +68,8 @@ inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) {
std::vector<Expr> indices; std::vector<Expr> indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
indices.push_back(idx % shape[i]); indices.push_back(indexmod(idx, shape[i]));
idx = idx / shape[i]; idx = indexdiv(idx, shape[i]);
} }
std::reverse(indices.begin(), indices.end()); std::reverse(indices.begin(), indices.end());
return indices; return indices;
......
...@@ -290,8 +290,8 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I, ...@@ -290,8 +290,8 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
W->shape[0], // O W->shape[0], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
...@@ -339,8 +339,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I, ...@@ -339,8 +339,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
auto pH = I->shape[2]; auto pH = I->shape[2];
auto pW = I->shape[3]; auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1, // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
I->shape[2], // B I->shape[2], // B
W->shape[3] // O W->shape[3] // O
}; };
...@@ -393,8 +393,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, ...@@ -393,8 +393,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
W->shape[1], // O W->shape[1], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
...@@ -403,8 +403,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, ...@@ -403,8 +403,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
? I ? I
: pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
return tvm::sum(T(b, i / pCM, stride_h * h + kh, stride_w * w + kw) * return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
W(i / pCM, o % pCM, kh, kw), W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
{i, kh, kw}); {i, kh, kw});
}; };
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
...@@ -425,8 +425,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, ...@@ -425,8 +425,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
auto pCM = W->shape[1]; // channel_multiplier auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
(I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
(I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
W->shape[3], // O W->shape[3], // O
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
...@@ -436,8 +436,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, ...@@ -436,8 +436,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
? I ? I
: pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)}); : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) * return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
W(kh, kw, i / pCM, o % pCM), W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
{kh, kw, i}); {kh, kw, i});
}; };
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
...@@ -479,8 +479,8 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, ...@@ -479,8 +479,8 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
I->shape[0], // B I->shape[0], // B
I->shape[1], // G I->shape[1], // G
W->shape[2], // O W->shape[2], // O
(I->shape[3] - W->shape[3] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
(I->shape[4] - W->shape[4] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
......
...@@ -58,7 +58,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data, ...@@ -58,7 +58,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
Array<Expr> oshape; Array<Expr> oshape;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ? oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::ir::Simplify(ishape[i] / 32) : tvm::ir::Simplify(indexdiv(ishape[i], 32)) :
ishape[i]); ishape[i]);
} }
......
...@@ -89,8 +89,8 @@ inline Tensor dilate(const Tensor& x, ...@@ -89,8 +89,8 @@ inline Tensor dilate(const Tensor& x,
if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
index_tuple.push_back(indices[i]); index_tuple.push_back(indices[i]);
} else { } else {
index_tuple.push_back(indices[i] / strides[i]); index_tuple.push_back(indexdiv(indices[i], strides[i]));
not_zero.push_back((indices[i] % strides[i]) == 0); not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
} }
} }
if (not_zero.size() > 0) { if (not_zero.size() > 0) {
......
...@@ -70,8 +70,8 @@ inline Tensor flatten(const Tensor& x, ...@@ -70,8 +70,8 @@ inline Tensor flatten(const Tensor& x,
Expr idx = j; Expr idx = j;
std::vector<Expr> index; std::vector<Expr> index;
for (auto s : extra_shape) { for (auto s : extra_shape) {
index.push_back(idx % s); index.push_back(indexmod(idx, s));
idx = idx / s; idx = indexdiv(idx, s);
} }
index.push_back(i); index.push_back(i);
std::reverse(index.begin(), index.end()); std::reverse(index.begin(), index.end());
......
...@@ -85,7 +85,7 @@ inline Tensor lrn(const Tensor& data, ...@@ -85,7 +85,7 @@ inline Tensor lrn(const Tensor& data,
input_shape, input_shape,
[&](Var i, Var j, Var k, Var l) { [&](Var i, Var j, Var k, Var l) {
return tvm::pow(bias + return tvm::pow(bias +
(alpha * sqr_sum(i, j, k, l) / size), (div(alpha * sqr_sum(i, j, k, l), size)),
beta); beta);
}); });
return topi::divide(data, sqrt_sum_up); return topi::divide(data, sqrt_sum_up);
......
...@@ -102,9 +102,9 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -102,9 +102,9 @@ inline Tensor pool_impl(const Tensor& x,
pad_after.Set(width_axis, pad_right); pad_after.Set(width_axis, pad_right);
auto out_height = tvm::ir::Simplify( auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_bottom) / stride_height + 1); indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
auto out_width = tvm::ir::Simplify( auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1); indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
auto dheight = tvm::reduce_axis(Range(0, kernel_height)); auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
...@@ -149,7 +149,7 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -149,7 +149,7 @@ inline Tensor pool_impl(const Tensor& x,
Array<Expr> indices; Array<Expr> indices;
for (const Var& var : output) indices.push_back(var); for (const Var& var : output) indices.push_back(var);
if (count_include_pad) { if (count_include_pad) {
return pool_sum(indices) / (kernel_height * kernel_width); return div(pool_sum(indices), (kernel_height * kernel_width));
} else { } else {
Expr h_start = output[height_axis] * stride_height - pad_top; Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left; Expr w_start = output[width_axis] * stride_width - pad_left;
...@@ -159,7 +159,7 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -159,7 +159,7 @@ inline Tensor pool_impl(const Tensor& x,
w_start = ir::Max::make(w_start, make_const(Int(32), 0)); w_start = ir::Max::make(w_start, make_const(Int(32), 0));
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start), Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
make_const(Int(32), 1)); make_const(Int(32), 1));
return pool_sum(indices) / divide_factor; return div(pool_sum(indices), divide_factor);
} }
}, "tensor", kElementWise); }, "tensor", kElementWise);
} else { } else {
...@@ -439,14 +439,14 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp ...@@ -439,14 +439,14 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp
inline Expr start_index(const Var& out_index, inline Expr start_index(const Var& out_index,
const Expr& odim, const Expr& odim,
const Expr& idim) { const Expr& idim) {
return out_index * idim / odim; return indexdiv(out_index * idim, odim);
} }
inline Expr end_index(const Var& out_index, inline Expr end_index(const Var& out_index,
const Expr& odim, const Expr& odim,
const Expr& idim) { const Expr& idim) {
Expr tmp = (out_index + 1) * idim / odim; Expr tmp = indexdiv((out_index + 1) * idim, odim);
return tvm::ir::Select::make((out_index + 1) * idim % odim == 0, return tvm::ir::Select::make(indexmod((out_index + 1) * idim, odim) == 0,
tmp, tmp + 1); tmp, tmp + 1);
} }
...@@ -505,7 +505,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, ...@@ -505,7 +505,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
indices.Set(height_axis, i_start_h + dheight); indices.Set(height_axis, i_start_h + dheight);
indices.Set(width_axis, i_start_w + dwidth); indices.Set(width_axis, i_start_w + dwidth);
return tvm::sum(x(indices) / divide_factor, { dheight, dwidth }); return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
}, "tensor", "adaptive_pool_avg"); }, "tensor", "adaptive_pool_avg");
} else { } else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type; LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
......
...@@ -658,7 +658,7 @@ inline Tensor take(const Tensor& a, ...@@ -658,7 +658,7 @@ inline Tensor take(const Tensor& a,
} else { // mode == "wrap" } else { // mode == "wrap"
return compute( return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size; auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
return a(UnravelIndex(idx, a_shape)); return a(UnravelIndex(idx, a_shape));
}, name, tag); }, name, tag);
} }
...@@ -787,7 +787,7 @@ inline Tensor take(const Tensor& a, ...@@ -787,7 +787,7 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
} }
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim; auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
real_indices.push_back(idx); real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) { for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
...@@ -888,7 +888,7 @@ inline Tensor repeat(const Tensor& x, ...@@ -888,7 +888,7 @@ inline Tensor repeat(const Tensor& x,
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]); idx.push_back(indices[i]);
} }
idx.push_back(indices[axis] / repeats); idx.push_back(indexdiv(indices[axis], repeats));
for (size_t i = axis + 1; i < indices.size(); ++i) { for (size_t i = axis + 1; i < indices.size(); ++i) {
idx.push_back(indices[i]); idx.push_back(indices[i]);
} }
...@@ -944,10 +944,10 @@ inline Tensor tile(const Tensor& x, ...@@ -944,10 +944,10 @@ inline Tensor tile(const Tensor& x,
Array<Expr> idx; Array<Expr> idx;
if (ndim >= rdim) { if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i) for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[i] % x->shape[i]); idx.push_back(indexmod(indices[i], x->shape[i]));
} else { } else {
for (size_t i = 0; i < ndim; ++i) for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[rdim - ndim + i] % x->shape[i]); idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
} }
return x(idx); return x(idx);
}, name, tag); }, name, tag);
......
...@@ -64,9 +64,9 @@ inline Tensor reorg(const Tensor &data, ...@@ -64,9 +64,9 @@ inline Tensor reorg(const Tensor &data,
auto out = tvm::compute(input_shape, auto out = tvm::compute(input_shape,
[&](Var b, Var k, Var j, Var i) { [&](Var b, Var k, Var j, Var i) {
return data(b * stride * stride, return data(b * stride * stride,
(k % out_c) * stride * stride, indexmod(k, out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride, (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
(i*stride + (k / out_c) % stride)); (i*stride + indexmod(indexdiv(k, out_c), stride)));
}, },
name, name,
tag); tag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment