Unverified Commit f0079a57 by Tianqi Chen Committed by GitHub

[ARITH] Refactor to use explicit div/mod functions instead of operators. (#4000)

* [ARITH] Use explicit div/mod functions instead of operators.

* fix pooling case
parent 17c2c0a1
...@@ -217,16 +217,6 @@ TVM_DLL Expr operator*(Expr a, Expr b); ...@@ -217,16 +217,6 @@ TVM_DLL Expr operator*(Expr a, Expr b);
*/ */
TVM_DLL Expr operator/(Expr a, Expr b); TVM_DLL Expr operator/(Expr a, Expr b);
/*! /*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator * \brief left shift operator
* *
* \param a left operand * \param a left operand
...@@ -371,6 +361,35 @@ TVM_DLL Expr truncdiv(Expr a, Expr b); ...@@ -371,6 +361,35 @@ TVM_DLL Expr truncdiv(Expr a, Expr b);
*/ */
TVM_DLL Expr truncmod(Expr a, Expr b); TVM_DLL Expr truncmod(Expr a, Expr b);
/*! /*!
* \brief compute floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
*
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexdiv(Expr a, Expr b);
/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b) * \brief compute floor(a / b)
* *
* \param a left operand * \param a left operand
...@@ -662,21 +681,6 @@ inline Expr make_zero(Type t) { ...@@ -662,21 +681,6 @@ inline Expr make_zero(Type t) {
return make_const(t, 0); return make_const(t, 0);
} }
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// additional const expression overloading // additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \ inline Expr Name(Expr& a, Expr b) { \
...@@ -718,11 +722,9 @@ inline void DivAmbiguityError(const TA& a) { ...@@ -718,11 +722,9 @@ inline void DivAmbiguityError(const TA& a) {
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
...@@ -731,11 +733,12 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); ...@@ -731,11 +733,12 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops // integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
...@@ -745,5 +748,45 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); ...@@ -745,5 +748,45 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, indexdiv/indexmod, "
"floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// The following code are not intended to be used in the codebase.
// Instead, they generate clear compiler errors that ask developers
// to use the specific division function.
// The second template argument is necessary to make sure the
// code compiles lazily by the compiler during invocation.
template<typename TB>
inline Expr operator/(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator/=(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator%(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_OPERATOR_H_ #endif // TVM_EXPR_OPERATOR_H_
...@@ -235,8 +235,6 @@ DEFINE_OVERLOAD_SLICE_UNARY_OP(-); ...@@ -235,8 +235,6 @@ DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(+); DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
DEFINE_OVERLOAD_SLICE_BINARY_OP(-); DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(*); DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
DEFINE_OVERLOAD_SLICE_BINARY_OP(/);
DEFINE_OVERLOAD_SLICE_BINARY_OP(%);
DEFINE_OVERLOAD_SLICE_BINARY_OP(==); DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=); DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=); DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
......
...@@ -198,8 +198,8 @@ TVM_REGISTER_API("make.Allocate") ...@@ -198,8 +198,8 @@ TVM_REGISTER_API("make.Allocate")
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
......
...@@ -146,10 +146,12 @@ class BoundDeducer: public IRVisitor { ...@@ -146,10 +146,12 @@ class BoundDeducer: public IRVisitor {
success_ = false; success_ = false;
return; return;
} }
// always use relax bound // always use relax bound
bool divided = analyzer_.CanProve(result_ % operand == 0); bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
result_ = result_ / operand; // TODO(tvm-team): use floordiv, which could give better bound.
result_ = truncdiv(result_, operand);
if (!divided) { if (!divided) {
// Handle non-divisible case // Handle non-divisible case
......
...@@ -912,7 +912,7 @@ Mutate_(const Mod* op, const Expr& self) { ...@@ -912,7 +912,7 @@ Mutate_(const Mod* op, const Expr& self) {
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (temp.as<IntImm>()) { if (temp.as<IntImm>()) {
return temp % c1.Eval(); return truncmod(temp, c1.Eval());
} else { } else {
// If temp < cval && temp >=0 then can remove the mod. // If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) { if (TryCompare(temp, cval) == kLT) {
......
...@@ -93,12 +93,12 @@ inline Expr Compute<ir::Mul>(Expr a, Expr b) { ...@@ -93,12 +93,12 @@ inline Expr Compute<ir::Mul>(Expr a, Expr b) {
template<> template<>
inline Expr Compute<ir::Div>(Expr a, Expr b) { inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b; return truncdiv(a, b);
} }
template<> template<>
inline Expr Compute<ir::Mod>(Expr a, Expr b) { inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b; return truncmod(a, b);
} }
template<> template<>
......
...@@ -227,7 +227,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer, ...@@ -227,7 +227,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value % b->min_value); return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
} }
if (a->IsEmpty()) return a; if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b; if (b->IsEmpty()) return b;
......
...@@ -31,6 +31,10 @@ ...@@ -31,6 +31,10 @@
namespace tvm { namespace tvm {
// TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod;
using IndexDiv = ir::Div;
Array<Expr> SimplifyArray(Array<Expr> array) { Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) { for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ir::Simplify(array[i])); array.Set(i, ir::Simplify(array[i]));
...@@ -109,7 +113,7 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr, ...@@ -109,7 +113,7 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
Expr mult_inner; // The inner multiplication factor Expr mult_inner; // The inner multiplication factor
Expr no_opt_sum; // Sum of the exprs that cannot be optimized Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) { while (true) {
auto inner_div_ptr = search_ptr->as<Div>(); auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<Mul>(); auto inner_mult_ptr = search_ptr->as<Mul>();
auto inner_add_ptr = search_ptr->as<Add>(); auto inner_add_ptr = search_ptr->as<Add>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
...@@ -156,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles, ...@@ -156,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
*has_mult = false; *has_mult = false;
*has_mod = false; *has_mod = false;
for (const Expr* ele : eles) { for (const Expr* ele : eles) {
auto mod_ptr = ele->as<Mod>(); auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<Mul>(); auto mult_ptr = ele->as<Mul>();
if (mod_ptr) { if (mod_ptr) {
*has_mod = true; *has_mod = true;
...@@ -235,7 +239,8 @@ inline Expr MergeMulMod(const Expr &base) { ...@@ -235,7 +239,8 @@ inline Expr MergeMulMod(const Expr &base) {
} }
for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin(); for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) { it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second; no_opt_sum = no_opt_sum.get() ?
no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
} }
return no_opt_sum; return no_opt_sum;
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -236,10 +236,10 @@ inline bool GetStoreRule(Array<Expr>* rule, ...@@ -236,10 +236,10 @@ inline bool GetStoreRule(Array<Expr>* rule,
if (store_axis.IsPrimal()) { if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis); const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) { if (factor > 0) {
store = store / Expr(factor); store = indexdiv(store, Expr(factor));
} }
} else { } else {
store = store % store_axis_impl->dom->extent; store = indexmod(store, store_axis_impl->dom->extent);
} }
rule->push_back(store); rule->push_back(store);
......
...@@ -206,6 +206,15 @@ Expr operator%(Expr a, Expr b) { ...@@ -206,6 +206,15 @@ Expr operator%(Expr a, Expr b) {
return truncmod(a, b); return truncmod(a, b);
} }
// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b);
}
Expr indexmod(Expr a, Expr b) {
return truncmod(a, b);
}
Expr floordiv(Expr a, Expr b) { Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b); Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
......
...@@ -309,7 +309,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -309,7 +309,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
if (op->loop_var.get() == inner) { if (op->loop_var.get() == inner) {
CHECK(under_outer); CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent % op->extent; rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent; extent = op->extent;
fused = true; fused = true;
return ir::Substitute(op->body, rmap); return ir::Substitute(op->body, rmap);
...@@ -317,7 +317,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -317,7 +317,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
under_outer = true; under_outer = true;
Stmt body = IRMutator::Mutate(op->body); Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent; rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
under_outer = false; under_outer = false;
return For::make(parent->var, Expr(0), extent * op->extent, return For::make(parent->var, Expr(0), extent * op->extent,
...@@ -325,7 +325,7 @@ Stmt ApplyLoopShapes(const Stage &stage, ...@@ -325,7 +325,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
} else if (under_outer) { } else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body); Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap; std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent % op->extent; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap); body = ir::Substitute(body, rmap);
extent = extent * op->extent; extent = extent * op->extent;
return body; return body;
......
...@@ -120,7 +120,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -120,7 +120,8 @@ void ArgBinder::BindBuffer(const Buffer& arg,
Expr offset = value->elem_offset; Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor); Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type()); Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_); BinderAddAssert(truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
} }
} }
...@@ -288,7 +289,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -288,7 +289,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
Expr offset = buffer->elem_offset; Expr offset = buffer->elem_offset;
Expr factor = make_const(offset.type(), buffer->offset_factor); Expr factor = make_const(offset.type(), buffer->offset_factor);
Expr zero = make_zero(offset.type()); Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_); BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
} }
} }
} }
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch. * \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc * \file inject_double_buffer.cc
*/ */
...@@ -230,7 +228,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -230,7 +228,7 @@ class DoubleBufferInjector : public IRMutator {
Expr loop_shift = e.loop->loop_var + one; Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type()); e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two; e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true; in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false; in_double_buffer_scope_ = false;
...@@ -239,7 +237,7 @@ class DoubleBufferInjector : public IRMutator { ...@@ -239,7 +237,7 @@ class DoubleBufferInjector : public IRMutator {
vmap[e.loop->loop_var.get()] = zero; vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two; vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap); body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body); body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body); body = IfThenElse::make(loop_shift < e.loop->extent, body);
......
...@@ -178,6 +178,24 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -178,6 +178,24 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::Mutate_(op, e); return IRMutatorWithAnalyzer::Mutate_(op, e);
} }
Expr Mutate_(const EQ* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
Expr Mutate_(const NE* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
private: private:
Expr SwapBroadcastCast(const Expr& e) { Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x)) // Try to change broadcast(cast(x)) to cast(broadcast(x))
......
...@@ -264,14 +264,15 @@ class WarpAccessRewriter : protected IRMutator { ...@@ -264,14 +264,15 @@ class WarpAccessRewriter : protected IRMutator {
// simple case, warp index is on the highest. // simple case, warp index is on the highest.
if (warp_group_ == 1) { if (warp_group_ == 1) {
Expr x = analyzer_->canonical_simplify(index % m); Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr z = analyzer_->canonical_simplify(index / m); Expr z = analyzer_->canonical_simplify(indexdiv(index, m));
return std::make_pair(x, z); return std::make_pair(x, z);
} else { } else {
Expr x = analyzer_->canonical_simplify(index % m); Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
y = y * m + x; y = y * m + x;
Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m; Expr z = indexdiv(indexmod(index, make_const(index.type(), warp_coeff_ * warp_size_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y), return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z)); analyzer_->canonical_simplify(z));
} }
......
...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator { ...@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
if (dim < avec.size() && avec[dim].align_factor != 0) { if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor); Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset); Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + (factor + offset - stride % factor) % factor; stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride); stride = ir::Simplify(stride);
} }
rstrides.push_back(stride); rstrides.push_back(stride);
......
...@@ -610,8 +610,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -610,8 +610,8 @@ class StoragePlanRewriter : public IRMutator {
} }
// transform to alloc bytes // transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes(); auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = analyzer_.CanProve(combo_size % type_bits == 0); bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
combo_size = combo_size / type_bits; combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided // round up for can not divided
if (!divided) { if (!divided) {
combo_size = combo_size + make_const(Int(32), 1); combo_size = combo_size + make_const(Int(32), 1);
......
...@@ -66,12 +66,12 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -66,12 +66,12 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (i == bit_axis) { if (i == bit_axis) {
out_shape.push_back(bits); out_shape.push_back(bits);
if (i == pack_axis) { if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits); out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else { } else {
out_shape.push_back(data->shape[i]); out_shape.push_back(data->shape[i]);
} }
} else if (i == pack_axis) { } else if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits); out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else { } else {
out_shape.push_back(data->shape[i]); out_shape.push_back(data->shape[i]);
} }
...@@ -102,12 +102,12 @@ TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); ...@@ -102,12 +102,12 @@ TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
RELAY_REGISTER_OP("nn.bitpack") RELAY_REGISTER_OP("nn.bitpack")
.describe(R"code(Bitpack layer that prepares data for bitserial operations. .describe(R"code(Bitpack layer that prepares data for bitserial operations.
This layer backs the bits of an input into a single datatype, allowing This layer backs the bits of an input into a single datatype, allowing
efficient implementation of bitserial operations. efficient implementation of bitserial operations.
- **data**: Input tensor of any shape, dimension that is to be - **data**: Input tensor of any shape, dimension that is to be
packed must be divisible by number of bits. packed must be divisible by number of bits.
- **out**: Packed tensor with shape appropriately compressed. - **out**: Packed tensor with shape appropriately compressed.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.BitPackAttrs") .set_attrs_type_key("relay.attrs.BitPackAttrs")
...@@ -183,7 +183,7 @@ on some platforms. ...@@ -183,7 +183,7 @@ on some platforms.
When data is NCHW, weight is expected to be OIHW or OIHWi. When data is NCHW, weight is expected to be OIHW or OIHWi.
When data is NHWC weight is expected to be HWIO or HWIOi. When data is NHWC weight is expected to be HWIO or HWIOi.
- **out**: Output with same layout as input. - **out**: Output with same layout as input.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BinaryConv2DAttrs") .set_attrs_type_key("relay.attrs.BinaryConv2DAttrs")
.set_num_inputs(2) .set_num_inputs(2)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -154,9 +154,9 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -154,9 +154,9 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({dshape_nchw[1], Array<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups, indexdiv(param->channels, param->groups),
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
wshape = trans_kernel_layout.BackwardShape(wshape); wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
...@@ -184,7 +184,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -184,7 +184,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape); << " wshape=" << Array<IndexExpr>(wshape);
} }
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0])); CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1]; channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
...@@ -738,7 +738,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -738,7 +738,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape( Array<IndexExpr> wshape(
{param->channels, {param->channels,
data->shape[1] / param->groups, indexdiv(data->shape[1], param->groups),
param->kernel_size[0], param->kernel_size[0],
param->kernel_size[1]}); param->kernel_size[1]});
channels = param->channels; channels = param->channels;
...@@ -767,7 +767,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -767,7 +767,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< " channels=" << param->channels << " channels=" << param->channels
<< " wshape=" << wshape; << " wshape=" << wshape;
} }
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0]; channels = wshape[0];
ksize_y = wshape[2]; ksize_y = wshape[2];
ksize_x = wshape[3]; ksize_x = wshape[3];
...@@ -777,8 +777,10 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -777,8 +777,10 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// dilation // dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0}); Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape.Set(2, (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y,
oshape.Set(3, (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
// infer offset shape // infer offset shape
......
...@@ -74,10 +74,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -74,10 +74,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) { if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution // infer weight's shape for depthwise convolution
wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0], wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
param->kernel_size[1]}}; param->kernel_size[1]}};
} else { } else {
wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0], wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
param->kernel_size[1]}}; param->kernel_size[1]}};
} }
...@@ -108,7 +108,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -108,7 +108,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv2D: shape of weight is inconsistent with channels, " << "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape; << " channels=" << param->channels << " wshape=" << wshape;
} }
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
channels = wshape[0]; channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
...@@ -116,8 +116,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -116,8 +116,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
out_dtype = data->dtype; out_dtype = data->dtype;
......
...@@ -615,7 +615,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -615,7 +615,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) { if (d0.as<Any>()) {
oshape.push_back(Any::make()); oshape.push_back(Any::make());
} else { } else {
oshape.push_back(d0 / d2); oshape.push_back(indexdiv(d0, d2));
} }
used_output_dims.insert(oshape.size()); used_output_dims.insert(oshape.size());
oshape.push_back(d2); oshape.push_back(d2);
...@@ -627,7 +627,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -627,7 +627,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) { if (d0.as<Any>()) {
oshape.push_back(Any::make()); oshape.push_back(Any::make());
} else { } else {
oshape.push_back(d0 / d1); oshape.push_back(indexdiv(d0, d1));
} }
} else { } else {
oshape.push_back(d2); oshape.push_back(d2);
...@@ -659,7 +659,7 @@ bool ReshapeRel(const Array<Type>& types, ...@@ -659,7 +659,7 @@ bool ReshapeRel(const Array<Type>& types,
infer_dim = Any::make(); infer_dim = Any::make();
break; break;
} }
infer_dim /= oshape[i]; infer_dim = indexdiv(infer_dim, oshape[i]);
} }
} }
oshape.Set(infer_idx, infer_dim); oshape.Set(infer_idx, infer_dim);
...@@ -1987,13 +1987,13 @@ bool SplitRel(const Array<Type>& types, ...@@ -1987,13 +1987,13 @@ bool SplitRel(const Array<Type>& types,
<< "axis should be within the input dimension range."; << "axis should be within the input dimension range.";
if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) { if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
CHECK(reporter->Assert(data->shape[axis] % CHECK(reporter->Assert(indexmod(data->shape[axis],
sections->value == make_zero(Int(64)))) sections->value) == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]"; << "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields; std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) { for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] /= int32_t(sections->value); oshape[axis] = indexdiv(oshape[axis], sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -55,8 +55,8 @@ bool YoloReorgRel(const Array<Type>& types, ...@@ -55,8 +55,8 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension."; CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[1] = oshape[1] * param->stride * param->stride; oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride; oshape[2] = indexdiv(oshape[2], param->stride);
oshape[3] = oshape[3] / param->stride; oshape[3] = indexdiv(oshape[3], param->stride);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true; return true;
} }
......
...@@ -56,10 +56,10 @@ void PassDownDomain(const Stage& stage, ...@@ -56,10 +56,10 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx, arith::Analyzer* actx,
bool allow_missing) { bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) { auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) { if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(a / b); return actx->Simplify(indexdiv(a, b));
} }
return actx->Simplify((a + (b - 1)) / b); return actx->Simplify(indexdiv(a + (b - 1), b));
}; };
auto& state = *p_state; auto& state = *p_state;
...@@ -146,8 +146,8 @@ void PassUpIndex(const Stage& stage, ...@@ -146,8 +146,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min; Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min; Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor; state[s->outer] = indexdiv(value, factor);
state[s->inner] = value % factor; state[s->inner] = indexmod(value, factor);
// add min if they exist // add min if they exist
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min; state[s->outer] = state[s->outer] + outer_min;
...@@ -190,8 +190,8 @@ void PassDownIndex(const Stage& stage, ...@@ -190,8 +190,8 @@ void PassDownIndex(const Stage& stage,
CHECK(is_zero(r->min)); CHECK(is_zero(r->min));
Expr parent = state.at(s->parent); Expr parent = state.at(s->parent);
Expr factor = r->extent; Expr factor = r->extent;
state[s->outer] = parent / factor; state[s->outer] = indexdiv(parent, factor);
state[s->inner] = parent % factor; state[s->inner] = indexmod(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) { } else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) { if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing); CHECK(allow_missing);
...@@ -266,8 +266,8 @@ void PassUpDomain(const FuseNode* s, ...@@ -266,8 +266,8 @@ void PassUpDomain(const FuseNode* s,
if (fused.is_single_point()) { if (fused.is_single_point()) {
Expr value = fused.point_value(); Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor; Expr v_outer = indexdiv(value, factor);
Expr v_inner = value % factor; Expr v_inner = indexmod(value, factor);
if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer); *outer = IntSet::single_point(v_outer);
...@@ -275,17 +275,18 @@ void PassUpDomain(const FuseNode* s, ...@@ -275,17 +275,18 @@ void PassUpDomain(const FuseNode* s,
} else { } else {
Expr fused_extent = (fused.max() - fused.min() + 1); Expr fused_extent = (fused.max() - fused.min() + 1);
Expr inner_extent = dom_map.at(s->inner)->extent; Expr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::interval(outer_min + fused.min() / inner_extent, *outer = IntSet::interval(
outer_min + fused.max() / inner_extent); outer_min + indexdiv(fused.min(), inner_extent),
if (is_zero(Simplify(inner_extent % fused_extent)) && outer_min + indexdiv(fused.max(), inner_extent));
is_zero(Simplify(fused.min() % fused_extent)) ) { if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
// fused never spans multiple rows, make a tight bounding box // fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened // there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + fused.min() % inner_extent, *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
inner_min + fused.max() % inner_extent); inner_min + indexmod(fused.max(), inner_extent));
} else { // fused may span multiple rows, use full row widths } else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(fused_extent % inner_extent)) || if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(Simplify(fused.min() % inner_extent))) { !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
LOG(WARNING) << LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations"; "fused and original axes are not aligned, this may cause redundant computations";
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -193,8 +193,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator*, multiply); ...@@ -193,8 +193,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator*, multiply);
* *
* \return The result. * \return The result.
*/ */
TOPI_DEFINE_BCAST_OP(divide, { return a / b; }); TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
/*! /*!
* \fn mod * \fn mod
...@@ -207,8 +206,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator/, divide); ...@@ -207,8 +206,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
* *
* \return The result. * \return The result.
*/ */
TOPI_DEFINE_BCAST_OP(mod, { return a % b; }); TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
TOPI_DEFINE_OP_OVERLOAD(operator%, mod);
/*! /*!
* \fn maximum * \fn maximum
......
...@@ -47,8 +47,8 @@ inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) { ...@@ -47,8 +47,8 @@ inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
pad_h *= 2; pad_h *= 2;
pad_w *= 2; pad_w *= 2;
auto pad_top = (pad_h + 1) / 2; auto pad_top = indexdiv(pad_h + 1, 2);
auto pad_left = (pad_w + 1) / 2; auto pad_left = indexdiv(pad_w + 1, 2);
return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left }; return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left };
} }
......
...@@ -68,8 +68,8 @@ inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) { ...@@ -68,8 +68,8 @@ inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) {
std::vector<Expr> indices; std::vector<Expr> indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
indices.push_back(idx % shape[i]); indices.push_back(indexmod(idx, shape[i]));
idx = idx / shape[i]; idx = indexdiv(idx, shape[i]);
} }
std::reverse(indices.begin(), indices.end()); std::reverse(indices.begin(), indices.end());
return indices; return indices;
......
...@@ -288,10 +288,10 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I, ...@@ -288,10 +288,10 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
auto pH = I->shape[2]; auto pH = I->shape[2];
auto pW = I->shape[3]; auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
W->shape[0], // O W->shape[0], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
...@@ -339,8 +339,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I, ...@@ -339,8 +339,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
auto pH = I->shape[2]; auto pH = I->shape[2];
auto pW = I->shape[3]; auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1, // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
I->shape[2], // B I->shape[2], // B
W->shape[3] // O W->shape[3] // O
}; };
...@@ -393,8 +393,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, ...@@ -393,8 +393,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
W->shape[1], // O W->shape[1], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
...@@ -403,8 +403,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, ...@@ -403,8 +403,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
? I ? I
: pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) { auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
return tvm::sum(T(b, i / pCM, stride_h * h + kh, stride_w * w + kw) * return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
W(i / pCM, o % pCM, kh, kw), W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
{i, kh, kw}); {i, kh, kw});
}; };
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
...@@ -425,8 +425,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, ...@@ -425,8 +425,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
auto pCM = W->shape[1]; // channel_multiplier auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::Expr> output_shape{ tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B I->shape[0], // B
(I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
(I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
W->shape[3], // O W->shape[3], // O
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
...@@ -436,8 +436,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, ...@@ -436,8 +436,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
? I ? I
: pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)}); : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) * return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
W(kh, kw, i / pCM, o % pCM), W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
{kh, kw, i}); {kh, kw, i});
}; };
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
...@@ -479,8 +479,8 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I, ...@@ -479,8 +479,8 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
I->shape[0], // B I->shape[0], // B
I->shape[1], // G I->shape[1], // G
W->shape[2], // O W->shape[2], // O
(I->shape[3] - W->shape[3] + 2 * pad_h) / stride_h + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
(I->shape[4] - W->shape[4] + 2 * pad_w) / stride_w + 1 // W indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
}; };
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i"); auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh"); auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
......
...@@ -58,7 +58,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data, ...@@ -58,7 +58,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
Array<Expr> oshape; Array<Expr> oshape;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ? oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::ir::Simplify(ishape[i] / 32) : tvm::ir::Simplify(indexdiv(ishape[i], 32)) :
ishape[i]); ishape[i]);
} }
......
...@@ -89,8 +89,8 @@ inline Tensor dilate(const Tensor& x, ...@@ -89,8 +89,8 @@ inline Tensor dilate(const Tensor& x,
if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
index_tuple.push_back(indices[i]); index_tuple.push_back(indices[i]);
} else { } else {
index_tuple.push_back(indices[i] / strides[i]); index_tuple.push_back(indexdiv(indices[i], strides[i]));
not_zero.push_back((indices[i] % strides[i]) == 0); not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
} }
} }
if (not_zero.size() > 0) { if (not_zero.size() > 0) {
......
...@@ -70,8 +70,8 @@ inline Tensor flatten(const Tensor& x, ...@@ -70,8 +70,8 @@ inline Tensor flatten(const Tensor& x,
Expr idx = j; Expr idx = j;
std::vector<Expr> index; std::vector<Expr> index;
for (auto s : extra_shape) { for (auto s : extra_shape) {
index.push_back(idx % s); index.push_back(indexmod(idx, s));
idx = idx / s; idx = indexdiv(idx, s);
} }
index.push_back(i); index.push_back(i);
std::reverse(index.begin(), index.end()); std::reverse(index.begin(), index.end());
......
...@@ -85,7 +85,7 @@ inline Tensor lrn(const Tensor& data, ...@@ -85,7 +85,7 @@ inline Tensor lrn(const Tensor& data,
input_shape, input_shape,
[&](Var i, Var j, Var k, Var l) { [&](Var i, Var j, Var k, Var l) {
return tvm::pow(bias + return tvm::pow(bias +
(alpha * sqr_sum(i, j, k, l) / size), (div(alpha * sqr_sum(i, j, k, l), size)),
beta); beta);
}); });
return topi::divide(data, sqrt_sum_up); return topi::divide(data, sqrt_sum_up);
......
...@@ -102,9 +102,9 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -102,9 +102,9 @@ inline Tensor pool_impl(const Tensor& x,
pad_after.Set(width_axis, pad_right); pad_after.Set(width_axis, pad_right);
auto out_height = tvm::ir::Simplify( auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_bottom) / stride_height + 1); indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
auto out_width = tvm::ir::Simplify( auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1); indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
auto dheight = tvm::reduce_axis(Range(0, kernel_height)); auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width)); auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
...@@ -149,7 +149,7 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -149,7 +149,7 @@ inline Tensor pool_impl(const Tensor& x,
Array<Expr> indices; Array<Expr> indices;
for (const Var& var : output) indices.push_back(var); for (const Var& var : output) indices.push_back(var);
if (count_include_pad) { if (count_include_pad) {
return pool_sum(indices) / (kernel_height * kernel_width); return div(pool_sum(indices), (kernel_height * kernel_width));
} else { } else {
Expr h_start = output[height_axis] * stride_height - pad_top; Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left; Expr w_start = output[width_axis] * stride_width - pad_left;
...@@ -159,7 +159,7 @@ inline Tensor pool_impl(const Tensor& x, ...@@ -159,7 +159,7 @@ inline Tensor pool_impl(const Tensor& x,
w_start = ir::Max::make(w_start, make_const(Int(32), 0)); w_start = ir::Max::make(w_start, make_const(Int(32), 0));
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start), Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
make_const(Int(32), 1)); make_const(Int(32), 1));
return pool_sum(indices) / divide_factor; return div(pool_sum(indices), divide_factor);
} }
}, "tensor", kElementWise); }, "tensor", kElementWise);
} else { } else {
...@@ -439,14 +439,14 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp ...@@ -439,14 +439,14 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp
inline Expr start_index(const Var& out_index, inline Expr start_index(const Var& out_index,
const Expr& odim, const Expr& odim,
const Expr& idim) { const Expr& idim) {
return out_index * idim / odim; return indexdiv(out_index * idim, odim);
} }
inline Expr end_index(const Var& out_index, inline Expr end_index(const Var& out_index,
const Expr& odim, const Expr& odim,
const Expr& idim) { const Expr& idim) {
Expr tmp = (out_index + 1) * idim / odim; Expr tmp = indexdiv((out_index + 1) * idim, odim);
return tvm::ir::Select::make((out_index + 1) * idim % odim == 0, return tvm::ir::Select::make(indexmod((out_index + 1) * idim, odim) == 0,
tmp, tmp + 1); tmp, tmp + 1);
} }
...@@ -505,7 +505,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, ...@@ -505,7 +505,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2"); auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
indices.Set(height_axis, i_start_h + dheight); indices.Set(height_axis, i_start_h + dheight);
indices.Set(width_axis, i_start_w + dwidth); indices.Set(width_axis, i_start_w + dwidth);
return tvm::sum(x(indices) / divide_factor, { dheight, dwidth }); return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
}, "tensor", "adaptive_pool_avg"); }, "tensor", "adaptive_pool_avg");
} else { } else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type; LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
......
...@@ -658,7 +658,7 @@ inline Tensor take(const Tensor& a, ...@@ -658,7 +658,7 @@ inline Tensor take(const Tensor& a,
} else { // mode == "wrap" } else { // mode == "wrap"
return compute( return compute(
out_shape, [&](const Array<Var>& out_index) { out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size; auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
return a(UnravelIndex(idx, a_shape)); return a(UnravelIndex(idx, a_shape));
}, name, tag); }, name, tag);
} }
...@@ -787,7 +787,7 @@ inline Tensor take(const Tensor& a, ...@@ -787,7 +787,7 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) { for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
} }
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim; auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
real_indices.push_back(idx); real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) { for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]); real_indices.push_back(out_index[j]);
...@@ -888,7 +888,7 @@ inline Tensor repeat(const Tensor& x, ...@@ -888,7 +888,7 @@ inline Tensor repeat(const Tensor& x,
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) { for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]); idx.push_back(indices[i]);
} }
idx.push_back(indices[axis] / repeats); idx.push_back(indexdiv(indices[axis], repeats));
for (size_t i = axis + 1; i < indices.size(); ++i) { for (size_t i = axis + 1; i < indices.size(); ++i) {
idx.push_back(indices[i]); idx.push_back(indices[i]);
} }
...@@ -944,10 +944,10 @@ inline Tensor tile(const Tensor& x, ...@@ -944,10 +944,10 @@ inline Tensor tile(const Tensor& x,
Array<Expr> idx; Array<Expr> idx;
if (ndim >= rdim) { if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i) for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[i] % x->shape[i]); idx.push_back(indexmod(indices[i], x->shape[i]));
} else { } else {
for (size_t i = 0; i < ndim; ++i) for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[rdim - ndim + i] % x->shape[i]); idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
} }
return x(idx); return x(idx);
}, name, tag); }, name, tag);
...@@ -1253,7 +1253,7 @@ inline Tensor ndarray_size(const Tensor& src, ...@@ -1253,7 +1253,7 @@ inline Tensor ndarray_size(const Tensor& src,
} }
/*! /*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value. other locations take value off_value.
* \param indices locations to set to on_value. * \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on. * \param on_value value that locations represented by indices take on.
......
...@@ -64,9 +64,9 @@ inline Tensor reorg(const Tensor &data, ...@@ -64,9 +64,9 @@ inline Tensor reorg(const Tensor &data,
auto out = tvm::compute(input_shape, auto out = tvm::compute(input_shape,
[&](Var b, Var k, Var j, Var i) { [&](Var b, Var k, Var j, Var i) {
return data(b * stride * stride, return data(b * stride * stride,
(k % out_c) * stride * stride, indexmod(k, out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride, (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
(i*stride + (k / out_c) % stride)); (i*stride + indexmod(indexdiv(k, out_c), stride)));
}, },
name, name,
tag); tag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment