Unverified Commit f0079a57 by Tianqi Chen Committed by GitHub

[ARITH] Refactor to use explicit div/mod functions instead of operators. (#4000)

* [ARITH] Use explicit div/mod functions instead of operators.

* fix pooling case
parent 17c2c0a1
......@@ -217,16 +217,6 @@ TVM_DLL Expr operator*(Expr a, Expr b);
*/
TVM_DLL Expr operator/(Expr a, Expr b);
/*!
* \brief mod operator
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr operator%(Expr a, Expr b);
/*!
* \brief left shift operator
*
* \param a left operand
......@@ -371,6 +361,35 @@ TVM_DLL Expr truncdiv(Expr a, Expr b);
*/
TVM_DLL Expr truncmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
*
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexdiv(Expr a, Expr b);
/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
* Use this function for index split calculation.
* This function might take advantage of the fact
* that a and b are non-negative.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr indexmod(Expr a, Expr b);
/*!
* \brief compute floor(a / b)
*
* \param a left operand
......@@ -662,21 +681,6 @@ inline Expr make_zero(Type t) {
return make_const(t, 0);
}
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
......@@ -718,11 +722,9 @@ inline void DivAmbiguityError(const TA& a) {
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
......@@ -731,11 +733,12 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
......@@ -745,5 +748,45 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, indexdiv/indexmod, "
"floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}
// The following code are not intended to be used in the codebase.
// Instead, they generate clear compiler errors that ask developers
// to use the specific division function.
// The second template argument is necessary to make sure the
// code compiles lazily by the compiler during invocation.
template<typename TB>
inline Expr operator/(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator/=(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
inline Expr operator%(const Expr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
} // namespace tvm
#endif // TVM_EXPR_OPERATOR_H_
......@@ -235,8 +235,6 @@ DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
DEFINE_OVERLOAD_SLICE_BINARY_OP(/);
DEFINE_OVERLOAD_SLICE_BINARY_OP(%);
DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
......
......@@ -198,8 +198,8 @@ TVM_REGISTER_API("make.Allocate")
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, operator/);
REGISTER_MAKE_BINARY_OP(_OpMod, operator%);
REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
......
......@@ -146,10 +146,12 @@ class BoundDeducer: public IRVisitor {
success_ = false;
return;
}
// always use relax bound
bool divided = analyzer_.CanProve(result_ % operand == 0);
bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
result_ = result_ / operand;
// TODO(tvm-team): use floordiv, which could give better bound.
result_ = truncdiv(result_, operand);
if (!divided) {
// Handle non-divisible case
......
......@@ -912,7 +912,7 @@ Mutate_(const Mod* op, const Expr& self) {
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
return temp % c1.Eval();
return truncmod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) {
......
......@@ -93,12 +93,12 @@ inline Expr Compute<ir::Mul>(Expr a, Expr b) {
template<>
inline Expr Compute<ir::Div>(Expr a, Expr b) {
return a / b;
return truncdiv(a, b);
}
template<>
inline Expr Compute<ir::Mod>(Expr a, Expr b) {
return a % b;
return truncmod(a, b);
}
template<>
......
......@@ -227,7 +227,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value % b->min_value);
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
......
......@@ -31,6 +31,10 @@
namespace tvm {
// TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod;
using IndexDiv = ir::Div;
Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ir::Simplify(array[i]));
......@@ -109,7 +113,7 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
Expr mult_inner; // The inner multiplication factor
Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) {
auto inner_div_ptr = search_ptr->as<Div>();
auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<Mul>();
auto inner_add_ptr = search_ptr->as<Add>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
......@@ -156,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
*has_mult = false;
*has_mod = false;
for (const Expr* ele : eles) {
auto mod_ptr = ele->as<Mod>();
auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<Mul>();
if (mod_ptr) {
*has_mod = true;
......@@ -235,7 +239,8 @@ inline Expr MergeMulMod(const Expr &base) {
}
for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second;
no_opt_sum = no_opt_sum.get() ?
no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
}
return no_opt_sum;
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -236,10 +236,10 @@ inline bool GetStoreRule(Array<Expr>* rule,
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = store / Expr(factor);
store = indexdiv(store, Expr(factor));
}
} else {
store = store % store_axis_impl->dom->extent;
store = indexmod(store, store_axis_impl->dom->extent);
}
rule->push_back(store);
......
......@@ -206,6 +206,15 @@ Expr operator%(Expr a, Expr b) {
return truncmod(a, b);
}
// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b);
}
Expr indexmod(Expr a, Expr b) {
return truncmod(a, b);
}
Expr floordiv(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
......
......@@ -309,7 +309,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
if (op->loop_var.get() == inner) {
CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent % op->extent;
rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent;
fused = true;
return ir::Substitute(op->body, rmap);
......@@ -317,7 +317,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
under_outer = true;
Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap);
under_outer = false;
return For::make(parent->var, Expr(0), extent * op->extent,
......@@ -325,7 +325,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
} else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent % op->extent;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
return body;
......
......@@ -120,7 +120,8 @@ void ArgBinder::BindBuffer(const Buffer& arg,
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
BinderAddAssert(truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
......@@ -288,7 +289,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
Expr offset = buffer->elem_offset;
Expr factor = make_const(offset.type(), buffer->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
......
......@@ -18,8 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
......@@ -230,7 +228,7 @@ class DoubleBufferInjector : public IRMutator {
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.type());
e.switch_read_var = e.loop->loop_var % two;
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
in_double_buffer_scope_ = false;
......@@ -239,7 +237,7 @@ class DoubleBufferInjector : public IRMutator {
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = loop_shift % two;
vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap);
body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
body = IfThenElse::make(loop_shift < e.loop->extent, body);
......
......@@ -178,6 +178,24 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
Expr Mutate_(const EQ* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) == 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
Expr Mutate_(const NE* op, const Expr& e) final {
using namespace arith;
PVar<Expr> x, y;
if ((floormod(x, y) != 0).Match(e)) {
return Mutate((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::Mutate_(op, e);
}
private:
Expr SwapBroadcastCast(const Expr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
......
......@@ -264,14 +264,15 @@ class WarpAccessRewriter : protected IRMutator {
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
Expr x = analyzer_->canonical_simplify(index % m);
Expr z = analyzer_->canonical_simplify(index / m);
Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr z = analyzer_->canonical_simplify(indexdiv(index, m));
return std::make_pair(x, z);
} else {
Expr x = analyzer_->canonical_simplify(index % m);
Expr x = analyzer_->canonical_simplify(indexmod(index, m));
Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
y = y * m + x;
Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m;
Expr z = indexdiv(indexmod(index, make_const(index.type(), warp_coeff_ * warp_size_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
}
......
......@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + (factor + offset - stride % factor) % factor;
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
......
......@@ -610,8 +610,8 @@ class StoragePlanRewriter : public IRMutator {
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = analyzer_.CanProve(combo_size % type_bits == 0);
combo_size = combo_size / type_bits;
bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided
if (!divided) {
combo_size = combo_size + make_const(Int(32), 1);
......
......@@ -66,12 +66,12 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (i == bit_axis) {
out_shape.push_back(bits);
if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits);
out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else {
out_shape.push_back(data->shape[i]);
}
} else if (i == pack_axis) {
out_shape.push_back(data->shape[i] / pack_bits);
out_shape.push_back(indexdiv(data->shape[i], pack_bits));
} else {
out_shape.push_back(data->shape[i]);
}
......@@ -102,12 +102,12 @@ TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
RELAY_REGISTER_OP("nn.bitpack")
.describe(R"code(Bitpack layer that prepares data for bitserial operations.
This layer backs the bits of an input into a single datatype, allowing
This layer backs the bits of an input into a single datatype, allowing
efficient implementation of bitserial operations.
- **data**: Input tensor of any shape, dimension that is to be
packed must be divisible by number of bits.
- **out**: Packed tensor with shape appropriately compressed.
- **out**: Packed tensor with shape appropriately compressed.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.BitPackAttrs")
......@@ -183,7 +183,7 @@ on some platforms.
When data is NCHW, weight is expected to be OIHW or OIHWi.
When data is NHWC weight is expected to be HWIO or HWIOi.
- **out**: Output with same layout as input.
- **out**: Output with same layout as input.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BinaryConv2DAttrs")
.set_num_inputs(2)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -154,9 +154,9 @@ bool Conv2DTransposeRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
indexdiv(param->channels, param->groups),
param->kernel_size[0],
param->kernel_size[1]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
......@@ -184,7 +184,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0]));
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
......@@ -738,7 +738,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape(
{param->channels,
data->shape[1] / param->groups,
indexdiv(data->shape[1], param->groups),
param->kernel_size[0],
param->kernel_size[1]});
channels = param->channels;
......@@ -767,7 +767,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< " channels=" << param->channels
<< " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0];
ksize_y = wshape[2];
ksize_x = wshape[3];
......@@ -777,8 +777,10 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape.Set(2, (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;
// infer offset shape
......
......@@ -74,10 +74,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution
wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0],
wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
param->kernel_size[1]}};
} else {
wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0],
wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
param->kernel_size[1]}};
}
......@@ -108,7 +108,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
......@@ -116,8 +116,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
......
......@@ -615,7 +615,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d0 / d2);
oshape.push_back(indexdiv(d0, d2));
}
used_output_dims.insert(oshape.size());
oshape.push_back(d2);
......@@ -627,7 +627,7 @@ bool ReshapeRel(const Array<Type>& types,
if (d0.as<Any>()) {
oshape.push_back(Any::make());
} else {
oshape.push_back(d0 / d1);
oshape.push_back(indexdiv(d0, d1));
}
} else {
oshape.push_back(d2);
......@@ -659,7 +659,7 @@ bool ReshapeRel(const Array<Type>& types,
infer_dim = Any::make();
break;
}
infer_dim /= oshape[i];
infer_dim = indexdiv(infer_dim, oshape[i]);
}
}
oshape.Set(infer_idx, infer_dim);
......@@ -1987,13 +1987,13 @@ bool SplitRel(const Array<Type>& types,
<< "axis should be within the input dimension range.";
if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
CHECK(reporter->Assert(data->shape[axis] %
sections->value == make_zero(Int(64))))
CHECK(reporter->Assert(indexmod(data->shape[axis],
sections->value) == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] /= int32_t(sections->value);
oshape[axis] = indexdiv(oshape[axis], sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -55,8 +55,8 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride;
oshape[3] = oshape[3] / param->stride;
oshape[2] = indexdiv(oshape[2], param->stride);
oshape[3] = indexdiv(oshape[3], param->stride);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
......
......@@ -56,10 +56,10 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) {
return actx->Simplify(a / b);
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
return actx->Simplify((a + (b - 1)) / b);
return actx->Simplify(indexdiv(a + (b - 1), b));
};
auto& state = *p_state;
......@@ -146,8 +146,8 @@ void PassUpIndex(const Stage& stage,
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
state[s->outer] = indexdiv(value, factor);
state[s->inner] = indexmod(value, factor);
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
......@@ -190,8 +190,8 @@ void PassDownIndex(const Stage& stage,
CHECK(is_zero(r->min));
Expr parent = state.at(s->parent);
Expr factor = r->extent;
state[s->outer] = parent / factor;
state[s->inner] = parent % factor;
state[s->outer] = indexdiv(parent, factor);
state[s->inner] = indexmod(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) {
if (!state.count(s->inner) && !state.count(s->outer)) {
CHECK(allow_missing);
......@@ -266,8 +266,8 @@ void PassUpDomain(const FuseNode* s,
if (fused.is_single_point()) {
Expr value = fused.point_value();
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
Expr v_outer = indexdiv(value, factor);
Expr v_inner = indexmod(value, factor);
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
......@@ -275,17 +275,18 @@ void PassUpDomain(const FuseNode* s,
} else {
Expr fused_extent = (fused.max() - fused.min() + 1);
Expr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::interval(outer_min + fused.min() / inner_extent,
outer_min + fused.max() / inner_extent);
if (is_zero(Simplify(inner_extent % fused_extent)) &&
is_zero(Simplify(fused.min() % fused_extent)) ) {
*outer = IntSet::interval(
outer_min + indexdiv(fused.min(), inner_extent),
outer_min + indexdiv(fused.max(), inner_extent));
if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
// fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + fused.min() % inner_extent,
inner_min + fused.max() % inner_extent);
*inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
inner_min + indexmod(fused.max(), inner_extent));
} else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(fused_extent % inner_extent)) ||
!is_zero(Simplify(fused.min() % inner_extent))) {
if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
!is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations";
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -193,8 +193,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator*, multiply);
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(divide, { return a / b; });
TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
/*!
* \fn mod
......@@ -207,8 +206,7 @@ TOPI_DEFINE_OP_OVERLOAD(operator/, divide);
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(mod, { return a % b; });
TOPI_DEFINE_OP_OVERLOAD(operator%, mod);
TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
/*!
* \fn maximum
......
......@@ -47,8 +47,8 @@ inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
pad_h *= 2;
pad_w *= 2;
auto pad_top = (pad_h + 1) / 2;
auto pad_left = (pad_w + 1) / 2;
auto pad_top = indexdiv(pad_h + 1, 2);
auto pad_left = indexdiv(pad_w + 1, 2);
return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left };
}
......
......@@ -68,8 +68,8 @@ inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) {
std::vector<Expr> indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
indices.push_back(idx % shape[i]);
idx = idx / shape[i];
indices.push_back(indexmod(idx, shape[i]));
idx = indexdiv(idx, shape[i]);
}
std::reverse(indices.begin(), indices.end());
return indices;
......
......@@ -288,10 +288,10 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
auto pH = I->shape[2];
auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
W->shape[0], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W
I->shape[0], // B
W->shape[0], // O
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
......@@ -339,8 +339,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
auto pH = I->shape[2];
auto pW = I->shape[3];
tvm::Array<tvm::Expr> output_shape{
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1, // W
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
I->shape[2], // B
W->shape[3] // O
};
......@@ -393,8 +393,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
W->shape[1], // O
(I->shape[2] - W->shape[2] + 2 * pad_h) / stride_h + 1, // H
(I->shape[3] - W->shape[3] + 2 * pad_w) / stride_w + 1 // W
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
......@@ -403,8 +403,8 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
? I
: pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
return tvm::sum(T(b, i / pCM, stride_h * h + kh, stride_w * w + kw) *
W(i / pCM, o % pCM, kh, kw),
return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
{i, kh, kw});
};
return tvm::compute(output_shape, l, name, tag);
......@@ -425,8 +425,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
(I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H
(I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W
indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
W->shape[3], // O
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
......@@ -436,8 +436,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
? I
: pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) *
W(kh, kw, i / pCM, o % pCM),
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
{kh, kw, i});
};
return tvm::compute(output_shape, l, name, tag);
......@@ -479,8 +479,8 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
I->shape[0], // B
I->shape[1], // G
W->shape[2], // O
(I->shape[3] - W->shape[3] + 2 * pad_h) / stride_h + 1, // H
(I->shape[4] - W->shape[4] + 2 * pad_w) / stride_w + 1 // W
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
......
......@@ -58,7 +58,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
Array<Expr> oshape;
for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::ir::Simplify(ishape[i] / 32) :
tvm::ir::Simplify(indexdiv(ishape[i], 32)) :
ishape[i]);
}
......
......@@ -89,8 +89,8 @@ inline Tensor dilate(const Tensor& x,
if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
index_tuple.push_back(indices[i]);
} else {
index_tuple.push_back(indices[i] / strides[i]);
not_zero.push_back((indices[i] % strides[i]) == 0);
index_tuple.push_back(indexdiv(indices[i], strides[i]));
not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
}
}
if (not_zero.size() > 0) {
......
......@@ -70,8 +70,8 @@ inline Tensor flatten(const Tensor& x,
Expr idx = j;
std::vector<Expr> index;
for (auto s : extra_shape) {
index.push_back(idx % s);
idx = idx / s;
index.push_back(indexmod(idx, s));
idx = indexdiv(idx, s);
}
index.push_back(i);
std::reverse(index.begin(), index.end());
......
......@@ -85,7 +85,7 @@ inline Tensor lrn(const Tensor& data,
input_shape,
[&](Var i, Var j, Var k, Var l) {
return tvm::pow(bias +
(alpha * sqr_sum(i, j, k, l) / size),
(div(alpha * sqr_sum(i, j, k, l), size)),
beta);
});
return topi::divide(data, sqrt_sum_up);
......
......@@ -102,9 +102,9 @@ inline Tensor pool_impl(const Tensor& x,
pad_after.Set(width_axis, pad_right);
auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1);
auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1);
indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1);
auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
......@@ -149,7 +149,7 @@ inline Tensor pool_impl(const Tensor& x,
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
if (count_include_pad) {
return pool_sum(indices) / (kernel_height * kernel_width);
return div(pool_sum(indices), (kernel_height * kernel_width));
} else {
Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left;
......@@ -159,7 +159,7 @@ inline Tensor pool_impl(const Tensor& x,
w_start = ir::Max::make(w_start, make_const(Int(32), 0));
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
make_const(Int(32), 1));
return pool_sum(indices) / divide_factor;
return div(pool_sum(indices), divide_factor);
}
}, "tensor", kElementWise);
} else {
......@@ -439,14 +439,14 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp
inline Expr start_index(const Var& out_index,
const Expr& odim,
const Expr& idim) {
return out_index * idim / odim;
return indexdiv(out_index * idim, odim);
}
inline Expr end_index(const Var& out_index,
const Expr& odim,
const Expr& idim) {
Expr tmp = (out_index + 1) * idim / odim;
return tvm::ir::Select::make((out_index + 1) * idim % odim == 0,
Expr tmp = indexdiv((out_index + 1) * idim, odim);
return tvm::ir::Select::make(indexmod((out_index + 1) * idim, odim) == 0,
tmp, tmp + 1);
}
......@@ -505,7 +505,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
indices.Set(height_axis, i_start_h + dheight);
indices.Set(width_axis, i_start_w + dwidth);
return tvm::sum(x(indices) / divide_factor, { dheight, dwidth });
return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
}, "tensor", "adaptive_pool_avg");
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
......
......@@ -658,7 +658,7 @@ inline Tensor take(const Tensor& a,
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size;
auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
}
......@@ -787,7 +787,7 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
......@@ -888,7 +888,7 @@ inline Tensor repeat(const Tensor& x,
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
idx.push_back(indices[axis] / repeats);
idx.push_back(indexdiv(indices[axis], repeats));
for (size_t i = axis + 1; i < indices.size(); ++i) {
idx.push_back(indices[i]);
}
......@@ -944,10 +944,10 @@ inline Tensor tile(const Tensor& x,
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[i] % x->shape[i]);
idx.push_back(indexmod(indices[i], x->shape[i]));
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[rdim - ndim + i] % x->shape[i]);
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
return x(idx);
}, name, tag);
......@@ -1253,7 +1253,7 @@ inline Tensor ndarray_size(const Tensor& src,
}
/*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
other locations take value off_value.
* \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on.
......
......@@ -64,9 +64,9 @@ inline Tensor reorg(const Tensor &data,
auto out = tvm::compute(input_shape,
[&](Var b, Var k, Var j, Var i) {
return data(b * stride * stride,
(k % out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride,
(i*stride + (k / out_c) % stride));
indexmod(k, out_c) * stride * stride,
(j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
(i*stride + indexmod(indexdiv(k, out_c), stride)));
},
name,
tag);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment