Unverified Commit 2ded2d8c by Tianqi Chen Committed by GitHub

[ARITH] Use explicit div mode in python. (#4014)

parent 16bed7e6
......@@ -35,6 +35,13 @@ The user facing API for computation declaration.
tvm.thread_axis
tvm.comm_reducer
tvm.sum
tvm.div
tvm.indexdiv
tvm.indexmod
tvm.truncdiv
tvm.truncmod
tvm.floordiv
tvm.floormod
tvm.min
tvm.max
tvm.tag_scope
......@@ -53,6 +60,13 @@ The user facing API for computation declaration.
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
.. autofunction:: tvm.div
.. autofunction:: tvm.indexdiv
.. autofunction:: tvm.indexmod
.. autofunction:: tvm.truncdiv
.. autofunction:: tvm.truncmod
.. autofunction:: tvm.floordiv
.. autofunction:: tvm.floormod
.. autofunction:: tvm.min
.. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
......@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
reducer.__doc__ = doc_str.format(name)
return reducer
def div(a, b):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b).
"""
return _make._OpDiv(a, b)
def indexdiv(a, b):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _make._OpIndexDiv(a, b)
def indexmod(a, b):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _make._OpIndexMod(a, b)
def truncdiv(a, b):
"""Compute the truncdiv of two expressions.
......
......@@ -101,8 +101,11 @@ def convolution_inference(
assert isinstance(stride, list) and len(stride) == 2
batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
idxdiv = _api.indexdiv
output_height = idxdiv(
input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(
input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern(
(batch, output_channels, output_height, output_width),
......@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
batch, _, input_height, input_width = data.shape
output_channels, _, _, _ = transformed_kernel.shape
kernel_height, kernel_width = (3, 3)
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
idxdiv = _api.indexdiv
output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern(
(batch, output_channels, output_height, output_width),
......
......@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
from . import _api_internal
def div_ambiguity_error():
return RuntimeError(
"TVM supports multiple types of integer divisions, " +
"please call div, indexdiv/indexmod, floordiv/floormod " +
" or truncdiv/truncmod directly to avoid ambiguity in the code.")
def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
class ExprOp(object):
def __add__(self, other):
return _generic.add(self, other)
......@@ -58,24 +72,35 @@ class ExprOp(object):
return _generic.multiply(other, self)
def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(self, other)
def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(other, self)
def __truediv__(self, other):
return self.__div__(other)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(self, other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(other, self)
def __floordiv__(self, other):
return self.__div__(other)
# return _generic.floordiv(self, other)
return _generic.divide(self, other)
def __rfloordiv__(self, other):
return self.__rdiv__(other)
# return _generic.floordiv(other, self)
return _generic.divide(other, self)
def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other)
def __neg__(self):
......
......@@ -25,6 +25,7 @@ from . import make as _make
#Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
......@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
"""
return _make._OpMul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
......@@ -96,6 +96,23 @@ def divide(lhs, rhs):
"""
return _make._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpFloorDiv(lhs, rhs)
def cast(src, dtype):
"""Generic cast operator.
......
......@@ -31,6 +31,7 @@ from . import util
from .preprocessor import determine_variable_usage
from ..api import all as _all
from ..api import any as _any
from ..container import Array
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
......@@ -78,6 +79,18 @@ class Symbol(Enum):
ThreadBind = 10
def _floordiv(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floordiv(x, y)
return operator.floordiv(x, y)
def _floormod(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floormod(x, y)
return operator.mod(x, y)
class HybridParser(ast.NodeVisitor):
"""Python AST visitor pass which finally lowers it to HalideIR"""
......@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
ast.Sub : operator.sub,
ast.Mult : operator.mul,
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.Mod : operator.mod,
ast.FloorDiv: _floordiv,
ast.Mod : _floormod,
ast.BitOr : operator.or_,
ast.BitAnd : operator.and_,
ast.BitXor : operator.xor,
......
......@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
return out
@_reg.register_shape_func("arange", True)
......@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
......@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size / new_size
out[infer_idx] = old_size // new_size
return out
@_reg.register_shape_func("reshape", False)
......
......@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -100,6 +100,8 @@ class CodeGenHybrid :
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
......@@ -161,12 +163,12 @@ class CodeGenHybrid :
std::string GetUniqueName(std::string prefix);
/*! \brief The output code string builder. */
std::stringstream stream;
/*!
/*!
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
std::string GetVarID(const Variable *v);
/*!
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
......
......@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
}
Expr floordiv(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret;
......@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
}
Expr floormod(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret;
......
......@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->type;
if (dtype.is_float()) {
return floor(Div::make(op->a, op->b));
}
CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) {
......
......@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
x * 13 + z * 4 + y * 4 +6)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# trucdiv
ck.verify((x + y + x + y * 3) / 2, y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0)
ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(tmod(x + y + x + y * 3, 2), 0)
# floordiv
fld = tvm.floordiv
......@@ -51,28 +53,31 @@ def test_split_index_simplify():
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# trucdiv
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# split div const
ck.verify((x/3) *3 + x % 3, x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x)
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4)
ck.verify((x % 2) / 8, 0)
ck.verify((x % 2) / 7, 0)
ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6)
ck.verify(tdiv(x, 3) *3 + tmod(x, 3), x)
ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)
ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))
ck.verify(tdiv(tmod(x, 2), 8), 0)
ck.verify(tdiv(tmod(x, 2), 7), 0)
ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6))
# split mod const
ck.verify((x * 8) % 16, (x % 2) * 8)
ck.verify((x * 8) % 2, 0)
ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8)
ck.verify(tmod(x * 8, 2), 0)
# simplify then fold
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y)
ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y)
ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# floordiv
fld = tvm.floordiv
......@@ -85,23 +90,24 @@ def test_split_index_simplify():
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2)
ck.verify(tdiv(x,6) * 2 + tmod(fld(x,3), 2), tdiv(x,6) * 2 + tmod(fld(x,3), 2))
def test_div_simplify():
ck = CanonicalChecker()
x = tvm.var("x")
tdiv = tvm.truncdiv
# truc div
ck.verify((16+48*x)/16, x*3 + 1)
ck.verify(tdiv(16+48*x,16), x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
ck.verify((17+48*x)/16, (x * 48 + 17) / 16)
ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
ck.verify((17+48*x)/16, x * 3 + 1)
ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1)
# Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16)
ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
# floordiv
fld = tvm.floordiv
......@@ -124,8 +130,10 @@ def test_canonical_mixed():
ck = CanonicalChecker()
x = tvm.var("x")
z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0)
ck.verify(x / (z+z) - x / (z+z), 0)
tdiv = tvm.truncdiv
tmod = tvm.truncmod
ck.verify(tdiv(x, (z*z)) - tdiv(x, (z*z)), 0)
ck.verify(tdiv(x, (z+z)) - tdiv(x, (z+z)), 0)
ck.verify(x - 2 < 3, x < 5)
ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0)
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
......@@ -207,42 +215,44 @@ def test_reduce_simplify():
tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k))
ck.verify(tvm.sum(tvm.div(k, 10), k), tvm.sum(tvm.const(0, "int32"), k))
def test_simplify_if_then_else():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# simplification that takes condition into account.
res = tvm.if_then_else((x * 4 + y) >= 466036,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y)
res2 = tvm.if_then_else((x * 4) >= 466036 - y,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16,
tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)),
tmod(((x*4) + y) - 4, 16),
x), y)
ck.verify(res, expected)
ck.verify(res2, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 > 2, x, 0), 0)
tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 < 2, x, 0), 0)
tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
ck.verify(res, 0)
......@@ -250,20 +260,20 @@ def test_complex_cases():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) +
((((((x*128) + y) % 36)*2) + 1)/2))
- (((x*128) + y) % 1296)) + 1)
tdiv = tvm.truncdiv
tmod = tvm.truncmod
res2 = (tdiv(tdiv(tmod(x*128 + y, 1296),36)*2 + 1,2)*36 +
tdiv(tmod((x*128) + y, 36)*2 + 1,2)
- tmod((x*128) + y, 1296) + 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
res3 = ((((((((((x*1024) + y)/65536) + ((((x*1024) + y) % 65536)/256))
+ ((((x*1024) + y) % 256)/16)) + (((x*1024) + y) % 16)) - (y/256)) -
((y % 256)/16)) - (y % 16)) - (x*4))
ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4))
res3 = (tdiv(x*1024 + y,65536) + tdiv(tmod(x*1024 + y, 65536),256)
+ tdiv(tmod(x*1024 + y, 256),16) + tmod(x*1024 + y, 16) - tdiv(y,256) -
tdiv(tmod(y, 256),16) - tmod(y, 16) - (x*4))
ck.verify(res3, tdiv((x*1024) + y, 256) - tdiv(y,256) - (x*4))
if __name__ == "__main__":
......
......@@ -38,12 +38,13 @@ def test_dtype_bound():
def test_cast_bound():
analyzer = tvm.arith.Analyzer()
x = tvm.var("x", dtype="int8")
bd = analyzer.const_int_bound((x % 3).astype("uint32"))
tmod = tvm.truncmod
bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32"))
assert bd.min_value == 0
assert bd.max_value == 2
bd = analyzer.const_int_bound(
(x % 3).astype("float32").astype("int32"))
tmod(x, 3).astype("float32").astype("int32"))
assert bd.min_value == -2
assert bd.max_value == 2
......@@ -98,47 +99,50 @@ def test_mul_bound():
assert bd.max_value == bd.POS_INF
def test_div_bound():
def test_truncdiv_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x / y)
bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -2
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(x / y)
bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -4
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(x / y)
bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
def test_mod_bound():
def test_truncmod_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
tmod = tvm.truncmod
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x % y)
bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == -9
assert bd.max_value == 4
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y)
bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == -9
assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y)
bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == 0
assert bd.max_value == 9
......@@ -253,9 +257,12 @@ def test_shift_and_bound():
def test_mix_index_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
tmod = tvm.truncmod
analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1))
analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1))
bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8)
bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8)
assert bd.min_value == 0
assert bd.max_value == 24 - 1
......@@ -263,7 +270,7 @@ def test_mix_index_bound():
assert bd.min_value == 0
assert bd.max_value == 24 * 3 - 1
bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7)
bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7)
assert bd.min_value == 0
assert bd.max_value == (23 // 7) * 7 + 6
......@@ -273,8 +280,8 @@ if __name__ == "__main__":
test_cast_bound()
test_add_sub_bound()
test_mul_bound()
test_div_bound()
test_mod_bound()
test_truncdiv_bound()
test_truncmod_bound()
test_floordiv_bound()
test_floormod_bound()
test_min_max_bound()
......
......@@ -35,9 +35,11 @@ def test_deduce():
d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.const(0, "int32")
tdiv = tvm.truncdiv
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1) + (-1))
ans0 = (tdiv(d - c, b*-1) + (-1))
assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
......@@ -46,7 +48,7 @@ def test_deduce():
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d-c)/d - 1)
ans0 = (tdiv(d-c,d) - 1)
assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
......@@ -56,7 +58,7 @@ def test_deduce():
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4 -1)
ans1 = (tdiv((c - b) + -1,4) -1)
assert_expr_equal(res1.max_value, ans1)
......@@ -79,7 +81,7 @@ def test_deduce():
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
ans3 = tdiv(2,c)+1
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
......
......@@ -60,13 +60,14 @@ def test_add_sub():
def test_mul_div():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.floordiv
ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
......@@ -76,9 +77,10 @@ def test_mul_div():
def test_mod():
ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y")
tmod = tvm.truncmod
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
......
......@@ -54,7 +54,8 @@ def test_div_shift():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
# not sure if x is non-negative
m = analyzer.modular_set((x * 4 + 2) / 2)
tdiv = tvm.truncdiv
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 1
assert m.base == 0
# right shift always round down so it is fine
......@@ -67,7 +68,7 @@ def test_div_shift():
assert m.base == 1
# x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2)
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 2
assert m.base == 1
......@@ -92,6 +93,7 @@ def test_mix_index():
a = tvm.var("a")
b = tvm.var("b")
analyzer = tvm.arith.Analyzer()
tdiv = tvm.truncdiv
m = analyzer.modular_set(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
......@@ -100,11 +102,11 @@ def test_mix_index():
assert m.coeff == 4
assert m.base == 3
m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3))
m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4))
m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4))
assert m.coeff == 2
assert m.base == 0
......@@ -121,11 +123,13 @@ def test_constraint_scope():
a = tvm.var("a")
b = tvm.var("b")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(b % 4 == 2):
tmod = tvm.truncmod
with analyzer.constraint_scope(tmod(b, 4) == 2):
m = analyzer.modular_set(b + 1)
assert m.coeff == 4
assert m.base == 3
with analyzer.constraint_scope(a % 2 == 1):
with analyzer.constraint_scope(tmod(a, 2) == 1):
m = analyzer.modular_set(b + a * 2)
assert m.coeff == 4
assert m.base == 0
......@@ -140,15 +144,16 @@ def test_constraint_scope():
def test_intersect():
a = tvm.var("a")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1):
with analyzer.constraint_scope(a % 3 == 1):
tmod = tvm.truncmod
with analyzer.constraint_scope(tmod(a, 4) == 1):
with analyzer.constraint_scope(tmod(a, 3) == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1
with analyzer.constraint_scope(a % 3 == 2):
with analyzer.constraint_scope(a % 5 == 3):
with analyzer.constraint_scope(a % 7 == 2):
with analyzer.constraint_scope(tmod(a, 3) == 2):
with analyzer.constraint_scope(tmod(a, 5) == 3):
with analyzer.constraint_scope(tmod(a, 7) == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23
......
......@@ -60,11 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L))
bn = 4
fld = tvm.floordiv
flm = tvm.floormod
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn])
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
......@@ -119,9 +122,11 @@ def test_average_pool():
OH = (H - KH) + 1
OW = (W - KW) + 1
C = tvm.compute(
(N, CO, OH, OW),
lambda n, co, h, w: tvm.sum(D[n][co][h + kh][w + kw].astype(acc_dtype) / (KW * KH), axis=[kh, kw]))
lambda n, co, h, w: tvm.sum(
tvm.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]))
s = tvm.create_schedule([C.op])
......
......@@ -35,7 +35,7 @@ def test_lower_rfactor():
def test_dependent_output_shape():
n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x')
A = tvm.placeholder((n, m))
B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B')
B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B')
s = tvm.create_schedule(B.op)
mod = tvm.build(s, [A, B, x])
......
......@@ -409,7 +409,7 @@ def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++"""
def check_div(start, end, divisor, dtype):
T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) / tvm.const(divisor, dtype))
lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
......@@ -418,8 +418,9 @@ def test_llvm_div():
tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_mod(start, end, divisor, dtype):
tmod = tvm.truncmod
T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) % tvm.const(divisor, dtype))
lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype)
......@@ -443,7 +444,7 @@ def test_llvm_div():
def test_llvm_fp_math():
def check_llvm_reciprocal(n):
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda i: 1.0/(1e+37*A[i]), name='B')
B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B')
s = tvm.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
......
......@@ -41,8 +41,9 @@ def test_if():
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.pointer("float32", name="A")
tmod = tvm.truncmod
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2) == 0):
with ib.if_scope(tmod(i, 2) == 0):
A[i] = A[i] + 1
with ib.else_scope():
A[0] = A[i] + 2
......@@ -108,13 +109,14 @@ def test_gpu():
dtype = "float32"
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
def test_device_ir(A, B, C):
n = A.shape[0]
max_threads = 32
ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", (n+max_threads-1) // max_threads)
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A)
......
......@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# Test Case1
index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1))
index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case2
index_simplified = A.vload(((k0 % (k1 / s)) / n,
(k0 % (k1 / s)) % n + (k0 % k1)))
index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s)))
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n,
((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n))
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload(((k0 % (k1 / s)) / n,
(k0 % (k1 / n)) % n + (k0 % k1)))
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
assert_simplified_equal(index_simplified, index_direct)
......@@ -143,14 +150,14 @@ def test_buffer_broadcast():
check()
def test_bbuffer_roadcast_expr():
def test_buffer_broadcast_expr():
n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x')
n1, m1 = tvm.var('n1'), tvm.var('m1')
o0, o1 = tvm.var('o0'), tvm.var('o1')
A = tvm.placeholder((m0, n0), name='A')
B = tvm.placeholder((m1, n1), name='B')
C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C')
C = tvm.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
......
......@@ -32,10 +32,11 @@ def test_const_fold():
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
tmod = tvm.truncmod
check(lambda x, y: x + y, 3, 4)
check(lambda x, y: x * y, 3, 12)
check(lambda x, y: x * y - 10, 3, 12)
check(lambda x, y: x - y % 10, 3, 12)
check(lambda x, y: x - tmod(y, 10), 3, 12)
check(lambda x, y: x // y + 10, 100, 12)
check(lambda x, y: x & y + 10, 112, 128)
check(lambda x, y: x > y, 112, 128)
......@@ -47,13 +48,15 @@ def test_const_fold():
def test_const_fold2():
x = tvm.var("x")
tmod = tvm.truncmod
tdiv = tvm.truncdiv
assert (x + 0).same_as(x)
assert (0 + x).same_as(x)
assert (x - 0).same_as(x)
assert (x % 1).value == 0
assert tmod(x, 1).value == 0
assert (x * 1).same_as(x)
assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div)
assert isinstance(tdiv(1, x), tvm.expr.Div)
def test_const_fold3():
# Test that using ints with logic operations is forbidden
......@@ -88,8 +91,9 @@ def test_const_fold3():
def test_const_fold4():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
tdiv = tvm.truncdiv
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
x3 = tdiv(x2, 3)
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
......
......@@ -72,7 +72,7 @@ def test_combination():
A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B')
C = tvm.placeholder((n, m), name='C')
D = k + A - B * C / x
D = k + A - B * C + x
s = tvm.create_schedule(D.op)
foo = tvm.build(s, [x, A, B, C, D], "llvm")
ctx = tvm.cpu(0)
......@@ -82,7 +82,7 @@ def test_combination():
c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
foo(x, a, b, c, d)
tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() / x)
tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x)
def verify_tensor_scalar_bop(shape, typ="add"):
......
......@@ -17,13 +17,15 @@
import tvm
def test_simplify():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
x = tvm.var('x')
e1 = tvm.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir_pass.Equal(e1, x + 3))
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
e3 = tvm.ir_pass.Simplify(x - tdiv(x, 3) * 3)
assert(tvm.ir_pass.Equal(e3, tmod(x, 3)))
def test_verify_ssa():
......
......@@ -24,7 +24,7 @@ def test_equal_expr():
return x + y + 1
def func2():
return tvm.exp((x + y + 1) * y / 4)
return tvm.exp(tvm.truncdiv((x + y + 1) * y, 4))
assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2())
......
......@@ -162,7 +162,7 @@ def test_condition():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i:
with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i:
with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n)))
......@@ -206,7 +206,7 @@ def test_everything_during_deduction():
ib = tvm.ir_builder.create()
with ib.for_range(0, n, 'i') as i:
with ib.for_range(0, 32, 'j') as j:
with ib.if_scope(ib.likely(i/j < m)):
with ib.if_scope(ib.likely(tvm.truncdiv(i,j) < m)):
# this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m))
stmt = ib.get()
......
......@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0)
idxdiv = tvm.indexdiv
assert(tvm.ir_pass.Simplify(
bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0)
expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1)
expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
for i in range(1, 6):
for j in range(1, 6):
for k in range(1, 6):
......@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value
assert(comp_ext == exp_ext)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
def test_bound_fusesplit2():
......@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2
tdiv = tvm.truncdiv
# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[-2*tdiv(i,2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))
......
......@@ -221,11 +221,14 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
def op_intrin():
bh = 9
bw = 9
x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw), lambda i,j: x[j/3 + i%3, j%3+ i/3])
y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
def intrin_func(ins, outs):
xx, = ins
......@@ -236,7 +239,7 @@ def test_tensorize_op():
return tvm.decl_tensor_intrin(y.op, intrin_func)
A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[j/3 + i%3, j%3 + i/3])
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
bt = op_intrin()
s = tvm.create_schedule(B.op)
......
......@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_conv2d_output')
return output
......
......@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_conv2d_transpose_output')
return output
......
......@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw]
.astype(out_dtype) *
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
axis=[kh, kw]), name='depthwise_conv')
conv = tvm.compute(
ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]
.astype(out_dtype) *
kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
axis=[kh, kw]), name='depthwise_conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh,
tvm.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh,
vw * WSTR + kw].astype(out_dtype) *
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
kernel_vec[idxdiv(co, M),
idxmod(co, M),
kh, kw, vc].astype(out_dtype),
axis=[kh, kw]), name='depthwise_conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
return output
......
......@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
[0, 0, (bpad_bottom + stride_h - 1) // stride_h,
(bpad_right + stride_w - 1) // stride_w], name='FirstPad')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# remove extra padding introduced by dilatation
border_h = (stride_h - bpad_top % stride_h) % stride_h
border_w = (stride_w - bpad_left % stride_w) % stride_w
border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
# dilation stage
data = FirstPad
......@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple = []
for i in range(n):
if not equal_const_int(strides[i], 1):
index_tuple.append(indices[i] // strides[i])
not_zero.append((indices[i] % strides[i]).equal(0))
index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
......
......@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
else:
kernel_pack = kernel
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# pack input tile
input_tile = tvm.compute((CI, P, alpha, alpha), lambda c, p, eps, nu:
data_pad[p // (nH * nW)][c][p // nW % nH * m + eps]
[p % nW * m + nu], name='d')
data_pad[idxdiv(p, (nH * nW))][c][idxmod(idxdiv(p, nW), nH) * m + eps]
[idxmod(p, nW) * m + nu], name='d')
# transform data
r_a = tvm.reduce_axis((0, alpha), 'r_a')
......@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
# output
output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
inverse[co][n * nH * nW + (h // m) * nW + w // m][h % m][w % m],
inverse[co,
n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(h, m),
idxmod(w, m)],
name='output', tag='conv2d_nchw_winograd')
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
......
......@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
new_range = num_anchors // elem_per_thread + 1
# Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors # number of batches
i = tid // num_anchors # number of batches
j = tid % num_anchors # number of anchors
with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid]
......@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors
i = tid // num_anchors
j = tid % num_anchors
base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0):
......
......@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors
i = tid // num_anchors
j = tid % num_anchors
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
......
......@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu],
axis=[r_a, r_b]), name='V')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# batch gemm
ci = tvm.reduce_axis((0, CI), name='c')
M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p:
tvm.sum(U[eps][nu][co // bna][ci][co % bna] *
V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M')
tvm.sum(U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)] *
V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)], axis=ci), name='M')
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b')
......@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output
output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
Y[co][n * nH * nW + (h//m) * nW + w//m][h % m][w % m]
Y[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(h, m), idxmod(w, m)]
# The following hack term is used to make the padding in batch gemm ("M")
# effective, otherwise the padding will be eliminated by bound inference.
# Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
......
......@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis=[ci, dh, dw, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv_out')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
return tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
return tvm.compute(
oshape, lambda n, co, h, w:
conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv(
w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
......@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv = tvm.compute(ovshape, _conv, name='conv')
return tvm.compute(oshape, lambda n, h, w, co:
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
return tvm.compute(
oshape, lambda n, h, w, co:
conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv(
co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
@tvm.target.generic_func
def bitserial_conv2d_legalize(attrs, inputs, types):
......
......@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype)
(tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype)
tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
# binary ops
......
......@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n,
ic // ic_bn,
idxdiv(ic, ic_bn),
oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w,
ic % ic_bn].astype(out_dtype)
idxmod(ic, ic_bn)].astype(out_dtype)
* kernel[oc_chunk,
ic // ic_bn,
idxdiv(ic, ic_bn),
kh,
kw,
ic % ic_bn,
idxmod(ic, ic_bn),
oc_block],
axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc")
......
......@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h,
(PaddedInput[b, idxdiv(c, channel_multiplier), i*stride_h+di*dilation_h,
j*stride_w+dj*dilation_w].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
Filter[idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier), di, dj].astype(out_dtype)),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
......@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
c/channel_multiplier].astype(out_dtype) *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
idxdiv(c, channel_multiplier)].astype(out_dtype) *
Filter[di, dj,
idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier)].astype(out_dtype)),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
......@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh')
dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw')
db = tvm.reduce_axis((0, batch), name='db')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
Weight_grad = tvm.compute(
(filter_h, filter_w, in_c, channel_multiplier),
lambda fh, fw, c, m: tvm.sum(
Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] *
Out_grad[db, dh, dw, c*channel_multiplier+idxmod(m, channel_multiplier)] *
padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
tag='depthwise_conv2d_backward_weight_nhwc')
......
......@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
def _dilate(*indices):
not_zero = []
index_tuple = []
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
for i in range(n):
if not util.equal_const_int(strides[i], 1):
index_tuple.append(indices[i] / strides[i])
not_zero.append((indices[i] % strides[i]).equal(0))
index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
......
......@@ -38,12 +38,14 @@ def flatten(data):
for i in range(1, len(ishape)):
dim = dim * ishape[i]
oshape = [ishape[0], dim]
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
def unwrap(idx, shape):
index = []
for s in reversed(shape):
index.append(idx % s)
idx = idx / s
index.append(idxmod(idx, s))
idx = idxdiv(idx, s)
return list(reversed(index))
return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:])))
......@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
idxmod = tvm.indexmod
idxdiv = tvm.indexdiv
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h,
idxmod(ic, ic_bn),
ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw,
idxmod(ic, ic_bn),
oc_block].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn]
lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
.astype(out_dtype),
name='output_unpack',
tag='conv2d_nchw')
......@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
cfg = get_config()
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1])
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn),
raw_data_shape[2], raw_data_shape[3], ic_bn)
data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn),
idxdiv(raw_kernel_shape[1], ic_bn),
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
......@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1]
out_channel, _, k_height, k_width = kernel[:-1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
idxdiv = tvm.indexdiv
out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1
out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),)
......
......@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
k = tvm.reduce_axis((0, K), name="k")
C = tvm.compute((M, N),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None:
......
......@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
data_pad = data
# depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
kh = tvm.reduce_axis((0, filter_height), name='kh')
kw = tvm.reduce_axis((0, filter_width), name='kw')
Output = tvm.compute(
(batch, out_channel_chunk, out_height, out_width, out_channel_block),
lambda b, oco, oh, ow, oci: tvm.sum(
(data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block,
oh*HSTR+kh, ow*WSTR+kw,
((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block]
(data_pad[
b,
idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block),
oh*HSTR+kh, ow*WSTR+kw,
idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)]
.astype(out_dtype) *
kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
axis=[kh, kw]),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment