Unverified Commit 2ded2d8c by Tianqi Chen Committed by GitHub

[ARITH] Use explicit div mode in python. (#4014)

parent 16bed7e6
...@@ -35,6 +35,13 @@ The user facing API for computation declaration. ...@@ -35,6 +35,13 @@ The user facing API for computation declaration.
tvm.thread_axis tvm.thread_axis
tvm.comm_reducer tvm.comm_reducer
tvm.sum tvm.sum
tvm.div
tvm.indexdiv
tvm.indexmod
tvm.truncdiv
tvm.truncmod
tvm.floordiv
tvm.floormod
tvm.min tvm.min
tvm.max tvm.max
tvm.tag_scope tvm.tag_scope
...@@ -53,6 +60,13 @@ The user facing API for computation declaration. ...@@ -53,6 +60,13 @@ The user facing API for computation declaration.
.. autofunction:: tvm.thread_axis .. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer .. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum .. autofunction:: tvm.sum
.. autofunction:: tvm.div
.. autofunction:: tvm.indexdiv
.. autofunction:: tvm.indexmod
.. autofunction:: tvm.truncdiv
.. autofunction:: tvm.truncmod
.. autofunction:: tvm.floordiv
.. autofunction:: tvm.floormod
.. autofunction:: tvm.min .. autofunction:: tvm.min
.. autofunction:: tvm.max .. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope .. autofunction:: tvm.tag_scope
...@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
reducer.__doc__ = doc_str.format(name) reducer.__doc__ = doc_str.format(name)
return reducer return reducer
def div(a, b):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b).
"""
return _make._OpDiv(a, b)
def indexdiv(a, b):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _make._OpIndexDiv(a, b)
def indexmod(a, b):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _make._OpIndexMod(a, b)
def truncdiv(a, b): def truncdiv(a, b):
"""Compute the truncdiv of two expressions. """Compute the truncdiv of two expressions.
......
...@@ -101,8 +101,11 @@ def convolution_inference( ...@@ -101,8 +101,11 @@ def convolution_inference(
assert isinstance(stride, list) and len(stride) == 2 assert isinstance(stride, list) and len(stride) == 2
batch, _, input_height, input_width = data.shape batch, _, input_height, input_width = data.shape
output_channels, _, kernel_height, kernel_width = kernel.shape output_channels, _, kernel_height, kernel_width = kernel.shape
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 idxdiv = _api.indexdiv
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 output_height = idxdiv(
input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(
input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern( return _api.extern(
(batch, output_channels, output_height, output_width), (batch, output_channels, output_height, output_width),
...@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform( ...@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
batch, _, input_height, input_width = data.shape batch, _, input_height, input_width = data.shape
output_channels, _, _, _ = transformed_kernel.shape output_channels, _, _, _ = transformed_kernel.shape
kernel_height, kernel_width = (3, 3) kernel_height, kernel_width = (3, 3)
output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 idxdiv = _api.indexdiv
output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 output_height = idxdiv(input_height + padding[0] + padding[1] - kernel_height, stride[0]) + 1
output_width = idxdiv(input_width + padding[0] + padding[1] - kernel_width, stride[1]) + 1
return _api.extern( return _api.extern(
(batch, output_channels, output_height, output_width), (batch, output_channels, output_height, output_width),
......
...@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node from ._ffi.node import NodeBase, NodeGeneric, register_node
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make from . import make as _make
from . import generic as _generic from . import generic as _generic
from . import _api_internal from . import _api_internal
def div_ambiguity_error():
return RuntimeError(
"TVM supports multiple types of integer divisions, " +
"please call div, indexdiv/indexmod, floordiv/floormod " +
" or truncdiv/truncmod directly to avoid ambiguity in the code.")
def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
class ExprOp(object): class ExprOp(object):
def __add__(self, other): def __add__(self, other):
return _generic.add(self, other) return _generic.add(self, other)
...@@ -58,24 +72,35 @@ class ExprOp(object): ...@@ -58,24 +72,35 @@ class ExprOp(object):
return _generic.multiply(other, self) return _generic.multiply(other, self)
def __div__(self, other): def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(self, other) return _generic.divide(self, other)
def __rdiv__(self, other): def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(other, self) return _generic.divide(other, self)
def __truediv__(self, other): def __truediv__(self, other):
return self.__div__(other) # if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(self, other)
def __rtruediv__(self, other): def __rtruediv__(self, other):
return self.__rdiv__(other) # if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return _generic.divide(other, self)
def __floordiv__(self, other): def __floordiv__(self, other):
return self.__div__(other) # return _generic.floordiv(self, other)
return _generic.divide(self, other)
def __rfloordiv__(self, other): def __rfloordiv__(self, other):
return self.__rdiv__(other) # return _generic.floordiv(other, self)
return _generic.divide(other, self)
def __mod__(self, other): def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other) return _make._OpMod(self, other)
def __neg__(self): def __neg__(self):
......
...@@ -25,6 +25,7 @@ from . import make as _make ...@@ -25,6 +25,7 @@ from . import make as _make
#Operator precedence used when overloading. #Operator precedence used when overloading.
__op_priority__ = 0 __op_priority__ = 0
def add(lhs, rhs): def add(lhs, rhs):
"""Generic add operator. """Generic add operator.
...@@ -78,7 +79,6 @@ def multiply(lhs, rhs): ...@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
""" """
return _make._OpMul(lhs, rhs) return _make._OpMul(lhs, rhs)
def divide(lhs, rhs): def divide(lhs, rhs):
"""Generic divide operator. """Generic divide operator.
...@@ -96,6 +96,23 @@ def divide(lhs, rhs): ...@@ -96,6 +96,23 @@ def divide(lhs, rhs):
""" """
return _make._OpDiv(lhs, rhs) return _make._OpDiv(lhs, rhs)
def floordiv(lhs, rhs):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make._OpFloorDiv(lhs, rhs)
def cast(src, dtype): def cast(src, dtype):
"""Generic cast operator. """Generic cast operator.
......
...@@ -31,6 +31,7 @@ from . import util ...@@ -31,6 +31,7 @@ from . import util
from .preprocessor import determine_variable_usage from .preprocessor import determine_variable_usage
from ..api import all as _all from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..container import Array from ..container import Array
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal from .. import _api_internal as _tvm_internal
...@@ -78,6 +79,18 @@ class Symbol(Enum): ...@@ -78,6 +79,18 @@ class Symbol(Enum):
ThreadBind = 10 ThreadBind = 10
def _floordiv(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floordiv(x, y)
return operator.floordiv(x, y)
def _floormod(x, y):
if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp):
return _api.floormod(x, y)
return operator.mod(x, y)
class HybridParser(ast.NodeVisitor): class HybridParser(ast.NodeVisitor):
"""Python AST visitor pass which finally lowers it to HalideIR""" """Python AST visitor pass which finally lowers it to HalideIR"""
...@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
ast.Sub : operator.sub, ast.Sub : operator.sub,
ast.Mult : operator.mul, ast.Mult : operator.mul,
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv, ast.FloorDiv: _floordiv,
ast.Mod : operator.mod, ast.Mod : _floormod,
ast.BitOr : operator.or_, ast.BitOr : operator.or_,
ast.BitAnd : operator.and_, ast.BitAnd : operator.and_,
ast.BitXor : operator.xor, ast.BitXor : operator.xor,
......
...@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) ...@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
@script @script
def _arange_shape_func(start, stop, step): def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64") out = output_tensor((1,), "int64")
out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0]))) out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
return out return out
@_reg.register_shape_func("arange", True) @_reg.register_shape_func("arange", True)
...@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim): ...@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
assert len(newshape) - i > 2, "Not enough dims in new shape for -4" assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1: if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1." assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2]) out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2]) out[dst_idx+1] = int64(newshape[i+2])
else: else:
out[dst_idx] = int64(newshape[i+1]) out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1: if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1]) out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else: else:
out[dst_idx+1] = int64(newshape[i+2]) out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
...@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim): ...@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
new_size = int64(1) new_size = int64(1)
for i in const_range(out.shape[0]): for i in const_range(out.shape[0]):
new_size *= out[i] new_size *= out[i]
out[infer_idx] = old_size / new_size out[infer_idx] = old_size // new_size
return out return out
@_reg.register_shape_func("reshape", False) @_reg.register_shape_func("reshape", False)
......
...@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-); ...@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*); REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, div); REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) ...@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this); PrintBinaryExpr(op, "*", os, this);
} }
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int()) if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this); PrintBinaryExpr(op, "//", os, this);
else else
PrintBinaryExpr(op, "/", os, this); PrintBinaryExpr(op, "/", os, this);
} }
void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this); PrintBinaryExpr(op, "%", os, this);
} }
void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this); PrintBinaryExpr(op, "min", os, this);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -100,6 +100,8 @@ class CodeGenHybrid : ...@@ -100,6 +100,8 @@ class CodeGenHybrid :
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
...@@ -161,12 +163,12 @@ class CodeGenHybrid : ...@@ -161,12 +163,12 @@ class CodeGenHybrid :
std::string GetUniqueName(std::string prefix); std::string GetUniqueName(std::string prefix);
/*! \brief The output code string builder. */ /*! \brief The output code string builder. */
std::stringstream stream; std::stringstream stream;
/*! /*!
* \brief Get or allocate the ID for the given variable. * \brief Get or allocate the ID for the given variable.
* \param v The given variable. * \param v The given variable.
*/ */
std::string GetVarID(const Variable *v); std::string GetVarID(const Variable *v);
/*! /*!
* \brief Get or allocate the ID for the given tensor. * \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name. * \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor. * \param value_index The value index of the given tensor.
......
...@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) { ...@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
} }
Expr floordiv(Expr a, Expr b) { Expr floordiv(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b); Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
...@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) { ...@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
} }
Expr floormod(Expr a, Expr b) { Expr floormod(Expr a, Expr b) {
CHECK(a.type().is_int() || a.type().is_uint());
CHECK(b.type().is_int() || b.type().is_uint());
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b); Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
......
...@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { ...@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (op == nullptr) return ret; if (op == nullptr) return ret;
int shift; int shift;
const DataType& dtype = op->type; const DataType& dtype = op->type;
if (dtype.is_float()) {
return floor(Div::make(op->a, op->b));
}
CHECK(dtype.is_int() || !dtype.is_uint()); CHECK(dtype.is_int() || !dtype.is_uint());
if (is_const_power_of_two_integer(op->b, &shift)) { if (is_const_power_of_two_integer(op->b, &shift)) {
......
...@@ -33,9 +33,11 @@ def test_mul_sum_simplify(): ...@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
x * 13 + z * 4 + y * 4 +6) x * 13 + z * 4 + y * 4 +6)
ck.verify(x * 3 - 4 * x + 1, 1 - x) ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# trucdiv # trucdiv
ck.verify((x + y + x + y * 3) / 2, y * 2 + x) ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x)
ck.verify((x + y + x + y * 3) % 2, 0) ck.verify(tmod(x + y + x + y * 3, 2), 0)
# floordiv # floordiv
fld = tvm.floordiv fld = tvm.floordiv
...@@ -51,28 +53,31 @@ def test_split_index_simplify(): ...@@ -51,28 +53,31 @@ def test_split_index_simplify():
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# trucdiv # trucdiv
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# split div const # split div const
ck.verify((x/3) *3 + x % 3, x) ck.verify(tdiv(x, 3) *3 + tmod(x, 3), x)
ck.verify((x/6) * 6 + ((x/3) % 2) * 3 + x % 3, x) ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)
ck.verify(((x % 16) / 2) * 2 / 4, (x % 16) / 4) ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))
ck.verify((x % 2) / 8, 0) ck.verify(tdiv(tmod(x, 2), 8), 0)
ck.verify((x % 2) / 7, 0) ck.verify(tdiv(tmod(x, 2), 7), 0)
ck.verify(((x % 16) / 2) * 2 / 6, (x % 16) / 6) ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6))
# split mod const # split mod const
ck.verify((x * 8) % 16, (x % 2) * 8) ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8)
ck.verify((x * 8) % 2, 0) ck.verify(tmod(x * 8, 2), 0)
# simplify then fold # simplify then fold
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000)) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y) ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# complex fold # complex fold
ck.verify((z * 9 + y) / 2 * 2 + (z * 9 + y) % 2, z * 9 + y) ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True) ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify((x * 4 + y) / 2 * 2 + (x * 4 + y) % 2, x * 4 + y) ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# floordiv # floordiv
fld = tvm.floordiv fld = tvm.floordiv
...@@ -85,23 +90,24 @@ def test_split_index_simplify(): ...@@ -85,23 +90,24 @@ def test_split_index_simplify():
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6)) ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode. # cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify((x/6) * 2 + fld(x,3) % 2, (x/6) * 2 + fld(x,3) % 2) ck.verify(tdiv(x,6) * 2 + tmod(fld(x,3), 2), tdiv(x,6) * 2 + tmod(fld(x,3), 2))
def test_div_simplify(): def test_div_simplify():
ck = CanonicalChecker() ck = CanonicalChecker()
x = tvm.var("x") x = tvm.var("x")
tdiv = tvm.truncdiv
# truc div # truc div
ck.verify((16+48*x)/16, x*3 + 1) ck.verify(tdiv(16+48*x,16), x*3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0 # (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x # (17+48*x)/16 != 1+3*x
ck.verify((17+48*x)/16, (x * 48 + 17) / 16) ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified # However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
ck.verify((17+48*x)/16, x * 3 + 1) ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1)
# Trying expressions that are not simplifiable for any values of the variables # Trying expressions that are not simplifiable for any values of the variables
ck.verify((17+47*x)/16, (x * 47 + 17) / 16) ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
# floordiv # floordiv
fld = tvm.floordiv fld = tvm.floordiv
...@@ -124,8 +130,10 @@ def test_canonical_mixed(): ...@@ -124,8 +130,10 @@ def test_canonical_mixed():
ck = CanonicalChecker() ck = CanonicalChecker()
x = tvm.var("x") x = tvm.var("x")
z = tvm.const(3, "int32") z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0) tdiv = tvm.truncdiv
ck.verify(x / (z+z) - x / (z+z), 0) tmod = tvm.truncmod
ck.verify(tdiv(x, (z*z)) - tdiv(x, (z*z)), 0)
ck.verify(tdiv(x, (z+z)) - tdiv(x, (z+z)), 0)
ck.verify(x - 2 < 3, x < 5) ck.verify(x - 2 < 3, x < 5)
ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0) ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0)
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0) ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
...@@ -207,42 +215,44 @@ def test_reduce_simplify(): ...@@ -207,42 +215,44 @@ def test_reduce_simplify():
tvm.sum(k + j, [k, j])) tvm.sum(k + j, [k, j]))
ck.verify(tvm.sum(A[3], []), A[3]) ck.verify(tvm.sum(A[3], []), A[3])
# The rule below is not typical, removed for now # The rule below is not typical, removed for now
ck.verify(tvm.sum(k / 10, k), tvm.sum(tvm.const(0, "int32"), k)) ck.verify(tvm.sum(tvm.div(k, 10), k), tvm.sum(tvm.const(0, "int32"), k))
def test_simplify_if_then_else(): def test_simplify_if_then_else():
ck = CanonicalChecker() ck = CanonicalChecker()
x = tvm.var("x") x = tvm.var("x")
y = tvm.var("y") y = tvm.var("y")
tdiv = tvm.truncdiv
tmod = tvm.truncmod
# simplification that takes condition into account. # simplification that takes condition into account.
res = tvm.if_then_else((x * 4 + y) >= 466036, res = tvm.if_then_else((x * 4 + y) >= 466036,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16, tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y) x), y)
res2 = tvm.if_then_else((x * 4) >= 466036 - y, res2 = tvm.if_then_else((x * 4) >= 466036 - y,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), tvm.if_then_else(24512 <= tmod(((x*4) + y) - 466036, 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16, tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16),
x), y) x), y)
expected = tvm.if_then_else( expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)), tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)), tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)),
(((x*4) + y) - 4) % 16, tmod(((x*4) + y) - 4, 16),
x), y) x), y)
ck.verify(res, expected) ck.verify(res, expected)
ck.verify(res2, expected) ck.verify(res2, expected)
# can only simplify if condition # can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3) res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3) expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
ck.verify(res, ck.analyzer.canonical_simplify(expected)) ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10, res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 > 2, x, 0), 0) tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0) expected = tvm.expr.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected)) ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10, res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 < 2, x, 0), 0) tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
ck.verify(res, 0) ck.verify(res, 0)
...@@ -250,20 +260,20 @@ def test_complex_cases(): ...@@ -250,20 +260,20 @@ def test_complex_cases():
ck = CanonicalChecker() ck = CanonicalChecker()
x = tvm.var("x") x = tvm.var("x")
y = tvm.var("y") y = tvm.var("y")
res2 = (((((((((((x*128) + y) % 1296)/36)*2) + 1)/2)*36) + tdiv = tvm.truncdiv
((((((x*128) + y) % 36)*2) + 1)/2)) tmod = tvm.truncmod
- (((x*128) + y) % 1296)) + 1) res2 = (tdiv(tdiv(tmod(x*128 + y, 1296),36)*2 + 1,2)*36 +
tdiv(tmod((x*128) + y, 36)*2 + 1,2)
- tmod((x*128) + y, 1296) + 1)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127)) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1) ck.verify(res2, 1)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
res3 = ((((((((((x*1024) + y)/65536) + ((((x*1024) + y) % 65536)/256)) res3 = (tdiv(x*1024 + y,65536) + tdiv(tmod(x*1024 + y, 65536),256)
+ ((((x*1024) + y) % 256)/16)) + (((x*1024) + y) % 16)) - (y/256)) - + tdiv(tmod(x*1024 + y, 256),16) + tmod(x*1024 + y, 16) - tdiv(y,256) -
((y % 256)/16)) - (y % 16)) - (x*4)) tdiv(tmod(y, 256),16) - tmod(y, 16) - (x*4))
ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4)) ck.verify(res3, tdiv((x*1024) + y, 256) - tdiv(y,256) - (x*4))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -38,12 +38,13 @@ def test_dtype_bound(): ...@@ -38,12 +38,13 @@ def test_dtype_bound():
def test_cast_bound(): def test_cast_bound():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x = tvm.var("x", dtype="int8") x = tvm.var("x", dtype="int8")
bd = analyzer.const_int_bound((x % 3).astype("uint32")) tmod = tvm.truncmod
bd = analyzer.const_int_bound(tmod(x, 3).astype("uint32"))
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == 2 assert bd.max_value == 2
bd = analyzer.const_int_bound( bd = analyzer.const_int_bound(
(x % 3).astype("float32").astype("int32")) tmod(x, 3).astype("float32").astype("int32"))
assert bd.min_value == -2 assert bd.min_value == -2
assert bd.max_value == 2 assert bd.max_value == 2
...@@ -98,47 +99,50 @@ def test_mul_bound(): ...@@ -98,47 +99,50 @@ def test_mul_bound():
assert bd.max_value == bd.POS_INF assert bd.max_value == bd.POS_INF
def test_div_bound(): def test_truncdiv_bound():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x / y) bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -2 assert bd.min_value == -2
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True) analyzer.update(x, tvm.arith.ConstIntBound(-9, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True) analyzer.update(y, tvm.arith.ConstIntBound(-2, 0), override=True)
bd = analyzer.const_int_bound(x / y) bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == -4 assert bd.min_value == -4
assert bd.max_value == 9 assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True) analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, 4), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True) analyzer.update(y, tvm.arith.ConstIntBound(-2, 1), override=True)
bd = analyzer.const_int_bound(x / y) bd = analyzer.const_int_bound(tdiv(x, y))
assert bd.min_value == bd.NEG_INF assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF assert bd.max_value == bd.POS_INF
def test_mod_bound(): def test_truncmod_bound():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tmod = tvm.truncmod
analyzer.update(x, tvm.arith.ConstIntBound(-9, 4)) analyzer.update(x, tvm.arith.ConstIntBound(-9, 4))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(x % y) bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == -9 assert bd.min_value == -9
assert bd.max_value == 4 assert bd.max_value == 4
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y) bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == -9 assert bd.min_value == -9
assert bd.max_value == 9 assert bd.max_value == 9
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True) analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True) analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(x % y) bd = analyzer.const_int_bound(tmod(x, y))
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == 9 assert bd.max_value == 9
...@@ -253,9 +257,12 @@ def test_shift_and_bound(): ...@@ -253,9 +257,12 @@ def test_shift_and_bound():
def test_mix_index_bound(): def test_mix_index_bound():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
tmod = tvm.truncmod
analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1)) analyzer.update(x, tvm.arith.ConstIntBound(0, 24 - 1))
analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1)) analyzer.update(y, tvm.arith.ConstIntBound(0, 3 - 1))
bd = analyzer.const_int_bound((x % 8) + (x / 8) * 8) bd = analyzer.const_int_bound(tmod(x, 8) + tdiv(x, 8) * 8)
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == 24 - 1 assert bd.max_value == 24 - 1
...@@ -263,7 +270,7 @@ def test_mix_index_bound(): ...@@ -263,7 +270,7 @@ def test_mix_index_bound():
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == 24 * 3 - 1 assert bd.max_value == 24 * 3 - 1
bd = analyzer.const_int_bound((x % 7) + (x / 7) * 7) bd = analyzer.const_int_bound(tmod(x, 7) + tdiv(x, 7) * 7)
assert bd.min_value == 0 assert bd.min_value == 0
assert bd.max_value == (23 // 7) * 7 + 6 assert bd.max_value == (23 // 7) * 7 + 6
...@@ -273,8 +280,8 @@ if __name__ == "__main__": ...@@ -273,8 +280,8 @@ if __name__ == "__main__":
test_cast_bound() test_cast_bound()
test_add_sub_bound() test_add_sub_bound()
test_mul_bound() test_mul_bound()
test_div_bound() test_truncdiv_bound()
test_mod_bound() test_truncmod_bound()
test_floordiv_bound() test_floordiv_bound()
test_floormod_bound() test_floormod_bound()
test_min_max_bound() test_min_max_bound()
......
...@@ -35,9 +35,11 @@ def test_deduce(): ...@@ -35,9 +35,11 @@ def test_deduce():
d_s = tvm.arith.IntervalSet(-3, -1) d_s = tvm.arith.IntervalSet(-3, -1)
zero = tvm.const(0, "int32") zero = tvm.const(0, "int32")
tdiv = tvm.truncdiv
e0 = (-b)*a+c-d e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1) + (-1)) ans0 = (tdiv(d - c, b*-1) + (-1))
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
...@@ -46,7 +48,7 @@ def test_deduce(): ...@@ -46,7 +48,7 @@ def test_deduce():
e0 = d*a+c-d e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d-c)/d - 1) ans0 = (tdiv(d-c,d) - 1)
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
...@@ -56,7 +58,7 @@ def test_deduce(): ...@@ -56,7 +58,7 @@ def test_deduce():
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4 -1) ans1 = (tdiv((c - b) + -1,4) -1)
assert_expr_equal(res1.max_value, ans1) assert_expr_equal(res1.max_value, ans1)
...@@ -79,7 +81,7 @@ def test_deduce(): ...@@ -79,7 +81,7 @@ def test_deduce():
e3 = (-b)+a*c-d e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1 ans3 = tdiv(2,c)+1
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
......
...@@ -60,13 +60,14 @@ def test_add_sub(): ...@@ -60,13 +60,14 @@ def test_add_sub():
def test_mul_div(): def test_mul_div():
ck = IntSetChecker() ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tdiv = tvm.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20)) ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2)) ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y)) ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.floordiv fld = tvm.floordiv
ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y))) ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
...@@ -76,9 +77,10 @@ def test_mul_div(): ...@@ -76,9 +77,10 @@ def test_mul_div():
def test_mod(): def test_mod():
ck = IntSetChecker() ck = IntSetChecker()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
tmod = tvm.truncmod
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.floormod flm = tvm.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
......
...@@ -54,7 +54,8 @@ def test_div_shift(): ...@@ -54,7 +54,8 @@ def test_div_shift():
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y") x, y = tvm.var("x"), tvm.var("y")
# not sure if x is non-negative # not sure if x is non-negative
m = analyzer.modular_set((x * 4 + 2) / 2) tdiv = tvm.truncdiv
m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 1 assert m.coeff == 1
assert m.base == 0 assert m.base == 0
# right shift always round down so it is fine # right shift always round down so it is fine
...@@ -67,7 +68,7 @@ def test_div_shift(): ...@@ -67,7 +68,7 @@ def test_div_shift():
assert m.base == 1 assert m.base == 1
# x is non-negative # x is non-negative
analyzer.update(x, tvm.arith.ConstIntBound(0, 100)) analyzer.update(x, tvm.arith.ConstIntBound(0, 100))
m = analyzer.modular_set((x * 4 + 2) / 2) m = analyzer.modular_set(tdiv(x * 4 + 2, 2))
assert m.coeff == 2 assert m.coeff == 2
assert m.base == 1 assert m.base == 1
...@@ -92,6 +93,7 @@ def test_mix_index(): ...@@ -92,6 +93,7 @@ def test_mix_index():
a = tvm.var("a") a = tvm.var("a")
b = tvm.var("b") b = tvm.var("b")
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
tdiv = tvm.truncdiv
m = analyzer.modular_set(a * 4 + b * 6 + 7) m = analyzer.modular_set(a * 4 + b * 6 + 7)
assert m.coeff == 2 assert m.coeff == 2
assert m.base == 1 assert m.base == 1
...@@ -100,11 +102,11 @@ def test_mix_index(): ...@@ -100,11 +102,11 @@ def test_mix_index():
assert m.coeff == 4 assert m.coeff == 4
assert m.base == 3 assert m.base == 3
m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3)) m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3))
assert m.coeff == 1 assert m.coeff == 1
assert m.base == 0 assert m.base == 0
m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4)) m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4))
assert m.coeff == 2 assert m.coeff == 2
assert m.base == 0 assert m.base == 0
...@@ -121,11 +123,13 @@ def test_constraint_scope(): ...@@ -121,11 +123,13 @@ def test_constraint_scope():
a = tvm.var("a") a = tvm.var("a")
b = tvm.var("b") b = tvm.var("b")
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(b % 4 == 2): tmod = tvm.truncmod
with analyzer.constraint_scope(tmod(b, 4) == 2):
m = analyzer.modular_set(b + 1) m = analyzer.modular_set(b + 1)
assert m.coeff == 4 assert m.coeff == 4
assert m.base == 3 assert m.base == 3
with analyzer.constraint_scope(a % 2 == 1): with analyzer.constraint_scope(tmod(a, 2) == 1):
m = analyzer.modular_set(b + a * 2) m = analyzer.modular_set(b + a * 2)
assert m.coeff == 4 assert m.coeff == 4
assert m.base == 0 assert m.base == 0
...@@ -140,15 +144,16 @@ def test_constraint_scope(): ...@@ -140,15 +144,16 @@ def test_constraint_scope():
def test_intersect(): def test_intersect():
a = tvm.var("a") a = tvm.var("a")
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1): tmod = tvm.truncmod
with analyzer.constraint_scope(a % 3 == 1): with analyzer.constraint_scope(tmod(a, 4) == 1):
with analyzer.constraint_scope(tmod(a, 3) == 1):
m = analyzer.modular_set(a) m = analyzer.modular_set(a)
assert m.coeff == 12 assert m.coeff == 12
assert m.base == 1 assert m.base == 1
with analyzer.constraint_scope(a % 3 == 2): with analyzer.constraint_scope(tmod(a, 3) == 2):
with analyzer.constraint_scope(a % 5 == 3): with analyzer.constraint_scope(tmod(a, 5) == 3):
with analyzer.constraint_scope(a % 7 == 2): with analyzer.constraint_scope(tmod(a, 7) == 2):
m = analyzer.modular_set(a) m = analyzer.modular_set(a)
assert m.coeff == 105 assert m.coeff == 105
assert m.base == 23 assert m.base == 23
......
...@@ -60,11 +60,14 @@ def test_pack_gemm(): ...@@ -60,11 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L)) k = tvm.reduce_axis((0, L))
bn = 4 bn = 4
fld = tvm.floordiv
flm = tvm.floormod
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn]) C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
s = tvm.create_schedule([C.op]) s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M assert compute_flop(s) == 2 * N * L * M
...@@ -119,9 +122,11 @@ def test_average_pool(): ...@@ -119,9 +122,11 @@ def test_average_pool():
OH = (H - KH) + 1 OH = (H - KH) + 1
OW = (W - KW) + 1 OW = (W - KW) + 1
C = tvm.compute( C = tvm.compute(
(N, CO, OH, OW), (N, CO, OH, OW),
lambda n, co, h, w: tvm.sum(D[n][co][h + kh][w + kw].astype(acc_dtype) / (KW * KH), axis=[kh, kw])) lambda n, co, h, w: tvm.sum(
tvm.div(D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw]))
s = tvm.create_schedule([C.op]) s = tvm.create_schedule([C.op])
......
...@@ -35,7 +35,7 @@ def test_lower_rfactor(): ...@@ -35,7 +35,7 @@ def test_lower_rfactor():
def test_dependent_output_shape(): def test_dependent_output_shape():
n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x') n, m, x = tvm.var('n'), tvm.var('m'), tvm.var('x')
A = tvm.placeholder((n, m)) A = tvm.placeholder((n, m))
B = tvm.compute((m, n/x), lambda i, j: A[i,j] , name='B') B = tvm.compute((m, n//x), lambda i, j: A[i,j] , name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
mod = tvm.build(s, [A, B, x]) mod = tvm.build(s, [A, B, x])
......
...@@ -409,7 +409,7 @@ def test_llvm_div(): ...@@ -409,7 +409,7 @@ def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++""" """Check that the semantics of div and mod is the same as in C/C++"""
def check_div(start, end, divisor, dtype): def check_div(start, end, divisor, dtype):
T = tvm.compute((end - start,), T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) / tvm.const(divisor, dtype)) lambda i: tvm.div(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op]) s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm") f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype) a = tvm.nd.empty((end - start,), dtype)
...@@ -418,8 +418,9 @@ def test_llvm_div(): ...@@ -418,8 +418,9 @@ def test_llvm_div():
tvm.testing.assert_allclose(a.asnumpy(), ref) tvm.testing.assert_allclose(a.asnumpy(), ref)
def check_mod(start, end, divisor, dtype): def check_mod(start, end, divisor, dtype):
tmod = tvm.truncmod
T = tvm.compute((end - start,), T = tvm.compute((end - start,),
lambda i: tvm.expr.Cast(dtype, (start + i)) % tvm.const(divisor, dtype)) lambda i: tmod(tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype)))
s = tvm.create_schedule([T.op]) s = tvm.create_schedule([T.op])
f = tvm.build(s, [T], "llvm") f = tvm.build(s, [T], "llvm")
a = tvm.nd.empty((end - start,), dtype) a = tvm.nd.empty((end - start,), dtype)
...@@ -443,7 +444,7 @@ def test_llvm_div(): ...@@ -443,7 +444,7 @@ def test_llvm_div():
def test_llvm_fp_math(): def test_llvm_fp_math():
def check_llvm_reciprocal(n): def check_llvm_reciprocal(n):
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda i: 1.0/(1e+37*A[i]), name='B') B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm") f = tvm.build(s, [A, B], "llvm")
......
...@@ -41,8 +41,9 @@ def test_if(): ...@@ -41,8 +41,9 @@ def test_if():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
tmod = tvm.truncmod
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2) == 0): with ib.if_scope(tmod(i, 2) == 0):
A[i] = A[i] + 1 A[i] = A[i] + 1
with ib.else_scope(): with ib.else_scope():
A[0] = A[i] + 2 A[0] = A[i] + 2
...@@ -108,13 +109,14 @@ def test_gpu(): ...@@ -108,13 +109,14 @@ def test_gpu():
dtype = "float32" dtype = "float32"
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
def test_device_ir(A, B, C): def test_device_ir(A, B, C):
n = A.shape[0] n = A.shape[0]
max_threads = 32 max_threads = 32
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", (n+max_threads-1) // max_threads) ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads) ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A) Aptr = ib.buffer_ptr(A)
......
...@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod(): ...@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct): def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\ assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# Test Case1 # Test Case1
index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1)) index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
index_direct = A_stride.vload((0, k0)) index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case2 # Test Case2
index_simplified = A.vload(((k0 % (k1 / s)) / n, index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
(k0 % (k1 / s)) % n + (k0 % k1))) idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s))) index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case3 # Test Case3
index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n, index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n)) idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_direct = A.vload((0, k0)) index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify) # Test Case4 (not able to simplify)
index_simplified = A.vload(((k0 % (k1 / s)) / n, index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
(k0 % (k1 / n)) % n + (k0 % k1))) idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
...@@ -143,14 +150,14 @@ def test_buffer_broadcast(): ...@@ -143,14 +150,14 @@ def test_buffer_broadcast():
check() check()
def test_bbuffer_roadcast_expr(): def test_buffer_broadcast_expr():
n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x') n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x')
n1, m1 = tvm.var('n1'), tvm.var('m1') n1, m1 = tvm.var('n1'), tvm.var('m1')
o0, o1 = tvm.var('o0'), tvm.var('o1') o0, o1 = tvm.var('o0'), tvm.var('o1')
A = tvm.placeholder((m0, n0), name='A') A = tvm.placeholder((m0, n0), name='A')
B = tvm.placeholder((m1, n1), name='B') B = tvm.placeholder((m1, n1), name='B')
C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C') C = tvm.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
......
...@@ -32,10 +32,11 @@ def test_const_fold(): ...@@ -32,10 +32,11 @@ def test_const_fold():
if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y)) raise ValueError("check error: %s vs %s " % (x, y))
tmod = tvm.truncmod
check(lambda x, y: x + y, 3, 4) check(lambda x, y: x + y, 3, 4)
check(lambda x, y: x * y, 3, 12) check(lambda x, y: x * y, 3, 12)
check(lambda x, y: x * y - 10, 3, 12) check(lambda x, y: x * y - 10, 3, 12)
check(lambda x, y: x - y % 10, 3, 12) check(lambda x, y: x - tmod(y, 10), 3, 12)
check(lambda x, y: x // y + 10, 100, 12) check(lambda x, y: x // y + 10, 100, 12)
check(lambda x, y: x & y + 10, 112, 128) check(lambda x, y: x & y + 10, 112, 128)
check(lambda x, y: x > y, 112, 128) check(lambda x, y: x > y, 112, 128)
...@@ -47,13 +48,15 @@ def test_const_fold(): ...@@ -47,13 +48,15 @@ def test_const_fold():
def test_const_fold2(): def test_const_fold2():
x = tvm.var("x") x = tvm.var("x")
tmod = tvm.truncmod
tdiv = tvm.truncdiv
assert (x + 0).same_as(x) assert (x + 0).same_as(x)
assert (0 + x).same_as(x) assert (0 + x).same_as(x)
assert (x - 0).same_as(x) assert (x - 0).same_as(x)
assert (x % 1).value == 0 assert tmod(x, 1).value == 0
assert (x * 1).same_as(x) assert (x * 1).same_as(x)
assert (1 * x).same_as(x) assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div) assert isinstance(tdiv(1, x), tvm.expr.Div)
def test_const_fold3(): def test_const_fold3():
# Test that using ints with logic operations is forbidden # Test that using ints with logic operations is forbidden
...@@ -88,8 +91,9 @@ def test_const_fold3(): ...@@ -88,8 +91,9 @@ def test_const_fold3():
def test_const_fold4(): def test_const_fold4():
x1 = tvm.const(4, "int32") x1 = tvm.const(4, "int32")
x2 = x1 + 5 x2 = x1 + 5
tdiv = tvm.truncdiv
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3 x3 = tdiv(x2, 3)
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.55 x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6 assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
......
...@@ -72,7 +72,7 @@ def test_combination(): ...@@ -72,7 +72,7 @@ def test_combination():
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B') B = tvm.placeholder((n, m), name='B')
C = tvm.placeholder((n, m), name='C') C = tvm.placeholder((n, m), name='C')
D = k + A - B * C / x D = k + A - B * C + x
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
foo = tvm.build(s, [x, A, B, C, D], "llvm") foo = tvm.build(s, [x, A, B, C, D], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
...@@ -82,7 +82,7 @@ def test_combination(): ...@@ -82,7 +82,7 @@ def test_combination():
c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx) c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
foo(x, a, b, c, d) foo(x, a, b, c, d)
tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() / x) tvm.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() + x)
def verify_tensor_scalar_bop(shape, typ="add"): def verify_tensor_scalar_bop(shape, typ="add"):
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
import tvm import tvm
def test_simplify(): def test_simplify():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
x = tvm.var('x') x = tvm.var('x')
e1 = tvm.ir_pass.Simplify(x + 2 + 1) e1 = tvm.ir_pass.Simplify(x + 2 + 1)
assert(tvm.ir_pass.Equal(e1, x + 3)) assert(tvm.ir_pass.Equal(e1, x + 3))
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.ir_pass.Equal(e2, x * 8)) assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) e3 = tvm.ir_pass.Simplify(x - tdiv(x, 3) * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) assert(tvm.ir_pass.Equal(e3, tmod(x, 3)))
def test_verify_ssa(): def test_verify_ssa():
......
...@@ -24,7 +24,7 @@ def test_equal_expr(): ...@@ -24,7 +24,7 @@ def test_equal_expr():
return x + y + 1 return x + y + 1
def func2(): def func2():
return tvm.exp((x + y + 1) * y / 4) return tvm.exp(tvm.truncdiv((x + y + 1) * y, 4))
assert tvm.ir_pass.Equal(func1(), func1()) assert tvm.ir_pass.Equal(func1(), func1())
assert tvm.ir_pass.Equal(func2(), func2()) assert tvm.ir_pass.Equal(func2(), func2())
......
...@@ -162,7 +162,7 @@ def test_condition(): ...@@ -162,7 +162,7 @@ def test_condition():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
with ib.for_range(0, ((n+3)/4), 'i') as i: with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i:
with ib.for_range(0, 4, 'j') as j: with ib.for_range(0, 4, 'j') as j:
ib.emit(tvm.make.Evaluate( ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(i*4+j<n), m, n))) tvm.make.Select(ib.likely(i*4+j<n), m, n)))
...@@ -206,7 +206,7 @@ def test_everything_during_deduction(): ...@@ -206,7 +206,7 @@ def test_everything_during_deduction():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
with ib.for_range(0, n, 'i') as i: with ib.for_range(0, n, 'i') as i:
with ib.for_range(0, 32, 'j') as j: with ib.for_range(0, 32, 'j') as j:
with ib.if_scope(ib.likely(i/j < m)): with ib.if_scope(ib.likely(tvm.truncdiv(i,j) < m)):
# this guard will produce everything during deduction # this guard will produce everything during deduction
ib.emit(tvm.make.Evaluate(m)) ib.emit(tvm.make.Evaluate(m))
stmt = ib.get() stmt = ib.get()
......
...@@ -111,9 +111,11 @@ def test_bound_fusesplit1(): ...@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0) idxdiv = tvm.indexdiv
assert(tvm.ir_pass.Simplify(
bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0)
expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
for i in range(1, 6): for i in range(1, 6):
for j in range(1, 6): for j in range(1, 6):
for k in range(1, 6): for k in range(1, 6):
...@@ -121,7 +123,7 @@ def test_bound_fusesplit1(): ...@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value
assert(comp_ext == exp_ext) assert(comp_ext == exp_ext)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
def test_bound_fusesplit2(): def test_bound_fusesplit2():
...@@ -394,11 +396,11 @@ def test_bound_simplification_failure(): ...@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
if not bounds[A.op.axis[0]].extent.value <= 2: if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt) print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2 assert bounds[A.op.axis[0]].extent.value <= 2
tdiv = tvm.truncdiv
# These are hard to simplify, moreover we don't simplify them # These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)])) _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)])) _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)])) _check(tvm.compute((10,), lambda i: A[-2*tdiv(i,2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)])) _check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it # This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i])) _check(tvm.compute((10,), lambda i: A[i]))
......
...@@ -221,11 +221,14 @@ def test_tensorize_matmul(): ...@@ -221,11 +221,14 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified # This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696 # as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op(): def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
def op_intrin(): def op_intrin():
bh = 9 bh = 9
bw = 9 bw = 9
x = tvm.placeholder((5, 5), name='A') x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw), lambda i,j: x[j/3 + i%3, j%3+ i/3]) y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
def intrin_func(ins, outs): def intrin_func(ins, outs):
xx, = ins xx, = ins
...@@ -236,7 +239,7 @@ def test_tensorize_op(): ...@@ -236,7 +239,7 @@ def test_tensorize_op():
return tvm.decl_tensor_intrin(y.op, intrin_func) return tvm.decl_tensor_intrin(y.op, intrin_func)
A = tvm.placeholder((5, 5), name='A') A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[j/3 + i%3, j%3 + i/3]) B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
bt = op_intrin() bt = op_intrin()
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
......
...@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, ...@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv') axis=[ci, kh, kw]), name='conv')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_conv2d_output') name='output_unpack', tag='spatial_conv2d_output')
return output return output
......
...@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n ...@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype), kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv') axis=[ci, kh, kw]), name='conv')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_conv2d_transpose_output') name='output_unpack', tag='spatial_conv2d_transpose_output')
return output return output
......
...@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, ...@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
kh = tvm.reduce_axis((0, KH), name='kh') kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw') kw = tvm.reduce_axis((0, KW), name='kw')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
if dilation_h != 1 or dilation_w != 1: if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ conv = tvm.compute(
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw] ovshape, lambda n, co, h, w, vh, vw, vc: \
.astype(out_dtype) * tvm.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype), .astype(out_dtype) *
axis=[kh, kw]), name='depthwise_conv') kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
axis=[kh, kw]), name='depthwise_conv')
else: else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh, tvm.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh,
vw * WSTR + kw].astype(out_dtype) * vw * WSTR + kw].astype(out_dtype) *
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype), kernel_vec[idxdiv(co, M),
idxmod(co, M),
kh, kw, vc].astype(out_dtype),
axis=[kh, kw]), name='depthwise_conv') axis=[kh, kw]), name='depthwise_conv')
output = tvm.compute(oshape, lambda n, co, h, w: output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], conv[n,
idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
name='output_unpack', tag='spatial_depthwise_conv_nchw_output') name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
return output return output
......
...@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): ...@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
[0, 0, (bpad_bottom + stride_h - 1) // stride_h, [0, 0, (bpad_bottom + stride_h - 1) // stride_h,
(bpad_right + stride_w - 1) // stride_w], name='FirstPad') (bpad_right + stride_w - 1) // stride_w], name='FirstPad')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# remove extra padding introduced by dilatation # remove extra padding introduced by dilatation
border_h = (stride_h - bpad_top % stride_h) % stride_h border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
border_w = (stride_w - bpad_left % stride_w) % stride_w border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
# dilation stage # dilation stage
data = FirstPad data = FirstPad
...@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): ...@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple = [] index_tuple = []
for i in range(n): for i in range(n):
if not equal_const_int(strides[i], 1): if not equal_const_int(strides[i], 1):
index_tuple.append(indices[i] // strides[i]) index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append((indices[i] % strides[i]).equal(0)) not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else: else:
index_tuple.append(indices[i]) index_tuple.append(indices[i])
if not_zero: if not_zero:
......
...@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
else: else:
kernel_pack = kernel kernel_pack = kernel
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# pack input tile # pack input tile
input_tile = tvm.compute((CI, P, alpha, alpha), lambda c, p, eps, nu: input_tile = tvm.compute((CI, P, alpha, alpha), lambda c, p, eps, nu:
data_pad[p // (nH * nW)][c][p // nW % nH * m + eps] data_pad[idxdiv(p, (nH * nW))][c][idxmod(idxdiv(p, nW), nH) * m + eps]
[p % nW * m + nu], name='d') [idxmod(p, nW) * m + nu], name='d')
# transform data # transform data
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
...@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
# output # output
output = tvm.compute((N, CO, H, W), lambda n, co, h, w: output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
inverse[co][n * nH * nW + (h // m) * nW + w // m][h % m][w % m], inverse[co,
n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(h, m),
idxmod(w, m)],
name='output', tag='conv2d_nchw_winograd') name='output', tag='conv2d_nchw_winograd')
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
......
...@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ...@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
new_range = num_anchors // elem_per_thread + 1 new_range = num_anchors // elem_per_thread + 1
# Scan: Downsweep: # Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors): with ib. if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors # number of batches i = tid // num_anchors # number of batches
j = tid % num_anchors # number of anchors j = tid % num_anchors # number of anchors
with ib.if_scope(j < elem_per_thread): with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid] idx[tid] = idx_in[tid]
...@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ...@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
tid = bx * max_threads + tx tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors i = tid // num_anchors
j = tid % num_anchors j = tid % num_anchors
base_idx = i * num_anchors * elem_length base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0): with ib.if_scope(flag[tid] > 0):
......
...@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score ...@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
tid = bx * max_threads + tx tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors i = tid // num_anchors
j = tid % num_anchors j = tid % num_anchors
with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0): with ib.if_scope(tid == 0):
......
...@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu], tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu],
axis=[r_a, r_b]), name='V') axis=[r_a, r_b]), name='V')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# batch gemm # batch gemm
ci = tvm.reduce_axis((0, CI), name='c') ci = tvm.reduce_axis((0, CI), name='c')
M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p: M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p:
tvm.sum(U[eps][nu][co // bna][ci][co % bna] * tvm.sum(U[eps][nu][idxdiv(co, bna)][ci][idxmod(co, bna)] *
V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M') V[eps][nu][idxdiv(p, bnb)][ci][idxmod(p, bnb)], axis=ci), name='M')
r_a = tvm.reduce_axis((0, alpha), 'r_a') r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_b') r_b = tvm.reduce_axis((0, alpha), 'r_b')
...@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output # unpack output
output = tvm.compute((N, CO, H, W), lambda n, co, h, w: output = tvm.compute((N, CO, H, W), lambda n, co, h, w:
Y[co][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] Y[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(h, m), idxmod(w, m)]
# The following hack term is used to make the padding in batch gemm ("M") # The following hack term is used to make the padding in batch gemm ("M")
# effective, otherwise the padding will be eliminated by bound inference. # effective, otherwise the padding will be eliminated by bound inference.
# Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding. # Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
......
...@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis=[ci, dh, dw, b1, b2]) axis=[ci, dh, dw, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv_out') conv = tvm.compute(ovshape, _conv, name='conv_out')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
return tvm.compute(oshape, lambda n, co, h, w: return tvm.compute(
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], oshape, lambda n, co, h, w:
name='conv_vec', tag='spatial_bitserial_conv_nchw') conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv(
w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
...@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv = tvm.compute(ovshape, _conv, name='conv') conv = tvm.compute(ovshape, _conv, name='conv')
return tvm.compute(oshape, lambda n, h, w, co: idxdiv = tvm.indexdiv
conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], idxmod = tvm.indexmod
name='output_unpack', tag='spatial_bitserial_conv_nhwc') return tvm.compute(
oshape, lambda n, h, w, co:
conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv(
co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
@tvm.target.generic_func @tvm.target.generic_func
def bitserial_conv2d_legalize(attrs, inputs, types): def bitserial_conv2d_legalize(attrs, inputs, types):
......
...@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp ...@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k: weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec') weight_packed[xo*VX+vx][wb][k], name='weight_vec')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) - (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype) tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
matmul = tvm.compute(oshape, lambda i, j: tvm.sum( matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype) tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
# binary ops # binary ops
......
...@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l ...@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
kh = tvm.reduce_axis((0, kernel_height), name='kh') kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw') kw = tvm.reduce_axis((0, kernel_width), name='kw')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, tvm.sum(data_pad[n,
ic // ic_bn, idxdiv(ic, ic_bn),
oh * HSTR + kh * dilation_h, oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w, ow * WSTR + kw * dilation_w,
ic % ic_bn].astype(out_dtype) idxmod(ic, ic_bn)].astype(out_dtype)
* kernel[oc_chunk, * kernel[oc_chunk,
ic // ic_bn, idxdiv(ic, ic_bn),
kh, kh,
kw, kw,
ic % ic_bn, idxmod(ic, ic_bn),
oc_block], oc_block],
axis=[ic, kh, kw]), axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc") name='conv2d_NCHWc', tag="conv2d_NCHWc")
......
...@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after = [0, 0, pad_down, pad_right] pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage # depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
di = tvm.reduce_axis((0, filter_height), name='di') di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj') dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute( Output = tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum( lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h+di*dilation_h, (PaddedInput[b, idxdiv(c, channel_multiplier), i*stride_h+di*dilation_h,
j*stride_w+dj*dilation_w].astype(out_dtype) * j*stride_w+dj*dilation_w].astype(out_dtype) *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)), Filter[idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier), di, dj].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output return Output
...@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No ...@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after = [0, pad_down, pad_right, 0] pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage # depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
di = tvm.reduce_axis((0, filter_height), name='di') di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj') dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute( Output = tvm.compute(
(batch, out_height, out_width, out_channel), (batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum( lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w, (PaddedInput[b, i*stride_h + di*dilation_h, j*stride_w + dj*dilation_w,
c/channel_multiplier].astype(out_dtype) * idxdiv(c, channel_multiplier)].astype(out_dtype) *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)), Filter[di, dj,
idxdiv(c, channel_multiplier),
idxmod(c, channel_multiplier)].astype(out_dtype)),
axis=[di, dj]), axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output return Output
...@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid ...@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh') dh = tvm.reduce_axis((0, Out_grad.shape[1].value), name='dh')
dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw') dw = tvm.reduce_axis((0, Out_grad.shape[2].value), name='dw')
db = tvm.reduce_axis((0, batch), name='db') db = tvm.reduce_axis((0, batch), name='db')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
Weight_grad = tvm.compute( Weight_grad = tvm.compute(
(filter_h, filter_w, in_c, channel_multiplier), (filter_h, filter_w, in_c, channel_multiplier),
lambda fh, fw, c, m: tvm.sum( lambda fh, fw, c, m: tvm.sum(
Out_grad[db, dh, dw, c*channel_multiplier+m%channel_multiplier] * Out_grad[db, dh, dw, c*channel_multiplier+idxmod(m, channel_multiplier)] *
padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]), padded_in[db, fh+dh*stride_h, fw+dw*stride_w, c], axis=[db, dh, dw]),
tag='depthwise_conv2d_backward_weight_nhwc') tag='depthwise_conv2d_backward_weight_nhwc')
......
...@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"): ...@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
def _dilate(*indices): def _dilate(*indices):
not_zero = [] not_zero = []
index_tuple = [] index_tuple = []
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
for i in range(n): for i in range(n):
if not util.equal_const_int(strides[i], 1): if not util.equal_const_int(strides[i], 1):
index_tuple.append(indices[i] / strides[i]) index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append((indices[i] % strides[i]).equal(0)) not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else: else:
index_tuple.append(indices[i]) index_tuple.append(indices[i])
if not_zero: if not_zero:
......
...@@ -38,12 +38,14 @@ def flatten(data): ...@@ -38,12 +38,14 @@ def flatten(data):
for i in range(1, len(ishape)): for i in range(1, len(ishape)):
dim = dim * ishape[i] dim = dim * ishape[i]
oshape = [ishape[0], dim] oshape = [ishape[0], dim]
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
def unwrap(idx, shape): def unwrap(idx, shape):
index = [] index = []
for s in reversed(shape): for s in reversed(shape):
index.append(idx % s) index.append(idxmod(idx, s))
idx = idx / s idx = idxdiv(idx, s)
return list(reversed(index)) return list(reversed(index))
return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:]))) return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:])))
...@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ...@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
ic = tvm.reduce_axis((0, in_channel), name='ic') ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh') kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw') kw = tvm.reduce_axis((0, kernel_width), name='kw')
idxmod = tvm.indexmod
idxdiv = tvm.indexdiv
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn, tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h,
idxmod(ic, ic_bn),
ow*WSTR+kw*dilation_w].astype(out_dtype) * ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw,
idxmod(ic, ic_bn),
oc_block].astype(out_dtype), oc_block].astype(out_dtype),
axis=[ic, kh, kw]), name='conv') axis=[ic, kh, kw]), name='conv')
unpack = tvm.compute(unpack_shape, unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
.astype(out_dtype), .astype(out_dtype),
name='output_unpack', name='output_unpack',
tag='conv2d_nchw') tag='conv2d_nchw')
...@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): ...@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
cfg = get_config() cfg = get_config()
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout) _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# change shape with the value in config # change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1]) cfg["tile_ow"].size[-1])
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn),
raw_data_shape[2], raw_data_shape[3], ic_bn) raw_data_shape[2], raw_data_shape[3], ic_bn)
data_layout = "NCHW%dc" % ic_bn data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn out_layout = "NCHW%dc" % oc_bn
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn),
idxdiv(raw_kernel_shape[1], ic_bn),
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
new_data = tvm.placeholder(new_data_shape, data.dtype) new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
...@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg): ...@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload _, data, kernel, strides, padding, dilation, layout, dtype = workload
batch_size, in_channel, in_height, in_width = data[:-1] batch_size, in_channel, in_height, in_width = data[:-1]
out_channel, _, k_height, k_width = kernel[:-1] out_channel, _, k_height, k_width = kernel[:-1]
out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 idxdiv = tvm.indexdiv
out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1
out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic in_layout = "NCHW%dc" % tile_ic
out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
out_layout = "NCHW%dc" % tile_oc out_layout = "NCHW%dc" % tile_oc
return ((in_shape, in_layout),), ((out_shape, out_layout),) return ((in_shape, in_layout),), ((out_shape, out_layout),)
......
...@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): ...@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
packw = tvm.compute(packw_shape, packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight") lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
k = tvm.reduce_axis((0, K), name="k") k = tvm.reduce_axis((0, K), name="k")
C = tvm.compute((M, N), C = tvm.compute((M, N),
lambda y, x: tvm.sum( lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) * data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype), packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
axis=k), axis=k),
tag="dense_pack") tag="dense_pack")
if bias is not None: if bias is not None:
......
...@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, ...@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
data_pad = data data_pad = data
# depthconv stage # depthconv stage
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
kh = tvm.reduce_axis((0, filter_height), name='kh') kh = tvm.reduce_axis((0, filter_height), name='kh')
kw = tvm.reduce_axis((0, filter_width), name='kw') kw = tvm.reduce_axis((0, filter_width), name='kw')
Output = tvm.compute( Output = tvm.compute(
(batch, out_channel_chunk, out_height, out_width, out_channel_block), (batch, out_channel_chunk, out_height, out_width, out_channel_block),
lambda b, oco, oh, ow, oci: tvm.sum( lambda b, oco, oh, ow, oci: tvm.sum(
(data_pad[b, (oco * out_channel_block + oci) // channel_multiplier // in_channel_block, (data_pad[
oh*HSTR+kh, ow*WSTR+kw, b,
((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block),
oh*HSTR+kh, ow*WSTR+kw,
idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)]
.astype(out_dtype) * .astype(out_dtype) *
kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)), kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
axis=[kh, kw]), axis=[kh, kw]),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment