Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2ded2d8c
Unverified
Commit
2ded2d8c
authored
Sep 27, 2019
by
Tianqi Chen
Committed by
GitHub
Sep 27, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Use explicit div mode in python. (#4014)
parent
16bed7e6
Hide whitespace changes
Inline
Side-by-side
Showing
46 changed files
with
481 additions
and
192 deletions
+481
-192
docs/api/python/tvm.rst
+14
-0
python/tvm/api.py
+71
-0
python/tvm/contrib/nnpack.py
+8
-4
python/tvm/expr.py
+29
-4
python/tvm/generic.py
+18
-1
python/tvm/hybrid/parser.py
+15
-2
python/tvm/relay/op/_transform.py
+4
-4
src/api/api_ir.cc
+2
-0
src/contrib/hybrid/codegen_hybrid.cc
+15
-2
src/contrib/hybrid/codegen_hybrid.h
+6
-4
src/lang/expr_operator.cc
+4
-0
src/pass/lower_intrin.cc
+0
-3
tests/python/unittest/test_arith_canonical_simplify.py
+50
-40
tests/python/unittest/test_arith_const_int_bound.py
+21
-14
tests/python/unittest/test_arith_deduce_bound.py
+6
-4
tests/python/unittest/test_arith_intset.py
+6
-4
tests/python/unittest/test_arith_modular_set.py
+16
-11
tests/python/unittest/test_autotvm_flop_calculator.py
+7
-2
tests/python/unittest/test_build_lower.py
+1
-1
tests/python/unittest/test_codegen_llvm.py
+4
-3
tests/python/unittest/test_ir_builder.py
+4
-2
tests/python/unittest/test_lang_buffer.py
+18
-11
tests/python/unittest/test_lang_operator.py
+8
-4
tests/python/unittest/test_lang_tensor_overload_op.py
+2
-2
tests/python/unittest/test_pass_basic.py
+4
-2
tests/python/unittest/test_pass_equal.py
+1
-1
tests/python/unittest/test_pass_loop_partition.py
+2
-2
tests/python/unittest/test_schedule_bound_inference.py
+7
-5
tests/python/unittest/test_schedule_tensorize.py
+5
-2
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
+6
-1
topi/python/topi/arm_cpu/conv2d_transpose.py
+6
-1
topi/python/topi/arm_cpu/depthwise_conv2d.py
+16
-8
topi/python/topi/cuda/conv2d_transpose_nchw.py
+6
-4
topi/python/topi/cuda/conv2d_winograd.py
+8
-3
topi/python/topi/cuda/nms.py
+2
-2
topi/python/topi/cuda/ssd/multibox.py
+1
-1
topi/python/topi/mali/conv2d.py
+7
-3
topi/python/topi/nn/bitserial_conv2d.py
+14
-6
topi/python/topi/nn/bitserial_dense.py
+8
-3
topi/python/topi/nn/conv2d.py
+7
-4
topi/python/topi/nn/depthwise_conv2d.py
+15
-5
topi/python/topi/nn/dilate.py
+4
-2
topi/python/topi/nn/flatten.py
+4
-2
topi/python/topi/x86/conv2d.py
+18
-9
topi/python/topi/x86/dense.py
+3
-1
topi/python/topi/x86/depthwise_conv2d.py
+8
-3
No files found.
docs/api/python/tvm.rst
View file @
2ded2d8c
...
@@ -35,6 +35,13 @@ The user facing API for computation declaration.
...
@@ -35,6 +35,13 @@ The user facing API for computation declaration.
tvm.thread_axis
tvm.thread_axis
tvm.comm_reducer
tvm.comm_reducer
tvm.sum
tvm.sum
tvm.div
tvm.indexdiv
tvm.indexmod
tvm.truncdiv
tvm.truncmod
tvm.floordiv
tvm.floormod
tvm.min
tvm.min
tvm.max
tvm.max
tvm.tag_scope
tvm.tag_scope
...
@@ -53,6 +60,13 @@ The user facing API for computation declaration.
...
@@ -53,6 +60,13 @@ The user facing API for computation declaration.
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
.. autofunction:: tvm.sum
.. autofunction:: tvm.div
.. autofunction:: tvm.indexdiv
.. autofunction:: tvm.indexmod
.. autofunction:: tvm.truncdiv
.. autofunction:: tvm.truncmod
.. autofunction:: tvm.floordiv
.. autofunction:: tvm.floormod
.. autofunction:: tvm.min
.. autofunction:: tvm.min
.. autofunction:: tvm.max
.. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
.. autofunction:: tvm.tag_scope
python/tvm/api.py
View file @
2ded2d8c
...
@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
...
@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
reducer
.
__doc__
=
doc_str
.
format
(
name
)
reducer
.
__doc__
=
doc_str
.
format
(
name
)
return
reducer
return
reducer
def
div
(
a
,
b
):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b).
"""
return
_make
.
_OpDiv
(
a
,
b
)
def
indexdiv
(
a
,
b
):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return
_make
.
_OpIndexDiv
(
a
,
b
)
def
indexmod
(
a
,
b
):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return
_make
.
_OpIndexMod
(
a
,
b
)
def
truncdiv
(
a
,
b
):
def
truncdiv
(
a
,
b
):
"""Compute the truncdiv of two expressions.
"""Compute the truncdiv of two expressions.
...
...
python/tvm/contrib/nnpack.py
View file @
2ded2d8c
...
@@ -101,8 +101,11 @@ def convolution_inference(
...
@@ -101,8 +101,11 @@ def convolution_inference(
assert
isinstance
(
stride
,
list
)
and
len
(
stride
)
==
2
assert
isinstance
(
stride
,
list
)
and
len
(
stride
)
==
2
batch
,
_
,
input_height
,
input_width
=
data
.
shape
batch
,
_
,
input_height
,
input_width
=
data
.
shape
output_channels
,
_
,
kernel_height
,
kernel_width
=
kernel
.
shape
output_channels
,
_
,
kernel_height
,
kernel_width
=
kernel
.
shape
output_height
=
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
)
/
stride
[
0
]
+
1
idxdiv
=
_api
.
indexdiv
output_width
=
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
)
/
stride
[
1
]
+
1
output_height
=
idxdiv
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
,
stride
[
0
])
+
1
output_width
=
idxdiv
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
,
stride
[
1
])
+
1
return
_api
.
extern
(
return
_api
.
extern
(
(
batch
,
output_channels
,
output_height
,
output_width
),
(
batch
,
output_channels
,
output_height
,
output_width
),
...
@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
...
@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
batch
,
_
,
input_height
,
input_width
=
data
.
shape
batch
,
_
,
input_height
,
input_width
=
data
.
shape
output_channels
,
_
,
_
,
_
=
transformed_kernel
.
shape
output_channels
,
_
,
_
,
_
=
transformed_kernel
.
shape
kernel_height
,
kernel_width
=
(
3
,
3
)
kernel_height
,
kernel_width
=
(
3
,
3
)
output_height
=
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
)
/
stride
[
0
]
+
1
idxdiv
=
_api
.
indexdiv
output_width
=
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
)
/
stride
[
1
]
+
1
output_height
=
idxdiv
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
,
stride
[
0
])
+
1
output_width
=
idxdiv
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
,
stride
[
1
])
+
1
return
_api
.
extern
(
return
_api
.
extern
(
(
batch
,
output_channels
,
output_height
,
output_width
),
(
batch
,
output_channels
,
output_height
,
output_width
),
...
...
python/tvm/expr.py
View file @
2ded2d8c
...
@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
...
@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
# pylint: disable=missing-docstring
# pylint: disable=missing-docstring
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.node
import
NodeBase
,
NodeGeneric
,
register_node
from
._ffi.node
import
NodeBase
,
NodeGeneric
,
register_node
from
._ffi.runtime_ctypes
import
TVMType
,
TypeCode
from
.
import
make
as
_make
from
.
import
make
as
_make
from
.
import
generic
as
_generic
from
.
import
generic
as
_generic
from
.
import
_api_internal
from
.
import
_api_internal
def
div_ambiguity_error
():
return
RuntimeError
(
"TVM supports multiple types of integer divisions, "
+
"please call div, indexdiv/indexmod, floordiv/floormod "
+
" or truncdiv/truncmod directly to avoid ambiguity in the code."
)
def
_dtype_is_int
(
value
):
if
isinstance
(
value
,
int
):
return
True
return
(
isinstance
(
value
,
ExprOp
)
and
TVMType
(
value
.
dtype
)
.
type_code
==
TypeCode
.
INT
)
class
ExprOp
(
object
):
class
ExprOp
(
object
):
def
__add__
(
self
,
other
):
def
__add__
(
self
,
other
):
return
_generic
.
add
(
self
,
other
)
return
_generic
.
add
(
self
,
other
)
...
@@ -58,24 +72,35 @@ class ExprOp(object):
...
@@ -58,24 +72,35 @@ class ExprOp(object):
return
_generic
.
multiply
(
other
,
self
)
return
_generic
.
multiply
(
other
,
self
)
def
__div__
(
self
,
other
):
def
__div__
(
self
,
other
):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
self
,
other
)
return
_generic
.
divide
(
self
,
other
)
def
__rdiv__
(
self
,
other
):
def
__rdiv__
(
self
,
other
):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
other
,
self
)
return
_generic
.
divide
(
other
,
self
)
def
__truediv__
(
self
,
other
):
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
self
,
other
)
def
__rtruediv__
(
self
,
other
):
def
__rtruediv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
other
,
self
)
def
__floordiv__
(
self
,
other
):
def
__floordiv__
(
self
,
other
):
return
self
.
__div__
(
other
)
# return _generic.floordiv(self, other)
return
_generic
.
divide
(
self
,
other
)
def
__rfloordiv__
(
self
,
other
):
def
__rfloordiv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
# return _generic.floordiv(other, self)
return
_generic
.
divide
(
other
,
self
)
def
__mod__
(
self
,
other
):
def
__mod__
(
self
,
other
):
# raise div_ambiguity_error()
return
_make
.
_OpMod
(
self
,
other
)
return
_make
.
_OpMod
(
self
,
other
)
def
__neg__
(
self
):
def
__neg__
(
self
):
...
...
python/tvm/generic.py
View file @
2ded2d8c
...
@@ -25,6 +25,7 @@ from . import make as _make
...
@@ -25,6 +25,7 @@ from . import make as _make
#Operator precedence used when overloading.
#Operator precedence used when overloading.
__op_priority__
=
0
__op_priority__
=
0
def
add
(
lhs
,
rhs
):
def
add
(
lhs
,
rhs
):
"""Generic add operator.
"""Generic add operator.
...
@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
...
@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
"""
"""
return
_make
.
_OpMul
(
lhs
,
rhs
)
return
_make
.
_OpMul
(
lhs
,
rhs
)
def
divide
(
lhs
,
rhs
):
def
divide
(
lhs
,
rhs
):
"""Generic divide operator.
"""Generic divide operator.
...
@@ -96,6 +96,23 @@ def divide(lhs, rhs):
...
@@ -96,6 +96,23 @@ def divide(lhs, rhs):
"""
"""
return
_make
.
_OpDiv
(
lhs
,
rhs
)
return
_make
.
_OpDiv
(
lhs
,
rhs
)
def
floordiv
(
lhs
,
rhs
):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return
_make
.
_OpFloorDiv
(
lhs
,
rhs
)
def
cast
(
src
,
dtype
):
def
cast
(
src
,
dtype
):
"""Generic cast operator.
"""Generic cast operator.
...
...
python/tvm/hybrid/parser.py
View file @
2ded2d8c
...
@@ -31,6 +31,7 @@ from . import util
...
@@ -31,6 +31,7 @@ from . import util
from
.preprocessor
import
determine_variable_usage
from
.preprocessor
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..api
import
any
as
_any
from
..container
import
Array
from
..container
import
Array
from
..tensor
import
Tensor
,
Operation
from
..tensor
import
Tensor
,
Operation
from
..
import
_api_internal
as
_tvm_internal
from
..
import
_api_internal
as
_tvm_internal
...
@@ -78,6 +79,18 @@ class Symbol(Enum):
...
@@ -78,6 +79,18 @@ class Symbol(Enum):
ThreadBind
=
10
ThreadBind
=
10
def
_floordiv
(
x
,
y
):
if
isinstance
(
x
,
_expr
.
ExprOp
)
or
isinstance
(
y
,
_expr
.
ExprOp
):
return
_api
.
floordiv
(
x
,
y
)
return
operator
.
floordiv
(
x
,
y
)
def
_floormod
(
x
,
y
):
if
isinstance
(
x
,
_expr
.
ExprOp
)
or
isinstance
(
y
,
_expr
.
ExprOp
):
return
_api
.
floormod
(
x
,
y
)
return
operator
.
mod
(
x
,
y
)
class
HybridParser
(
ast
.
NodeVisitor
):
class
HybridParser
(
ast
.
NodeVisitor
):
"""Python AST visitor pass which finally lowers it to HalideIR"""
"""Python AST visitor pass which finally lowers it to HalideIR"""
...
@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
...
@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
ast
.
Sub
:
operator
.
sub
,
ast
.
Sub
:
operator
.
sub
,
ast
.
Mult
:
operator
.
mul
,
ast
.
Mult
:
operator
.
mul
,
ast
.
Div
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
Div
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
FloorDiv
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
true
div
,
ast
.
FloorDiv
:
_floor
div
,
ast
.
Mod
:
operator
.
mod
,
ast
.
Mod
:
_floor
mod
,
ast
.
BitOr
:
operator
.
or_
,
ast
.
BitOr
:
operator
.
or_
,
ast
.
BitAnd
:
operator
.
and_
,
ast
.
BitAnd
:
operator
.
and_
,
ast
.
BitXor
:
operator
.
xor
,
ast
.
BitXor
:
operator
.
xor
,
...
...
python/tvm/relay/op/_transform.py
View file @
2ded2d8c
...
@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
...
@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
@script
@script
def
_arange_shape_func
(
start
,
stop
,
step
):
def
_arange_shape_func
(
start
,
stop
,
step
):
out
=
output_tensor
((
1
,),
"int64"
)
out
=
output_tensor
((
1
,),
"int64"
)
out
[
0
]
=
int64
(
ceil_div
((
float32
(
stop
[
0
])
-
float32
(
start
[
0
])),
float32
(
step
[
0
])))
out
[
0
]
=
int64
(
ceil_div
((
int64
(
stop
[
0
])
-
int64
(
start
[
0
])),
int64
(
step
[
0
])))
return
out
return
out
@_reg.register_shape_func
(
"arange"
,
True
)
@_reg.register_shape_func
(
"arange"
,
True
)
...
@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
...
@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
assert
len
(
newshape
)
-
i
>
2
,
"Not enough dims in new shape for -4"
assert
len
(
newshape
)
-
i
>
2
,
"Not enough dims in new shape for -4"
if
newshape
[
i
+
1
]
==
-
1
:
if
newshape
[
i
+
1
]
==
-
1
:
assert
newshape
[
i
+
2
]
!=
-
1
,
"Split dims cannot both be -1."
assert
newshape
[
i
+
2
]
!=
-
1
,
"Split dims cannot both be -1."
out
[
dst_idx
]
=
data_shape
[
src_idx
]
/
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
]
=
data_shape
[
src_idx
]
/
/
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
else
:
else
:
out
[
dst_idx
]
=
int64
(
newshape
[
i
+
1
])
out
[
dst_idx
]
=
int64
(
newshape
[
i
+
1
])
if
newshape
[
i
+
2
]
==
-
1
:
if
newshape
[
i
+
2
]
==
-
1
:
out
[
dst_idx
+
1
]
=
data_shape
[
src_idx
]
/
int64
(
newshape
[
i
+
1
])
out
[
dst_idx
+
1
]
=
data_shape
[
src_idx
]
/
/
int64
(
newshape
[
i
+
1
])
else
:
else
:
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
assert
data_shape
[
src_idx
]
==
out
[
dst_idx
]
*
out
[
dst_idx
+
1
],
\
assert
data_shape
[
src_idx
]
==
out
[
dst_idx
]
*
out
[
dst_idx
+
1
],
\
...
@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
...
@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
new_size
=
int64
(
1
)
new_size
=
int64
(
1
)
for
i
in
const_range
(
out
.
shape
[
0
]):
for
i
in
const_range
(
out
.
shape
[
0
]):
new_size
*=
out
[
i
]
new_size
*=
out
[
i
]
out
[
infer_idx
]
=
old_size
/
new_size
out
[
infer_idx
]
=
old_size
/
/
new_size
return
out
return
out
@_reg.register_shape_func
(
"reshape"
,
False
)
@_reg.register_shape_func
(
"reshape"
,
False
)
...
...
src/api/api_ir.cc
View file @
2ded2d8c
...
@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
...
@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP
(
_OpMul
,
operator
*
);
REGISTER_MAKE_BINARY_OP
(
_OpMul
,
operator
*
);
REGISTER_MAKE_BINARY_OP
(
_OpDiv
,
div
);
REGISTER_MAKE_BINARY_OP
(
_OpDiv
,
div
);
REGISTER_MAKE_BINARY_OP
(
_OpMod
,
truncmod
);
REGISTER_MAKE_BINARY_OP
(
_OpMod
,
truncmod
);
REGISTER_MAKE_BINARY_OP
(
_OpIndexDiv
,
indexdiv
);
REGISTER_MAKE_BINARY_OP
(
_OpIndexMod
,
indexmod
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorDiv
,
floordiv
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorDiv
,
floordiv
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorMod
,
floormod
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorMod
,
floormod
);
REGISTER_MAKE_BINARY_OP
(
_OpTruncDiv
,
truncdiv
);
REGISTER_MAKE_BINARY_OP
(
_OpTruncDiv
,
truncdiv
);
...
...
src/contrib/hybrid/codegen_hybrid.cc
View file @
2ded2d8c
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
...
@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"*"
,
os
,
this
);
PrintBinaryExpr
(
op
,
"*"
,
os
,
this
);
}
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
type
.
is_int
())
if
(
op
->
type
.
is_int
())
PrintBinaryExpr
(
op
,
"//"
,
os
,
this
);
PrintBinaryExpr
(
op
,
"//"
,
os
,
this
);
else
else
PrintBinaryExpr
(
op
,
"/"
,
os
,
this
);
PrintBinaryExpr
(
op
,
"/"
,
os
,
this
);
}
}
void
CodeGenHybrid
::
VisitExpr_
(
const
FloorDiv
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
type
.
is_int
())
PrintBinaryExpr
(
op
,
"//"
,
os
,
this
);
else
PrintBinaryExpr
(
op
,
"/"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"%"
,
os
,
this
);
PrintBinaryExpr
(
op
,
"%"
,
os
,
this
);
}
}
void
CodeGenHybrid
::
VisitExpr_
(
const
FloorMod
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"%"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"min"
,
os
,
this
);
PrintBinaryExpr
(
op
,
"min"
,
os
,
this
);
}
}
...
...
src/contrib/hybrid/codegen_hybrid.h
View file @
2ded2d8c
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -100,6 +100,8 @@ class CodeGenHybrid :
...
@@ -100,6 +100,8 @@ class CodeGenHybrid :
void
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
FloorDiv
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
FloorMod
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Max
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Max
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
EQ
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
EQ
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
...
@@ -161,12 +163,12 @@ class CodeGenHybrid :
...
@@ -161,12 +163,12 @@ class CodeGenHybrid :
std
::
string
GetUniqueName
(
std
::
string
prefix
);
std
::
string
GetUniqueName
(
std
::
string
prefix
);
/*! \brief The output code string builder. */
/*! \brief The output code string builder. */
std
::
stringstream
stream
;
std
::
stringstream
stream
;
/*!
/*!
* \brief Get or allocate the ID for the given variable.
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
* \param v The given variable.
*/
*/
std
::
string
GetVarID
(
const
Variable
*
v
);
std
::
string
GetVarID
(
const
Variable
*
v
);
/*!
/*!
* \brief Get or allocate the ID for the given tensor.
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
* \param value_index The value index of the given tensor.
...
...
src/lang/expr_operator.cc
View file @
2ded2d8c
...
@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
...
@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
}
}
Expr
floordiv
(
Expr
a
,
Expr
b
)
{
Expr
floordiv
(
Expr
a
,
Expr
b
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
());
CHECK
(
b
.
type
().
is_int
()
||
b
.
type
().
is_uint
());
BinaryOpMatchTypes
(
a
,
b
);
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorDiv
>
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorDiv
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
if
(
ret
.
defined
())
return
ret
;
...
@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
...
@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
}
}
Expr
floormod
(
Expr
a
,
Expr
b
)
{
Expr
floormod
(
Expr
a
,
Expr
b
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
());
CHECK
(
b
.
type
().
is_int
()
||
b
.
type
().
is_uint
());
BinaryOpMatchTypes
(
a
,
b
);
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorMod
>
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorMod
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
if
(
ret
.
defined
())
return
ret
;
...
...
src/pass/lower_intrin.cc
View file @
2ded2d8c
...
@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
...
@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if
(
op
==
nullptr
)
return
ret
;
if
(
op
==
nullptr
)
return
ret
;
int
shift
;
int
shift
;
const
DataType
&
dtype
=
op
->
type
;
const
DataType
&
dtype
=
op
->
type
;
if
(
dtype
.
is_float
())
{
return
floor
(
Div
::
make
(
op
->
a
,
op
->
b
));
}
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
if
(
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
if
(
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
...
...
tests/python/unittest/test_arith_canonical_simplify.py
View file @
2ded2d8c
...
@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
...
@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
x
*
13
+
z
*
4
+
y
*
4
+
6
)
x
*
13
+
z
*
4
+
y
*
4
+
6
)
ck
.
verify
(
x
*
3
-
4
*
x
+
1
,
1
-
x
)
ck
.
verify
(
x
*
3
-
4
*
x
+
1
,
1
-
x
)
ck
.
verify
(
y
+
x
*
3
-
5
*
x
+
1
+
y
,
y
*
2
+
1
-
x
*
2
)
ck
.
verify
(
y
+
x
*
3
-
5
*
x
+
1
+
y
,
y
*
2
+
1
-
x
*
2
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# trucdiv
# trucdiv
ck
.
verify
(
(
x
+
y
+
x
+
y
*
3
)
/
2
,
y
*
2
+
x
)
ck
.
verify
(
tdiv
(
x
+
y
+
x
+
y
*
3
,
2
)
,
y
*
2
+
x
)
ck
.
verify
(
(
x
+
y
+
x
+
y
*
3
)
%
2
,
0
)
ck
.
verify
(
tmod
(
x
+
y
+
x
+
y
*
3
,
2
)
,
0
)
# floordiv
# floordiv
fld
=
tvm
.
floordiv
fld
=
tvm
.
floordiv
...
@@ -51,28 +53,31 @@ def test_split_index_simplify():
...
@@ -51,28 +53,31 @@ def test_split_index_simplify():
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# trucdiv
# trucdiv
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# split div const
# split div const
ck
.
verify
(
(
x
/
3
)
*
3
+
x
%
3
,
x
)
ck
.
verify
(
tdiv
(
x
,
3
)
*
3
+
tmod
(
x
,
3
)
,
x
)
ck
.
verify
(
(
x
/
6
)
*
6
+
((
x
/
3
)
%
2
)
*
3
+
x
%
3
,
x
)
ck
.
verify
(
tdiv
(
x
,
6
)
*
6
+
tmod
(
tdiv
(
x
,
3
),
2
)
*
3
+
tmod
(
x
,
3
)
,
x
)
ck
.
verify
(
((
x
%
16
)
/
2
)
*
2
/
4
,
(
x
%
16
)
/
4
)
ck
.
verify
(
tdiv
(
tdiv
(
tmod
(
x
,
16
),
2
)
*
2
,
4
),
tdiv
(
tmod
(
x
,
16
),
4
)
)
ck
.
verify
(
(
x
%
2
)
/
8
,
0
)
ck
.
verify
(
tdiv
(
tmod
(
x
,
2
),
8
)
,
0
)
ck
.
verify
(
(
x
%
2
)
/
7
,
0
)
ck
.
verify
(
tdiv
(
tmod
(
x
,
2
),
7
)
,
0
)
ck
.
verify
(
((
x
%
16
)
/
2
)
*
2
/
6
,
(
x
%
16
)
/
6
)
ck
.
verify
(
tdiv
(
tdiv
(
tmod
(
x
,
16
),
2
)
*
2
,
6
),
tdiv
(
tmod
(
x
,
16
),
6
)
)
# split mod const
# split mod const
ck
.
verify
(
(
x
*
8
)
%
16
,
(
x
%
2
)
*
8
)
ck
.
verify
(
tmod
((
x
*
8
),
16
),
tmod
(
x
,
2
)
*
8
)
ck
.
verify
(
(
x
*
8
)
%
2
,
0
)
ck
.
verify
(
tmod
(
x
*
8
,
2
)
,
0
)
# simplify then fold
# simplify then fold
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
verify
(
(
x
*
4
+
y
)
/
2
*
2
+
(
x
*
4
+
y
)
%
2
,
x
*
4
+
y
)
ck
.
verify
(
tdiv
(
x
*
4
+
y
,
2
)
*
2
+
tmod
(
x
*
4
+
y
,
2
)
,
x
*
4
+
y
)
# complex fold
# complex fold
ck
.
verify
(
(
z
*
9
+
y
)
/
2
*
2
+
(
z
*
9
+
y
)
%
2
,
z
*
9
+
y
)
ck
.
verify
(
tdiv
(
z
*
9
+
y
,
2
)
*
2
+
tmod
(
z
*
9
+
y
,
2
)
,
z
*
9
+
y
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
verify
(
(
x
*
4
+
y
)
/
2
*
2
+
(
x
*
4
+
y
)
%
2
,
x
*
4
+
y
)
ck
.
verify
(
tdiv
(
x
*
4
+
y
,
2
)
*
2
+
tmod
(
x
*
4
+
y
,
2
)
,
x
*
4
+
y
)
# floordiv
# floordiv
fld
=
tvm
.
floordiv
fld
=
tvm
.
floordiv
...
@@ -85,23 +90,24 @@ def test_split_index_simplify():
...
@@ -85,23 +90,24 @@ def test_split_index_simplify():
ck
.
verify
(
fld
(
fld
(
flm
(
x
,
16
),
2
)
*
2
,
6
),
fld
(
flm
(
x
,
16
),
6
))
ck
.
verify
(
fld
(
fld
(
flm
(
x
,
16
),
2
)
*
2
,
6
),
fld
(
flm
(
x
,
16
),
6
))
# cannot simplify mixed case, unless we canonicalize into one mode.
# cannot simplify mixed case, unless we canonicalize into one mode.
ck
.
verify
(
(
x
/
6
)
*
2
+
fld
(
x
,
3
)
%
2
,
(
x
/
6
)
*
2
+
fld
(
x
,
3
)
%
2
)
ck
.
verify
(
tdiv
(
x
,
6
)
*
2
+
tmod
(
fld
(
x
,
3
),
2
),
tdiv
(
x
,
6
)
*
2
+
tmod
(
fld
(
x
,
3
),
2
)
)
def
test_div_simplify
():
def
test_div_simplify
():
ck
=
CanonicalChecker
()
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
x
=
tvm
.
var
(
"x"
)
tdiv
=
tvm
.
truncdiv
# truc div
# truc div
ck
.
verify
(
(
16
+
48
*
x
)
/
16
,
x
*
3
+
1
)
ck
.
verify
(
tdiv
(
16
+
48
*
x
,
16
)
,
x
*
3
+
1
)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
# (17+48*x)/16 != 1+3*x
ck
.
verify
(
(
17
+
48
*
x
)
/
16
,
(
x
*
48
+
17
)
/
16
)
ck
.
verify
(
tdiv
(
17
+
48
*
x
,
16
),
tdiv
(
x
*
48
+
17
,
16
)
)
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
10
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
10
))
ck
.
verify
(
(
17
+
48
*
x
)
/
16
,
x
*
3
+
1
)
ck
.
verify
(
tdiv
(
17
+
48
*
x
,
16
)
,
x
*
3
+
1
)
# Trying expressions that are not simplifiable for any values of the variables
# Trying expressions that are not simplifiable for any values of the variables
ck
.
verify
(
(
17
+
47
*
x
)
/
16
,
(
x
*
47
+
17
)
/
16
)
ck
.
verify
(
tdiv
(
17
+
47
*
x
,
16
),
tdiv
(
x
*
47
+
17
,
16
)
)
# floordiv
# floordiv
fld
=
tvm
.
floordiv
fld
=
tvm
.
floordiv
...
@@ -124,8 +130,10 @@ def test_canonical_mixed():
...
@@ -124,8 +130,10 @@ def test_canonical_mixed():
ck
=
CanonicalChecker
()
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
x
=
tvm
.
var
(
"x"
)
z
=
tvm
.
const
(
3
,
"int32"
)
z
=
tvm
.
const
(
3
,
"int32"
)
ck
.
verify
(
x
/
(
z
*
z
)
-
x
/
(
z
*
z
),
0
)
tdiv
=
tvm
.
truncdiv
ck
.
verify
(
x
/
(
z
+
z
)
-
x
/
(
z
+
z
),
0
)
tmod
=
tvm
.
truncmod
ck
.
verify
(
tdiv
(
x
,
(
z
*
z
))
-
tdiv
(
x
,
(
z
*
z
)),
0
)
ck
.
verify
(
tdiv
(
x
,
(
z
+
z
))
-
tdiv
(
x
,
(
z
+
z
)),
0
)
ck
.
verify
(
x
-
2
<
3
,
x
<
5
)
ck
.
verify
(
x
-
2
<
3
,
x
<
5
)
ck
.
verify
(
tvm
.
max
(
x
,
1
)
-
tvm
.
max
(
x
,
1
),
0
)
ck
.
verify
(
tvm
.
max
(
x
,
1
)
-
tvm
.
max
(
x
,
1
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
1
)
-
tvm
.
min
(
x
,
1
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
1
)
-
tvm
.
min
(
x
,
1
),
0
)
...
@@ -207,42 +215,44 @@ def test_reduce_simplify():
...
@@ -207,42 +215,44 @@ def test_reduce_simplify():
tvm
.
sum
(
k
+
j
,
[
k
,
j
]))
tvm
.
sum
(
k
+
j
,
[
k
,
j
]))
ck
.
verify
(
tvm
.
sum
(
A
[
3
],
[]),
A
[
3
])
ck
.
verify
(
tvm
.
sum
(
A
[
3
],
[]),
A
[
3
])
# The rule below is not typical, removed for now
# The rule below is not typical, removed for now
ck
.
verify
(
tvm
.
sum
(
k
/
10
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
ck
.
verify
(
tvm
.
sum
(
tvm
.
div
(
k
,
10
)
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
def
test_simplify_if_then_else
():
def
test_simplify_if_then_else
():
ck
=
CanonicalChecker
()
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
y
=
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# simplification that takes condition into account.
# simplification that takes condition into account.
res
=
tvm
.
if_then_else
((
x
*
4
+
y
)
>=
466036
,
res
=
tvm
.
if_then_else
((
x
*
4
+
y
)
>=
466036
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
tvm
.
if_then_else
(
24512
<=
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
tmod
(
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
)
-
24512
,
16
)
,
x
),
y
)
x
),
y
)
res2
=
tvm
.
if_then_else
((
x
*
4
)
>=
466036
-
y
,
res2
=
tvm
.
if_then_else
((
x
*
4
)
>=
466036
-
y
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
tvm
.
if_then_else
(
24512
<=
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
tmod
(
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
)
-
24512
,
16
)
,
x
),
y
)
x
),
y
)
expected
=
tvm
.
if_then_else
(
expected
=
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
466036
,
(
x
*
4
+
y
)),
tvm
.
expr
.
LE
(
466036
,
(
x
*
4
+
y
)),
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
((((
x
*
4
)
+
y
)
-
4
)
%
24528
)),
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
tmod
(((
x
*
4
)
+
y
)
-
4
,
24528
)),
(((
x
*
4
)
+
y
)
-
4
)
%
16
,
tmod
(((
x
*
4
)
+
y
)
-
4
,
16
)
,
x
),
y
)
x
),
y
)
ck
.
verify
(
res
,
expected
)
ck
.
verify
(
res
,
expected
)
ck
.
verify
(
res2
,
expected
)
ck
.
verify
(
res2
,
expected
)
# can only simplify if condition
# can only simplify if condition
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
100
)
%
3
,
(
x
+
100
)
%
3
)
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
tmod
(
x
+
y
+
100
,
3
),
tmod
(
x
+
100
,
3
)
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
1
)
%
3
,
(
x
+
100
)
%
3
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
tmod
(
x
+
y
+
1
,
3
),
tmod
(
x
+
100
,
3
)
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
>
2
,
x
,
0
),
0
)
tvm
.
if_then_else
(
tdiv
(
x
,
3
)
>
2
,
x
,
0
),
0
)
expected
=
tvm
.
expr
.
Select
(
x
>=
10
,
x
,
0
)
expected
=
tvm
.
expr
.
Select
(
x
>=
10
,
x
,
0
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
<
2
,
x
,
0
),
0
)
tvm
.
if_then_else
(
tdiv
(
x
,
3
)
<
2
,
x
,
0
),
0
)
ck
.
verify
(
res
,
0
)
ck
.
verify
(
res
,
0
)
...
@@ -250,20 +260,20 @@ def test_complex_cases():
...
@@ -250,20 +260,20 @@ def test_complex_cases():
ck
=
CanonicalChecker
()
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
y
=
tvm
.
var
(
"y"
)
res2
=
(((((((((((
x
*
128
)
+
y
)
%
1296
)
/
36
)
*
2
)
+
1
)
/
2
)
*
36
)
+
tdiv
=
tvm
.
truncdiv
((((((
x
*
128
)
+
y
)
%
36
)
*
2
)
+
1
)
/
2
))
tmod
=
tvm
.
truncmod
-
(((
x
*
128
)
+
y
)
%
1296
))
+
1
)
res2
=
(
tdiv
(
tdiv
(
tmod
(
x
*
128
+
y
,
1296
),
36
)
*
2
+
1
,
2
)
*
36
+
tdiv
(
tmod
((
x
*
128
)
+
y
,
36
)
*
2
+
1
,
2
)
-
tmod
((
x
*
128
)
+
y
,
1296
)
+
1
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
5
))
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
5
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
127
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
127
))
ck
.
verify
(
res2
,
1
)
ck
.
verify
(
res2
,
1
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1024
),
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1024
),
True
)
res3
=
((((((((((
x
*
1024
)
+
y
)
/
65536
)
+
((((
x
*
1024
)
+
y
)
%
65536
)
/
256
))
res3
=
(
tdiv
(
x
*
1024
+
y
,
65536
)
+
tdiv
(
tmod
(
x
*
1024
+
y
,
65536
),
256
)
+
((((
x
*
1024
)
+
y
)
%
256
)
/
16
))
+
(((
x
*
1024
)
+
y
)
%
16
))
-
(
y
/
256
))
-
+
tdiv
(
tmod
(
x
*
1024
+
y
,
256
),
16
)
+
tmod
(
x
*
1024
+
y
,
16
)
-
tdiv
(
y
,
256
)
-
((
y
%
256
)
/
16
))
-
(
y
%
16
))
-
(
x
*
4
))
tdiv
(
tmod
(
y
,
256
),
16
)
-
tmod
(
y
,
16
)
-
(
x
*
4
))
ck
.
verify
(
res3
,
((((
x
*
1024
)
+
y
)
/
256
)
-
(
y
/
256
))
-
(
x
*
4
))
ck
.
verify
(
res3
,
tdiv
((
x
*
1024
)
+
y
,
256
)
-
tdiv
(
y
,
256
)
-
(
x
*
4
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/unittest/test_arith_const_int_bound.py
View file @
2ded2d8c
...
@@ -38,12 +38,13 @@ def test_dtype_bound():
...
@@ -38,12 +38,13 @@ def test_dtype_bound():
def
test_cast_bound
():
def
test_cast_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
x
=
tvm
.
var
(
"x"
,
dtype
=
"int8"
)
x
=
tvm
.
var
(
"x"
,
dtype
=
"int8"
)
bd
=
analyzer
.
const_int_bound
((
x
%
3
)
.
astype
(
"uint32"
))
tmod
=
tvm
.
truncmod
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
3
)
.
astype
(
"uint32"
))
assert
bd
.
min_value
==
0
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
2
assert
bd
.
max_value
==
2
bd
=
analyzer
.
const_int_bound
(
bd
=
analyzer
.
const_int_bound
(
(
x
%
3
)
.
astype
(
"float32"
)
.
astype
(
"int32"
))
tmod
(
x
,
3
)
.
astype
(
"float32"
)
.
astype
(
"int32"
))
assert
bd
.
min_value
==
-
2
assert
bd
.
min_value
==
-
2
assert
bd
.
max_value
==
2
assert
bd
.
max_value
==
2
...
@@ -98,47 +99,50 @@ def test_mul_bound():
...
@@ -98,47 +99,50 @@ def test_mul_bound():
assert
bd
.
max_value
==
bd
.
POS_INF
assert
bd
.
max_value
==
bd
.
POS_INF
def
test_div_bound
():
def
test_
trunc
div_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
-
2
assert
bd
.
min_value
==
-
2
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
),
override
=
True
)
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
0
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
0
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
-
4
assert
bd
.
min_value
==
-
4
assert
bd
.
max_value
==
9
assert
bd
.
max_value
==
9
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
4
),
override
=
True
)
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
4
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
1
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
1
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
bd
.
NEG_INF
assert
bd
.
min_value
==
bd
.
NEG_INF
assert
bd
.
max_value
==
bd
.
POS_INF
assert
bd
.
max_value
==
bd
.
POS_INF
def
test_mod_bound
():
def
test_
trunc
mod_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tmod
=
tvm
.
truncmod
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
-
9
assert
bd
.
min_value
==
-
9
assert
bd
.
max_value
==
4
assert
bd
.
max_value
==
4
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
-
9
assert
bd
.
min_value
==
-
9
assert
bd
.
max_value
==
9
assert
bd
.
max_value
==
9
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
1
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
1
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
0
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
9
assert
bd
.
max_value
==
9
...
@@ -253,9 +257,12 @@ def test_shift_and_bound():
...
@@ -253,9 +257,12 @@ def test_shift_and_bound():
def
test_mix_index_bound
():
def
test_mix_index_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
24
-
1
))
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
24
-
1
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
3
-
1
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
3
-
1
))
bd
=
analyzer
.
const_int_bound
(
(
x
%
8
)
+
(
x
/
8
)
*
8
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
8
)
+
tdiv
(
x
,
8
)
*
8
)
assert
bd
.
min_value
==
0
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
24
-
1
assert
bd
.
max_value
==
24
-
1
...
@@ -263,7 +270,7 @@ def test_mix_index_bound():
...
@@ -263,7 +270,7 @@ def test_mix_index_bound():
assert
bd
.
min_value
==
0
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
24
*
3
-
1
assert
bd
.
max_value
==
24
*
3
-
1
bd
=
analyzer
.
const_int_bound
(
(
x
%
7
)
+
(
x
/
7
)
*
7
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
7
)
+
tdiv
(
x
,
7
)
*
7
)
assert
bd
.
min_value
==
0
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
(
23
//
7
)
*
7
+
6
assert
bd
.
max_value
==
(
23
//
7
)
*
7
+
6
...
@@ -273,8 +280,8 @@ if __name__ == "__main__":
...
@@ -273,8 +280,8 @@ if __name__ == "__main__":
test_cast_bound
()
test_cast_bound
()
test_add_sub_bound
()
test_add_sub_bound
()
test_mul_bound
()
test_mul_bound
()
test_div_bound
()
test_
trunc
div_bound
()
test_mod_bound
()
test_
trunc
mod_bound
()
test_floordiv_bound
()
test_floordiv_bound
()
test_floormod_bound
()
test_floormod_bound
()
test_min_max_bound
()
test_min_max_bound
()
...
...
tests/python/unittest/test_arith_deduce_bound.py
View file @
2ded2d8c
...
@@ -35,9 +35,11 @@ def test_deduce():
...
@@ -35,9 +35,11 @@ def test_deduce():
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
zero
=
tvm
.
const
(
0
,
"int32"
)
zero
=
tvm
.
const
(
0
,
"int32"
)
tdiv
=
tvm
.
truncdiv
e0
=
(
-
b
)
*
a
+
c
-
d
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
(
(
d
-
c
)
/
(
b
*-
1
)
+
(
-
1
))
ans0
=
(
tdiv
(
d
-
c
,
b
*-
1
)
+
(
-
1
))
assert_expr_equal
(
res0
.
max_value
,
ans0
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
# expression containing variable a is on rhs
...
@@ -46,7 +48,7 @@ def test_deduce():
...
@@ -46,7 +48,7 @@ def test_deduce():
e0
=
d
*
a
+
c
-
d
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
(
(
d
-
c
)
/
d
-
1
)
ans0
=
(
tdiv
(
d
-
c
,
d
)
-
1
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
# expression containing variable a is on rhs
...
@@ -56,7 +58,7 @@ def test_deduce():
...
@@ -56,7 +58,7 @@ def test_deduce():
e1
=
(
a
*
4
+
b
<
c
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(
((
c
-
b
)
+
-
1
)
/
4
-
1
)
ans1
=
(
tdiv
((
c
-
b
)
+
-
1
,
4
)
-
1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
...
@@ -79,7 +81,7 @@ def test_deduce():
...
@@ -79,7 +81,7 @@ def test_deduce():
e3
=
(
-
b
)
+
a
*
c
-
d
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
ans3
=
tdiv
(
2
,
c
)
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
...
...
tests/python/unittest/test_arith_intset.py
View file @
2ded2d8c
...
@@ -60,13 +60,14 @@ def test_add_sub():
...
@@ -60,13 +60,14 @@ def test_add_sub():
def
test_mul_div
():
def
test_mul_div
():
ck
=
IntSetChecker
()
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
*
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
*
y
))
ck
.
verify
(
x
*
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
*
y
))
ck
.
verify
(
x
*
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
2
,
20
))
ck
.
verify
(
x
*
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
2
,
20
))
ck
.
verify
(
x
*
-
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
-
20
,
-
2
))
ck
.
verify
(
x
*
-
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
-
20
,
-
2
))
ck
.
verify
(
x
/
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
/
y
))
ck
.
verify
(
tdiv
(
x
,
y
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
tdiv
(
10
,
y
)
))
ck
.
verify
(
x
/
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
ck
.
verify
(
tdiv
(
x
,
2
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
fld
=
tvm
.
floordiv
fld
=
tvm
.
floordiv
ck
.
verify
(
fld
(
x
,
y
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
fld
(
10
,
y
)))
ck
.
verify
(
fld
(
x
,
y
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
fld
(
10
,
y
)))
...
@@ -76,9 +77,10 @@ def test_mul_div():
...
@@ -76,9 +77,10 @@ def test_mul_div():
def
test_mod
():
def
test_mod
():
ck
=
IntSetChecker
()
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tmod
=
tvm
.
truncmod
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
%
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
tmod
(
x
,
y
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
x
%
10
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
ck
.
verify
(
tmod
(
x
,
10
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
flm
=
tvm
.
floormod
flm
=
tvm
.
floormod
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
-
10
,
10
)},
(
0
,
9
))
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
-
10
,
10
)},
(
0
,
9
))
...
...
tests/python/unittest/test_arith_modular_set.py
View file @
2ded2d8c
...
@@ -54,7 +54,8 @@ def test_div_shift():
...
@@ -54,7 +54,8 @@ def test_div_shift():
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
# not sure if x is non-negative
# not sure if x is non-negative
m
=
analyzer
.
modular_set
((
x
*
4
+
2
)
/
2
)
tdiv
=
tvm
.
truncdiv
m
=
analyzer
.
modular_set
(
tdiv
(
x
*
4
+
2
,
2
))
assert
m
.
coeff
==
1
assert
m
.
coeff
==
1
assert
m
.
base
==
0
assert
m
.
base
==
0
# right shift always round down so it is fine
# right shift always round down so it is fine
...
@@ -67,7 +68,7 @@ def test_div_shift():
...
@@ -67,7 +68,7 @@ def test_div_shift():
assert
m
.
base
==
1
assert
m
.
base
==
1
# x is non-negative
# x is non-negative
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
100
))
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
100
))
m
=
analyzer
.
modular_set
(
(
x
*
4
+
2
)
/
2
)
m
=
analyzer
.
modular_set
(
tdiv
(
x
*
4
+
2
,
2
)
)
assert
m
.
coeff
==
2
assert
m
.
coeff
==
2
assert
m
.
base
==
1
assert
m
.
base
==
1
...
@@ -92,6 +93,7 @@ def test_mix_index():
...
@@ -92,6 +93,7 @@ def test_mix_index():
a
=
tvm
.
var
(
"a"
)
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
b
=
tvm
.
var
(
"b"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
tdiv
=
tvm
.
truncdiv
m
=
analyzer
.
modular_set
(
a
*
4
+
b
*
6
+
7
)
m
=
analyzer
.
modular_set
(
a
*
4
+
b
*
6
+
7
)
assert
m
.
coeff
==
2
assert
m
.
coeff
==
2
assert
m
.
base
==
1
assert
m
.
base
==
1
...
@@ -100,11 +102,11 @@ def test_mix_index():
...
@@ -100,11 +102,11 @@ def test_mix_index():
assert
m
.
coeff
==
4
assert
m
.
coeff
==
4
assert
m
.
base
==
3
assert
m
.
base
==
3
m
=
analyzer
.
modular_set
(
(
a
*
4
+
1
)
/
(
b
*
8
+
3
))
m
=
analyzer
.
modular_set
(
tdiv
(
a
*
4
+
1
,
b
*
8
+
3
))
assert
m
.
coeff
==
1
assert
m
.
coeff
==
1
assert
m
.
base
==
0
assert
m
.
base
==
0
m
=
analyzer
.
modular_set
((
a
*
4
+
1
)
*
(
b
*
8
/
4
))
m
=
analyzer
.
modular_set
((
a
*
4
+
1
)
*
tdiv
(
b
*
8
,
4
))
assert
m
.
coeff
==
2
assert
m
.
coeff
==
2
assert
m
.
base
==
0
assert
m
.
base
==
0
...
@@ -121,11 +123,13 @@ def test_constraint_scope():
...
@@ -121,11 +123,13 @@ def test_constraint_scope():
a
=
tvm
.
var
(
"a"
)
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
b
=
tvm
.
var
(
"b"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
with
analyzer
.
constraint_scope
(
b
%
4
==
2
):
tmod
=
tvm
.
truncmod
with
analyzer
.
constraint_scope
(
tmod
(
b
,
4
)
==
2
):
m
=
analyzer
.
modular_set
(
b
+
1
)
m
=
analyzer
.
modular_set
(
b
+
1
)
assert
m
.
coeff
==
4
assert
m
.
coeff
==
4
assert
m
.
base
==
3
assert
m
.
base
==
3
with
analyzer
.
constraint_scope
(
a
%
2
==
1
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
2
)
==
1
):
m
=
analyzer
.
modular_set
(
b
+
a
*
2
)
m
=
analyzer
.
modular_set
(
b
+
a
*
2
)
assert
m
.
coeff
==
4
assert
m
.
coeff
==
4
assert
m
.
base
==
0
assert
m
.
base
==
0
...
@@ -140,15 +144,16 @@ def test_constraint_scope():
...
@@ -140,15 +144,16 @@ def test_constraint_scope():
def
test_intersect
():
def
test_intersect
():
a
=
tvm
.
var
(
"a"
)
a
=
tvm
.
var
(
"a"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
with
analyzer
.
constraint_scope
(
a
%
4
==
1
):
tmod
=
tvm
.
truncmod
with
analyzer
.
constraint_scope
(
a
%
3
==
1
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
4
)
==
1
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
3
)
==
1
):
m
=
analyzer
.
modular_set
(
a
)
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
12
assert
m
.
coeff
==
12
assert
m
.
base
==
1
assert
m
.
base
==
1
with
analyzer
.
constraint_scope
(
a
%
3
==
2
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
3
)
==
2
):
with
analyzer
.
constraint_scope
(
a
%
5
==
3
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
5
)
==
3
):
with
analyzer
.
constraint_scope
(
a
%
7
==
2
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
7
)
==
2
):
m
=
analyzer
.
modular_set
(
a
)
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
105
assert
m
.
coeff
==
105
assert
m
.
base
==
23
assert
m
.
base
==
23
...
...
tests/python/unittest/test_autotvm_flop_calculator.py
View file @
2ded2d8c
...
@@ -60,11 +60,14 @@ def test_pack_gemm():
...
@@ -60,11 +60,14 @@ def test_pack_gemm():
k
=
tvm
.
reduce_axis
((
0
,
L
))
k
=
tvm
.
reduce_axis
((
0
,
L
))
bn
=
4
bn
=
4
fld
=
tvm
.
floordiv
flm
=
tvm
.
floormod
A_pack
=
tvm
.
compute
((
N
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
A
[
i
*
bn
+
k
][
j
])
A_pack
=
tvm
.
compute
((
N
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
A
[
i
*
bn
+
k
][
j
])
B_pack
=
tvm
.
compute
((
M
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
B
[
i
*
bn
+
k
][
j
])
B_pack
=
tvm
.
compute
((
M
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
B
[
i
*
bn
+
k
][
j
])
C_pack
=
tvm
.
compute
((
N
//
bn
,
M
//
bn
,
bn
,
bn
),
lambda
i
,
j
,
ii
,
jj
:
C_pack
=
tvm
.
compute
((
N
//
bn
,
M
//
bn
,
bn
,
bn
),
lambda
i
,
j
,
ii
,
jj
:
tvm
.
sum
(
A_pack
[
i
,
k
,
ii
]
.
astype
(
acc_dtype
)
*
B_pack
[
j
,
k
,
jj
]
.
astype
(
acc_dtype
),
axis
=
[
k
]))
tvm
.
sum
(
A_pack
[
i
,
k
,
ii
]
.
astype
(
acc_dtype
)
*
B_pack
[
j
,
k
,
jj
]
.
astype
(
acc_dtype
),
axis
=
[
k
]))
C
=
tvm
.
compute
((
N
,
M
),
lambda
i
,
j
:
C_pack
[
i
//
bn
][
j
//
bn
][
i
%
bn
][
j
%
bn
])
C
=
tvm
.
compute
((
N
,
M
),
lambda
i
,
j
:
C_pack
[
fld
(
i
,
bn
)][
fld
(
j
,
bn
)][
flm
(
i
,
bn
)][
flm
(
j
,
bn
)
])
s
=
tvm
.
create_schedule
([
C
.
op
])
s
=
tvm
.
create_schedule
([
C
.
op
])
assert
compute_flop
(
s
)
==
2
*
N
*
L
*
M
assert
compute_flop
(
s
)
==
2
*
N
*
L
*
M
...
@@ -119,9 +122,11 @@ def test_average_pool():
...
@@ -119,9 +122,11 @@ def test_average_pool():
OH
=
(
H
-
KH
)
+
1
OH
=
(
H
-
KH
)
+
1
OW
=
(
W
-
KW
)
+
1
OW
=
(
W
-
KW
)
+
1
C
=
tvm
.
compute
(
C
=
tvm
.
compute
(
(
N
,
CO
,
OH
,
OW
),
(
N
,
CO
,
OH
,
OW
),
lambda
n
,
co
,
h
,
w
:
tvm
.
sum
(
D
[
n
][
co
][
h
+
kh
][
w
+
kw
]
.
astype
(
acc_dtype
)
/
(
KW
*
KH
),
axis
=
[
kh
,
kw
]))
lambda
n
,
co
,
h
,
w
:
tvm
.
sum
(
tvm
.
div
(
D
[
n
][
co
][
h
+
kh
][
w
+
kw
]
.
astype
(
acc_dtype
),
(
KW
*
KH
)),
axis
=
[
kh
,
kw
]))
s
=
tvm
.
create_schedule
([
C
.
op
])
s
=
tvm
.
create_schedule
([
C
.
op
])
...
...
tests/python/unittest/test_build_lower.py
View file @
2ded2d8c
...
@@ -35,7 +35,7 @@ def test_lower_rfactor():
...
@@ -35,7 +35,7 @@ def test_lower_rfactor():
def
test_dependent_output_shape
():
def
test_dependent_output_shape
():
n
,
m
,
x
=
tvm
.
var
(
'n'
),
tvm
.
var
(
'm'
),
tvm
.
var
(
'x'
)
n
,
m
,
x
=
tvm
.
var
(
'n'
),
tvm
.
var
(
'm'
),
tvm
.
var
(
'x'
)
A
=
tvm
.
placeholder
((
n
,
m
))
A
=
tvm
.
placeholder
((
n
,
m
))
B
=
tvm
.
compute
((
m
,
n
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
,
name
=
'B'
)
B
=
tvm
.
compute
((
m
,
n
/
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
,
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
mod
=
tvm
.
build
(
s
,
[
A
,
B
,
x
])
mod
=
tvm
.
build
(
s
,
[
A
,
B
,
x
])
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
2ded2d8c
...
@@ -409,7 +409,7 @@ def test_llvm_div():
...
@@ -409,7 +409,7 @@ def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++"""
"""Check that the semantics of div and mod is the same as in C/C++"""
def
check_div
(
start
,
end
,
divisor
,
dtype
):
def
check_div
(
start
,
end
,
divisor
,
dtype
):
T
=
tvm
.
compute
((
end
-
start
,),
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
))
/
tvm
.
const
(
divisor
,
dtype
))
lambda
i
:
tvm
.
div
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)
))
s
=
tvm
.
create_schedule
([
T
.
op
])
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
...
@@ -418,8 +418,9 @@ def test_llvm_div():
...
@@ -418,8 +418,9 @@ def test_llvm_div():
tvm
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
ref
)
tvm
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
ref
)
def
check_mod
(
start
,
end
,
divisor
,
dtype
):
def
check_mod
(
start
,
end
,
divisor
,
dtype
):
tmod
=
tvm
.
truncmod
T
=
tvm
.
compute
((
end
-
start
,),
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
t
vm
.
expr
.
Cast
(
dtype
,
(
start
+
i
))
%
tvm
.
const
(
divisor
,
dtype
))
lambda
i
:
t
mod
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)
))
s
=
tvm
.
create_schedule
([
T
.
op
])
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
...
@@ -443,7 +444,7 @@ def test_llvm_div():
...
@@ -443,7 +444,7 @@ def test_llvm_div():
def
test_llvm_fp_math
():
def
test_llvm_fp_math
():
def
check_llvm_reciprocal
(
n
):
def
check_llvm_reciprocal
(
n
):
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
1.0
/
(
1e+37
*
A
[
i
]
),
name
=
'B'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
div
(
1.0
,(
1e+37
*
A
[
i
])
),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"llvm"
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"llvm"
)
...
...
tests/python/unittest/test_ir_builder.py
View file @
2ded2d8c
...
@@ -41,8 +41,9 @@ def test_if():
...
@@ -41,8 +41,9 @@ def test_if():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
n
=
tvm
.
var
(
"n"
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
tmod
=
tvm
.
truncmod
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
if_scope
(
(
i
%
2
)
==
0
):
with
ib
.
if_scope
(
tmod
(
i
,
2
)
==
0
):
A
[
i
]
=
A
[
i
]
+
1
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
else_scope
():
with
ib
.
else_scope
():
A
[
0
]
=
A
[
i
]
+
2
A
[
0
]
=
A
[
i
]
+
2
...
@@ -108,13 +109,14 @@ def test_gpu():
...
@@ -108,13 +109,14 @@ def test_gpu():
dtype
=
"float32"
dtype
=
"float32"
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
fld
=
tvm
.
floordiv
def
test_device_ir
(
A
,
B
,
C
):
def
test_device_ir
(
A
,
B
,
C
):
n
=
A
.
shape
[
0
]
n
=
A
.
shape
[
0
]
max_threads
=
32
max_threads
=
32
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
(
n
+
max_threads
-
1
)
//
max_threads
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
fld
(
n
+
max_threads
-
1
,
max_threads
)
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
max_threads
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
max_threads
)
idx
=
bx
.
var
*
max_threads
+
tx
.
var
idx
=
bx
.
var
*
max_threads
+
tx
.
var
Aptr
=
ib
.
buffer_ptr
(
A
)
Aptr
=
ib
.
buffer_ptr
(
A
)
...
...
tests/python/unittest/test_lang_buffer.py
View file @
2ded2d8c
...
@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
...
@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
def
assert_simplified_equal
(
index_simplified
,
index_direct
):
def
assert_simplified_equal
(
index_simplified
,
index_direct
):
assert
tvm
.
ir_pass
.
Equal
(
index_simplified
,
index_direct
),
\
assert
tvm
.
ir_pass
.
Equal
(
index_simplified
,
index_direct
),
\
"index_simplified=
%
s, index_direct=
%
s"
%
(
index_simplified
,
index_direct
)
"index_simplified=
%
s, index_direct=
%
s"
%
(
index_simplified
,
index_direct
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# Test Case1
# Test Case1
index_simplified
=
A_stride
.
vload
(((
k0
%
k1
)
/
s
,
(
k0
%
k1
)
%
s
+
(
k0
/
k1
)
*
k1
))
index_simplified
=
A_stride
.
vload
(
(
idxdiv
(
idxmod
(
k0
,
k1
),
s
),
idxmod
(
idxmod
(
k0
,
k1
),
s
)
+
idxdiv
(
k0
,
k1
)
*
k1
))
index_direct
=
A_stride
.
vload
((
0
,
k0
))
index_direct
=
A_stride
.
vload
((
0
,
k0
))
assert_simplified_equal
(
index_simplified
,
index_direct
)
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case2
# Test Case2
index_simplified
=
A
.
vload
((
(
k0
%
(
k1
/
s
))
/
n
,
index_simplified
=
A
.
vload
((
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
,
(
k0
%
(
k1
/
s
))
%
n
+
(
k0
%
k1
)))
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
+
idxmod
(
k0
,
k1
)))
index_direct
=
A
.
vload
((
0
,
k0
%
k1
+
k0
%
(
k1
/
s
)))
index_direct
=
A
.
vload
((
0
,
idxmod
(
k0
,
k1
)
+
idxmod
(
k0
,
idxdiv
(
k1
,
s
)
)))
assert_simplified_equal
(
index_simplified
,
index_direct
)
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case3
# Test Case3
index_simplified
=
A
.
vload
((((
k0
/
(
k1
/
s
))
*
(
k1
/
s
))
/
n
+
(
k0
%
(
k1
/
s
))
/
n
,
index_simplified
=
A
.
vload
((
idxdiv
((
idxdiv
(
k0
,
idxdiv
(
k1
,
s
))
*
idxdiv
(
k1
,
s
)),
n
)
+
((
k0
/
(
k1
/
s
))
*
(
k1
/
s
))
%
n
+
(
k0
%
(
k1
/
s
))
%
n
))
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
),
idxmod
((
idxdiv
(
k0
,
idxdiv
(
k1
,
s
))
*
idxdiv
(
k1
,
s
)),
n
)
+
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)))
index_direct
=
A
.
vload
((
0
,
k0
))
index_direct
=
A
.
vload
((
0
,
k0
))
assert_simplified_equal
(
index_simplified
,
index_direct
)
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case4 (not able to simplify)
# Test Case4 (not able to simplify)
index_simplified
=
A
.
vload
(((
k0
%
(
k1
/
s
))
/
n
,
index_simplified
=
A
.
vload
((
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
),
(
k0
%
(
k1
/
n
))
%
n
+
(
k0
%
k1
)))
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
n
)),
n
)
+
idxmod
(
k0
,
k1
)))
index_direct
=
A
.
vload
((
0
,
((
k0
%
(
k1
/
s
))
/
n
)
*
n
+
((
k0
%
(
k1
/
n
))
%
n
+
(
k0
%
k1
))))
index_direct
=
A
.
vload
((
0
,
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
*
n
+
(
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
n
)),
n
)
+
idxmod
(
k0
,
k1
))))
assert_simplified_equal
(
index_simplified
,
index_direct
)
assert_simplified_equal
(
index_simplified
,
index_direct
)
...
@@ -143,14 +150,14 @@ def test_buffer_broadcast():
...
@@ -143,14 +150,14 @@ def test_buffer_broadcast():
check
()
check
()
def
test_b
buffer_
roadcast_expr
():
def
test_b
uffer_b
roadcast_expr
():
n0
,
m0
,
x
=
tvm
.
var
(
'n0'
),
tvm
.
var
(
'm0'
),
tvm
.
var
(
'x'
)
n0
,
m0
,
x
=
tvm
.
var
(
'n0'
),
tvm
.
var
(
'm0'
),
tvm
.
var
(
'x'
)
n1
,
m1
=
tvm
.
var
(
'n1'
),
tvm
.
var
(
'm1'
)
n1
,
m1
=
tvm
.
var
(
'n1'
),
tvm
.
var
(
'm1'
)
o0
,
o1
=
tvm
.
var
(
'o0'
),
tvm
.
var
(
'o1'
)
o0
,
o1
=
tvm
.
var
(
'o0'
),
tvm
.
var
(
'o1'
)
A
=
tvm
.
placeholder
((
m0
,
n0
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m0
,
n0
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m1
,
n1
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
m1
,
n1
),
name
=
'B'
)
C
=
tvm
.
compute
((
o0
,
o1
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
+
B
[
i
,
j
],
name
=
'C'
)
C
=
tvm
.
compute
((
o0
,
o1
/
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
+
B
[
i
,
j
],
name
=
'C'
)
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
name
=
"Ab"
,
buffer_type
=
"auto_broadcast"
)
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
name
=
"Ab"
,
buffer_type
=
"auto_broadcast"
)
Bb
=
tvm
.
decl_buffer
(
B
.
shape
,
B
.
dtype
,
name
=
"Bb"
,
buffer_type
=
"auto_broadcast"
)
Bb
=
tvm
.
decl_buffer
(
B
.
shape
,
B
.
dtype
,
name
=
"Bb"
,
buffer_type
=
"auto_broadcast"
)
...
...
tests/python/unittest/test_lang_operator.py
View file @
2ded2d8c
...
@@ -32,10 +32,11 @@ def test_const_fold():
...
@@ -32,10 +32,11 @@ def test_const_fold():
if
not
isinstance
(
x
,
(
tvm
.
expr
.
IntImm
,
tvm
.
expr
.
UIntImm
))
or
x
.
value
!=
int
(
y
):
if
not
isinstance
(
x
,
(
tvm
.
expr
.
IntImm
,
tvm
.
expr
.
UIntImm
))
or
x
.
value
!=
int
(
y
):
raise
ValueError
(
"check error:
%
s vs
%
s "
%
(
x
,
y
))
raise
ValueError
(
"check error:
%
s vs
%
s "
%
(
x
,
y
))
tmod
=
tvm
.
truncmod
check
(
lambda
x
,
y
:
x
+
y
,
3
,
4
)
check
(
lambda
x
,
y
:
x
+
y
,
3
,
4
)
check
(
lambda
x
,
y
:
x
*
y
,
3
,
12
)
check
(
lambda
x
,
y
:
x
*
y
,
3
,
12
)
check
(
lambda
x
,
y
:
x
*
y
-
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
*
y
-
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
-
y
%
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
-
tmod
(
y
,
10
)
,
3
,
12
)
check
(
lambda
x
,
y
:
x
//
y
+
10
,
100
,
12
)
check
(
lambda
x
,
y
:
x
//
y
+
10
,
100
,
12
)
check
(
lambda
x
,
y
:
x
&
y
+
10
,
112
,
128
)
check
(
lambda
x
,
y
:
x
&
y
+
10
,
112
,
128
)
check
(
lambda
x
,
y
:
x
>
y
,
112
,
128
)
check
(
lambda
x
,
y
:
x
>
y
,
112
,
128
)
...
@@ -47,13 +48,15 @@ def test_const_fold():
...
@@ -47,13 +48,15 @@ def test_const_fold():
def
test_const_fold2
():
def
test_const_fold2
():
x
=
tvm
.
var
(
"x"
)
x
=
tvm
.
var
(
"x"
)
tmod
=
tvm
.
truncmod
tdiv
=
tvm
.
truncdiv
assert
(
x
+
0
)
.
same_as
(
x
)
assert
(
x
+
0
)
.
same_as
(
x
)
assert
(
0
+
x
)
.
same_as
(
x
)
assert
(
0
+
x
)
.
same_as
(
x
)
assert
(
x
-
0
)
.
same_as
(
x
)
assert
(
x
-
0
)
.
same_as
(
x
)
assert
(
x
%
1
)
.
value
==
0
assert
tmod
(
x
,
1
)
.
value
==
0
assert
(
x
*
1
)
.
same_as
(
x
)
assert
(
x
*
1
)
.
same_as
(
x
)
assert
(
1
*
x
)
.
same_as
(
x
)
assert
(
1
*
x
)
.
same_as
(
x
)
assert
isinstance
(
(
1
/
x
),
tvm
.
expr
.
Div
)
assert
isinstance
(
tdiv
(
1
,
x
),
tvm
.
expr
.
Div
)
def
test_const_fold3
():
def
test_const_fold3
():
# Test that using ints with logic operations is forbidden
# Test that using ints with logic operations is forbidden
...
@@ -88,8 +91,9 @@ def test_const_fold3():
...
@@ -88,8 +91,9 @@ def test_const_fold3():
def
test_const_fold4
():
def
test_const_fold4
():
x1
=
tvm
.
const
(
4
,
"int32"
)
x1
=
tvm
.
const
(
4
,
"int32"
)
x2
=
x1
+
5
x2
=
x1
+
5
tdiv
=
tvm
.
truncdiv
assert
isinstance
(
x2
,
tvm
.
expr
.
IntImm
)
and
x2
.
value
==
9
assert
isinstance
(
x2
,
tvm
.
expr
.
IntImm
)
and
x2
.
value
==
9
x3
=
x2
/
3
x3
=
tdiv
(
x2
,
3
)
assert
isinstance
(
x3
,
tvm
.
expr
.
IntImm
)
and
x3
.
value
==
3
assert
isinstance
(
x3
,
tvm
.
expr
.
IntImm
)
and
x3
.
value
==
3
x4
=
x3
+
0.55
x4
=
x3
+
0.55
assert
isinstance
(
x4
,
tvm
.
expr
.
FloatImm
)
and
abs
(
x4
.
value
-
3.55
)
<
1e-6
assert
isinstance
(
x4
,
tvm
.
expr
.
FloatImm
)
and
abs
(
x4
.
value
-
3.55
)
<
1e-6
...
...
tests/python/unittest/test_lang_tensor_overload_op.py
View file @
2ded2d8c
...
@@ -72,7 +72,7 @@ def test_combination():
...
@@ -72,7 +72,7 @@ def test_combination():
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
m
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
n
,
m
),
name
=
'B'
)
C
=
tvm
.
placeholder
((
n
,
m
),
name
=
'C'
)
C
=
tvm
.
placeholder
((
n
,
m
),
name
=
'C'
)
D
=
k
+
A
-
B
*
C
/
x
D
=
k
+
A
-
B
*
C
+
x
s
=
tvm
.
create_schedule
(
D
.
op
)
s
=
tvm
.
create_schedule
(
D
.
op
)
foo
=
tvm
.
build
(
s
,
[
x
,
A
,
B
,
C
,
D
],
"llvm"
)
foo
=
tvm
.
build
(
s
,
[
x
,
A
,
B
,
C
,
D
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
ctx
=
tvm
.
cpu
(
0
)
...
@@ -82,7 +82,7 @@ def test_combination():
...
@@ -82,7 +82,7 @@ def test_combination():
c
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
m
))
.
astype
(
C
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
m
))
.
astype
(
C
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
foo
(
x
,
a
,
b
,
c
,
d
)
foo
(
x
,
a
,
b
,
c
,
d
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
k
+
a
.
asnumpy
()
-
b
.
asnumpy
()
*
c
.
asnumpy
()
/
x
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
k
+
a
.
asnumpy
()
-
b
.
asnumpy
()
*
c
.
asnumpy
()
+
x
)
def
verify_tensor_scalar_bop
(
shape
,
typ
=
"add"
):
def
verify_tensor_scalar_bop
(
shape
,
typ
=
"add"
):
...
...
tests/python/unittest/test_pass_basic.py
View file @
2ded2d8c
...
@@ -17,13 +17,15 @@
...
@@ -17,13 +17,15 @@
import
tvm
import
tvm
def
test_simplify
():
def
test_simplify
():
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
x
=
tvm
.
var
(
'x'
)
x
=
tvm
.
var
(
'x'
)
e1
=
tvm
.
ir_pass
.
Simplify
(
x
+
2
+
1
)
e1
=
tvm
.
ir_pass
.
Simplify
(
x
+
2
+
1
)
assert
(
tvm
.
ir_pass
.
Equal
(
e1
,
x
+
3
))
assert
(
tvm
.
ir_pass
.
Equal
(
e1
,
x
+
3
))
e2
=
tvm
.
ir_pass
.
Simplify
(
x
*
3
+
5
*
x
)
e2
=
tvm
.
ir_pass
.
Simplify
(
x
*
3
+
5
*
x
)
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
x
/
3
*
3
)
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
tdiv
(
x
,
3
)
*
3
)
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
t
vm
.
make
.
M
od
(
x
,
3
)))
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
t
m
od
(
x
,
3
)))
def
test_verify_ssa
():
def
test_verify_ssa
():
...
...
tests/python/unittest/test_pass_equal.py
View file @
2ded2d8c
...
@@ -24,7 +24,7 @@ def test_equal_expr():
...
@@ -24,7 +24,7 @@ def test_equal_expr():
return
x
+
y
+
1
return
x
+
y
+
1
def
func2
():
def
func2
():
return
tvm
.
exp
(
(
x
+
y
+
1
)
*
y
/
4
)
return
tvm
.
exp
(
tvm
.
truncdiv
((
x
+
y
+
1
)
*
y
,
4
)
)
assert
tvm
.
ir_pass
.
Equal
(
func1
(),
func1
())
assert
tvm
.
ir_pass
.
Equal
(
func1
(),
func1
())
assert
tvm
.
ir_pass
.
Equal
(
func2
(),
func2
())
assert
tvm
.
ir_pass
.
Equal
(
func2
(),
func2
())
...
...
tests/python/unittest/test_pass_loop_partition.py
View file @
2ded2d8c
...
@@ -162,7 +162,7 @@ def test_condition():
...
@@ -162,7 +162,7 @@ def test_condition():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
m
=
tvm
.
var
(
'm'
)
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
with
ib
.
for_range
(
0
,
((
n
+
3
)
/
4
),
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
tvm
.
truncdiv
(
n
+
3
,
4
),
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
4
,
'j'
)
as
j
:
with
ib
.
for_range
(
0
,
4
,
'j'
)
as
j
:
ib
.
emit
(
tvm
.
make
.
Evaluate
(
ib
.
emit
(
tvm
.
make
.
Evaluate
(
tvm
.
make
.
Select
(
ib
.
likely
(
i
*
4
+
j
<
n
),
m
,
n
)))
tvm
.
make
.
Select
(
ib
.
likely
(
i
*
4
+
j
<
n
),
m
,
n
)))
...
@@ -206,7 +206,7 @@ def test_everything_during_deduction():
...
@@ -206,7 +206,7 @@ def test_everything_during_deduction():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
with
ib
.
for_range
(
0
,
n
,
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
32
,
'j'
)
as
j
:
with
ib
.
for_range
(
0
,
32
,
'j'
)
as
j
:
with
ib
.
if_scope
(
ib
.
likely
(
i
/
j
<
m
)):
with
ib
.
if_scope
(
ib
.
likely
(
tvm
.
truncdiv
(
i
,
j
)
<
m
)):
# this guard will produce everything during deduction
# this guard will produce everything during deduction
ib
.
emit
(
tvm
.
make
.
Evaluate
(
m
))
ib
.
emit
(
tvm
.
make
.
Evaluate
(
m
))
stmt
=
ib
.
get
()
stmt
=
ib
.
get
()
...
...
tests/python/unittest/test_schedule_bound_inference.py
View file @
2ded2d8c
...
@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
...
@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
container
.
Map
)
assert
isinstance
(
bounds
,
tvm
.
container
.
Map
)
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
min
-
(
xo
*
split1
)
/
l
)
.
value
==
0
)
idxdiv
=
tvm
.
indexdiv
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
min
-
idxdiv
(
xo
*
split1
,
l
))
.
value
==
0
)
expected_extent
=
(
((
xo
+
1
)
*
split1
-
1
)
/
l
-
(
xo
*
split1
)
/
l
+
1
)
expected_extent
=
(
idxdiv
((
xo
+
1
)
*
split1
-
1
,
l
)
-
idxdiv
(
xo
*
split1
,
l
)
+
1
)
for
i
in
range
(
1
,
6
):
for
i
in
range
(
1
,
6
):
for
j
in
range
(
1
,
6
):
for
j
in
range
(
1
,
6
):
for
k
in
range
(
1
,
6
):
for
k
in
range
(
1
,
6
):
...
@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
...
@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
comp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
,
vars
))
.
value
comp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
,
vars
))
.
value
exp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
expected_extent
,
vars
))
.
value
exp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
expected_extent
,
vars
))
.
value
assert
(
comp_ext
==
exp_ext
)
assert
(
comp_ext
==
exp_ext
)
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
1
]]
.
extent
-
l
)
.
value
==
0
)
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
1
]]
.
extent
-
l
)
.
value
==
0
)
def
test_bound_fusesplit2
():
def
test_bound_fusesplit2
():
...
@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
...
@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
if
not
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
:
if
not
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
:
print
(
stmt
)
print
(
stmt
)
assert
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
assert
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
tdiv
=
tvm
.
truncdiv
# These are hard to simplify, moreover we don't simplify them
# These are hard to simplify, moreover we don't simplify them
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
min
(
-
3
*
i
,
-
2
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
min
(
-
3
*
i
,
-
2
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
max
(
-
3
*
i
,
-
4
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
max
(
-
3
*
i
,
-
4
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
-
2
*
(
i
/
2
)
-
tvm
.
min
(
i
,
0
-
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
-
2
*
tdiv
(
i
,
2
)
-
tvm
.
min
(
i
,
0
-
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
+
(
0
-
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
+
(
0
-
i
)]))
# This would cause out of bounds, but we nevertheless include it
# This would cause out of bounds, but we nevertheless include it
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
]))
...
...
tests/python/unittest/test_schedule_tensorize.py
View file @
2ded2d8c
...
@@ -221,11 +221,14 @@ def test_tensorize_matmul():
...
@@ -221,11 +221,14 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
# as much as possible first and then checked for equality. See Issue #696
def
test_tensorize_op
():
def
test_tensorize_op
():
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
def
op_intrin
():
def
op_intrin
():
bh
=
9
bh
=
9
bw
=
9
bw
=
9
x
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
x
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
y
=
tvm
.
compute
((
bh
,
bw
),
lambda
i
,
j
:
x
[
j
/
3
+
i
%
3
,
j
%
3
+
i
/
3
])
y
=
tvm
.
compute
((
bh
,
bw
),
lambda
i
,
j
:
x
[
tdiv
(
j
,
3
)
+
tmod
(
i
,
3
),
tmod
(
j
,
3
)
+
tdiv
(
i
,
3
)])
def
intrin_func
(
ins
,
outs
):
def
intrin_func
(
ins
,
outs
):
xx
,
=
ins
xx
,
=
ins
...
@@ -236,7 +239,7 @@ def test_tensorize_op():
...
@@ -236,7 +239,7 @@ def test_tensorize_op():
return
tvm
.
decl_tensor_intrin
(
y
.
op
,
intrin_func
)
return
tvm
.
decl_tensor_intrin
(
y
.
op
,
intrin_func
)
A
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
B
=
tvm
.
compute
((
9
,
9
),
lambda
i
,
j
:
A
[
j
/
3
+
i
%
3
,
j
%
3
+
i
/
3
])
B
=
tvm
.
compute
((
9
,
9
),
lambda
i
,
j
:
A
[
tdiv
(
j
,
3
)
+
tmod
(
i
,
3
),
tmod
(
j
,
3
)
+
tdiv
(
i
,
3
)
])
bt
=
op_intrin
()
bt
=
op_intrin
()
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
...
...
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
View file @
2ded2d8c
...
@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
...
@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
kernel_vec
[
co
,
ci
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
kernel_vec
[
co
,
ci
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_output'
)
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_output'
)
return
output
return
output
...
...
topi/python/topi/arm_cpu/conv2d_transpose.py
View file @
2ded2d8c
...
@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
...
@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
kernel_vec
[
co
,
ci
,
KH
-
1
-
kh
,
KW
-
1
-
kw
,
vc
]
.
astype
(
out_dtype
),
kernel_vec
[
co
,
ci
,
KH
-
1
-
kh
,
KW
-
1
-
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_transpose_output'
)
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_transpose_output'
)
return
output
return
output
...
...
topi/python/topi/arm_cpu/depthwise_conv2d.py
View file @
2ded2d8c
...
@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
...
@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
kh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'kh'
)
kh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'kw'
)
kw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'kw'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
if
dilation_h
!=
1
or
dilation_w
!=
1
:
if
dilation_h
!=
1
or
dilation_w
!=
1
:
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
conv
=
tvm
.
compute
(
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
(
co
*
VC
+
vc
)
//
M
,
kh
,
kw
,
vh
,
vw
]
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
.
astype
(
out_dtype
)
*
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
idxdiv
(
co
*
VC
+
vc
,
M
),
kh
,
kw
,
vh
,
vw
]
kernel_vec
[
co
//
M
,
co
%
M
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
.
astype
(
out_dtype
)
*
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
kernel_vec
[
idxdiv
(
co
,
M
),
idxmod
(
co
,
M
),
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
else
:
else
:
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
(
co
*
VC
+
vc
)
//
M
,
vh
*
HSTR
+
kh
,
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
idxdiv
((
co
*
VC
+
vc
),
M
)
,
vh
*
HSTR
+
kh
,
vw
*
WSTR
+
kw
]
.
astype
(
out_dtype
)
*
vw
*
WSTR
+
kw
]
.
astype
(
out_dtype
)
*
kernel_vec
[
co
//
M
,
co
%
M
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
kernel_vec
[
idxdiv
(
co
,
M
),
idxmod
(
co
,
M
),
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_depthwise_conv_nchw_output'
)
name
=
'output_unpack'
,
tag
=
'spatial_depthwise_conv_nchw_output'
)
return
output
return
output
...
...
topi/python/topi/cuda/conv2d_transpose_nchw.py
View file @
2ded2d8c
...
@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
...
@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
[
0
,
0
,
(
bpad_bottom
+
stride_h
-
1
)
//
stride_h
,
[
0
,
0
,
(
bpad_bottom
+
stride_h
-
1
)
//
stride_h
,
(
bpad_right
+
stride_w
-
1
)
//
stride_w
],
name
=
'FirstPad'
)
(
bpad_right
+
stride_w
-
1
)
//
stride_w
],
name
=
'FirstPad'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# remove extra padding introduced by dilatation
# remove extra padding introduced by dilatation
border_h
=
(
stride_h
-
bpad_top
%
stride_h
)
%
stride_h
border_h
=
idxmod
(
stride_h
-
idxmod
(
bpad_top
,
stride_h
),
stride_h
)
border_w
=
(
stride_w
-
bpad_left
%
stride_w
)
%
stride_w
border_w
=
idxmod
(
stride_w
-
idxmod
(
bpad_left
,
stride_w
),
stride_w
)
# dilation stage
# dilation stage
data
=
FirstPad
data
=
FirstPad
...
@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
...
@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple
=
[]
index_tuple
=
[]
for
i
in
range
(
n
):
for
i
in
range
(
n
):
if
not
equal_const_int
(
strides
[
i
],
1
):
if
not
equal_const_int
(
strides
[
i
],
1
):
index_tuple
.
append
(
i
ndices
[
i
]
//
strides
[
i
]
)
index_tuple
.
append
(
i
dxdiv
(
indices
[
i
],
strides
[
i
])
)
not_zero
.
append
(
(
indices
[
i
]
%
strides
[
i
])
.
equal
(
0
))
not_zero
.
append
(
idxmod
(
indices
[
i
],
strides
[
i
])
.
equal
(
0
))
else
:
else
:
index_tuple
.
append
(
indices
[
i
])
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
if
not_zero
:
...
...
topi/python/topi/cuda/conv2d_winograd.py
View file @
2ded2d8c
...
@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
...
@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
else
:
else
:
kernel_pack
=
kernel
kernel_pack
=
kernel
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# pack input tile
# pack input tile
input_tile
=
tvm
.
compute
((
CI
,
P
,
alpha
,
alpha
),
lambda
c
,
p
,
eps
,
nu
:
input_tile
=
tvm
.
compute
((
CI
,
P
,
alpha
,
alpha
),
lambda
c
,
p
,
eps
,
nu
:
data_pad
[
p
//
(
nH
*
nW
)][
c
][
p
//
nW
%
nH
*
m
+
eps
]
data_pad
[
idxdiv
(
p
,
(
nH
*
nW
))][
c
][
idxmod
(
idxdiv
(
p
,
nW
),
nH
)
*
m
+
eps
]
[
p
%
nW
*
m
+
nu
],
name
=
'd'
)
[
idxmod
(
p
,
nW
)
*
m
+
nu
],
name
=
'd'
)
# transform data
# transform data
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
...
@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
...
@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
# output
# output
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
inverse
[
co
][
n
*
nH
*
nW
+
(
h
//
m
)
*
nW
+
w
//
m
][
h
%
m
][
w
%
m
],
inverse
[
co
,
n
*
nH
*
nW
+
idxdiv
(
h
,
m
)
*
nW
+
idxdiv
(
w
,
m
),
idxmod
(
h
,
m
),
idxmod
(
w
,
m
)],
name
=
'output'
,
tag
=
'conv2d_nchw_winograd'
)
name
=
'output'
,
tag
=
'conv2d_nchw_winograd'
)
cfg
.
add_flop
(
2
*
N
*
CO
*
H
*
W
*
CI
*
KH
*
KW
)
cfg
.
add_flop
(
2
*
N
*
CO
*
H
*
W
*
CI
*
KH
*
KW
)
...
...
topi/python/topi/cuda/nms.py
View file @
2ded2d8c
...
@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
...
@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
new_range
=
num_anchors
//
elem_per_thread
+
1
new_range
=
num_anchors
//
elem_per_thread
+
1
# Scan: Downsweep:
# Scan: Downsweep:
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
# number of batches
i
=
tid
/
/
num_anchors
# number of batches
j
=
tid
%
num_anchors
# number of anchors
j
=
tid
%
num_anchors
# number of anchors
with
ib
.
if_scope
(
j
<
elem_per_thread
):
with
ib
.
if_scope
(
j
<
elem_per_thread
):
idx
[
tid
]
=
idx_in
[
tid
]
idx
[
tid
]
=
idx_in
[
tid
]
...
@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
...
@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
tid
=
bx
*
max_threads
+
tx
tid
=
bx
*
max_threads
+
tx
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
i
=
tid
/
/
num_anchors
j
=
tid
%
num_anchors
j
=
tid
%
num_anchors
base_idx
=
i
*
num_anchors
*
elem_length
base_idx
=
i
*
num_anchors
*
elem_length
with
ib
.
if_scope
(
flag
[
tid
]
>
0
):
with
ib
.
if_scope
(
flag
[
tid
]
>
0
):
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
2ded2d8c
...
@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
...
@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
tid
=
bx
*
max_threads
+
tx
tid
=
bx
*
max_threads
+
tx
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
i
=
tid
/
/
num_anchors
j
=
tid
%
num_anchors
j
=
tid
%
num_anchors
with
ib
.
if_scope
(
cls_id
[
tid
]
>
0
):
with
ib
.
if_scope
(
cls_id
[
tid
]
>
0
):
with
ib
.
if_scope
(
tid
==
0
):
with
ib
.
if_scope
(
tid
==
0
):
...
...
topi/python/topi/mali/conv2d.py
View file @
2ded2d8c
...
@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
...
@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm
.
sum
(
input_tile
[
ci
][
p
][
r_a
][
r_b
][
vp
]
*
B
[
r_a
][
eps
]
*
B
[
r_b
][
nu
],
tvm
.
sum
(
input_tile
[
ci
][
p
][
r_a
][
r_b
][
vp
]
*
B
[
r_a
][
eps
]
*
B
[
r_b
][
nu
],
axis
=
[
r_a
,
r_b
]),
name
=
'V'
)
axis
=
[
r_a
,
r_b
]),
name
=
'V'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# batch gemm
# batch gemm
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'c'
)
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'c'
)
M
=
tvm
.
compute
((
alpha
,
alpha
,
CO
,
P_round
),
lambda
eps
,
nu
,
co
,
p
:
M
=
tvm
.
compute
((
alpha
,
alpha
,
CO
,
P_round
),
lambda
eps
,
nu
,
co
,
p
:
tvm
.
sum
(
U
[
eps
][
nu
][
co
//
bna
][
ci
][
co
%
bna
]
*
tvm
.
sum
(
U
[
eps
][
nu
][
idxdiv
(
co
,
bna
)][
ci
][
idxmod
(
co
,
bna
)
]
*
V
[
eps
][
nu
][
p
//
bnb
][
ci
][
p
%
bnb
],
axis
=
ci
),
name
=
'M'
)
V
[
eps
][
nu
][
idxdiv
(
p
,
bnb
)][
ci
][
idxmod
(
p
,
bnb
)
],
axis
=
ci
),
name
=
'M'
)
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
r_b
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_b'
)
r_b
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_b'
)
...
@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
...
@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output
# unpack output
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
Y
[
co
][
n
*
nH
*
nW
+
(
h
//
m
)
*
nW
+
w
//
m
][
h
%
m
][
w
%
m
]
Y
[
co
,
n
*
nH
*
nW
+
idxdiv
(
h
,
m
)
*
nW
+
idxdiv
(
w
,
m
),
idxmod
(
h
,
m
),
idxmod
(
w
,
m
)]
# The following hack term is used to make the padding in batch gemm ("M")
# The following hack term is used to make the padding in batch gemm ("M")
# effective, otherwise the padding will be eliminated by bound inference.
# effective, otherwise the padding will be eliminated by bound inference.
# Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
# Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
...
...
topi/python/topi/nn/bitserial_conv2d.py
View file @
2ded2d8c
...
@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis
=
[
ci
,
dh
,
dw
,
b1
,
b2
])
axis
=
[
ci
,
dh
,
dw
,
b1
,
b2
])
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv_out'
)
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv_out'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
return
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
return
tvm
.
compute
(
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
oshape
,
lambda
n
,
co
,
h
,
w
:
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
conv
[
n
][
idxdiv
(
co
,
VC
)][
idxdiv
(
h
,
VH
)][
idxdiv
(
w
,
VW
)][
idxmod
(
h
,
VH
)][
idxmod
(
w
,
VW
)][
idxmod
(
co
,
VC
)],
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'cpu'
,
'direct'
)
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'cpu'
,
'direct'
)
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
...
@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
...
@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv'
)
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
idxdiv
=
tvm
.
indexdiv
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
idxmod
=
tvm
.
indexmod
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv
[
n
][
idxdiv
(
h
,
VH
)][
idxdiv
(
w
,
VW
)][
idxdiv
(
co
,
VC
)][
idxmod
(
h
,
VH
)][
idxmod
(
w
,
VW
)][
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
@tvm.target.generic_func
@tvm.target.generic_func
def
bitserial_conv2d_legalize
(
attrs
,
inputs
,
types
):
def
bitserial_conv2d_legalize
(
attrs
,
inputs
,
types
):
...
...
topi/python/topi/nn/bitserial_dense.py
View file @
2ded2d8c
...
@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
...
@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
weight_vec
=
tvm
.
compute
(
wvshape
,
lambda
xo
,
wb
,
vx
,
k
:
weight_vec
=
tvm
.
compute
(
wvshape
,
lambda
xo
,
wb
,
vx
,
k
:
weight_packed
[
xo
*
VX
+
vx
][
wb
][
k
],
name
=
'weight_vec'
)
weight_packed
[
xo
*
VX
+
vx
][
wb
][
k
],
name
=
'weight_vec'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
matmul_unipolar
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
matmul_unipolar
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
(
tvm
.
popcount
(
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
])
-
(
tvm
.
popcount
(
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
])
-
tvm
.
popcount
(
~
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
]))
.
astype
(
out_dtype
)
tvm
.
popcount
(
~
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
])
)
.
astype
(
out_dtype
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense_unipolar'
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense_unipolar'
)
matmul
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
matmul
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
tvm
.
popcount
(
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
])
.
astype
(
out_dtype
)
tvm
.
popcount
(
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
]
)
.
astype
(
out_dtype
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense'
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense'
)
# binary ops
# binary ops
...
...
topi/python/topi/nn/conv2d.py
View file @
2ded2d8c
...
@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
...
@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
tvm
.
sum
(
data_pad
[
n
,
i
c
//
ic_bn
,
i
dxdiv
(
ic
,
ic_bn
)
,
oh
*
HSTR
+
kh
*
dilation_h
,
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
ow
*
WSTR
+
kw
*
dilation_w
,
i
c
%
ic_bn
]
.
astype
(
out_dtype
)
i
dxmod
(
ic
,
ic_bn
)
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
*
kernel
[
oc_chunk
,
i
c
//
ic_bn
,
i
dxdiv
(
ic
,
ic_bn
)
,
kh
,
kh
,
kw
,
kw
,
i
c
%
ic_bn
,
i
dxmod
(
ic
,
ic_bn
)
,
oc_block
],
oc_block
],
axis
=
[
ic
,
kh
,
kw
]),
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
...
...
topi/python/topi/nn/depthwise_conv2d.py
View file @
2ded2d8c
...
@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
...
@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
# depthconv stage
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
Output
=
tvm
.
compute
(
Output
=
tvm
.
compute
(
(
batch
,
out_channel
,
out_height
,
out_width
),
(
batch
,
out_channel
,
out_height
,
out_width
),
lambda
b
,
c
,
i
,
j
:
tvm
.
sum
(
lambda
b
,
c
,
i
,
j
:
tvm
.
sum
(
(
PaddedInput
[
b
,
c
/
channel_multiplier
,
i
*
stride_h
+
di
*
dilation_h
,
(
PaddedInput
[
b
,
idxdiv
(
c
,
channel_multiplier
)
,
i
*
stride_h
+
di
*
dilation_h
,
j
*
stride_w
+
dj
*
dilation_w
]
.
astype
(
out_dtype
)
*
j
*
stride_w
+
dj
*
dilation_w
]
.
astype
(
out_dtype
)
*
Filter
[
c
/
channel_multiplier
,
c
%
channel_multiplier
,
di
,
dj
]
.
astype
(
out_dtype
)),
Filter
[
idxdiv
(
c
,
channel_multiplier
),
idxmod
(
c
,
channel_multiplier
),
di
,
dj
]
.
astype
(
out_dtype
)),
axis
=
[
di
,
dj
]),
axis
=
[
di
,
dj
]),
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nchw"
)
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nchw"
)
return
Output
return
Output
...
@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
...
@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after
=
[
0
,
pad_down
,
pad_right
,
0
]
pad_after
=
[
0
,
pad_down
,
pad_right
,
0
]
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
# depthconv stage
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
Output
=
tvm
.
compute
(
Output
=
tvm
.
compute
(
(
batch
,
out_height
,
out_width
,
out_channel
),
(
batch
,
out_height
,
out_width
,
out_channel
),
lambda
b
,
i
,
j
,
c
:
tvm
.
sum
(
lambda
b
,
i
,
j
,
c
:
tvm
.
sum
(
(
PaddedInput
[
b
,
i
*
stride_h
+
di
*
dilation_h
,
j
*
stride_w
+
dj
*
dilation_w
,
(
PaddedInput
[
b
,
i
*
stride_h
+
di
*
dilation_h
,
j
*
stride_w
+
dj
*
dilation_w
,
c
/
channel_multiplier
]
.
astype
(
out_dtype
)
*
idxdiv
(
c
,
channel_multiplier
)]
.
astype
(
out_dtype
)
*
Filter
[
di
,
dj
,
c
/
channel_multiplier
,
c
%
channel_multiplier
]
.
astype
(
out_dtype
)),
Filter
[
di
,
dj
,
idxdiv
(
c
,
channel_multiplier
),
idxmod
(
c
,
channel_multiplier
)]
.
astype
(
out_dtype
)),
axis
=
[
di
,
dj
]),
axis
=
[
di
,
dj
]),
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nhwc"
)
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nhwc"
)
return
Output
return
Output
...
@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
...
@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
dh
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
1
]
.
value
),
name
=
'dh'
)
dh
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
1
]
.
value
),
name
=
'dh'
)
dw
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
2
]
.
value
),
name
=
'dw'
)
dw
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
2
]
.
value
),
name
=
'dw'
)
db
=
tvm
.
reduce_axis
((
0
,
batch
),
name
=
'db'
)
db
=
tvm
.
reduce_axis
((
0
,
batch
),
name
=
'db'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
Weight_grad
=
tvm
.
compute
(
Weight_grad
=
tvm
.
compute
(
(
filter_h
,
filter_w
,
in_c
,
channel_multiplier
),
(
filter_h
,
filter_w
,
in_c
,
channel_multiplier
),
lambda
fh
,
fw
,
c
,
m
:
tvm
.
sum
(
lambda
fh
,
fw
,
c
,
m
:
tvm
.
sum
(
Out_grad
[
db
,
dh
,
dw
,
c
*
channel_multiplier
+
m
%
channel_multiplier
]
*
Out_grad
[
db
,
dh
,
dw
,
c
*
channel_multiplier
+
idxmod
(
m
,
channel_multiplier
)
]
*
padded_in
[
db
,
fh
+
dh
*
stride_h
,
fw
+
dw
*
stride_w
,
c
],
axis
=
[
db
,
dh
,
dw
]),
padded_in
[
db
,
fh
+
dh
*
stride_h
,
fw
+
dw
*
stride_w
,
c
],
axis
=
[
db
,
dh
,
dw
]),
tag
=
'depthwise_conv2d_backward_weight_nhwc'
)
tag
=
'depthwise_conv2d_backward_weight_nhwc'
)
...
...
topi/python/topi/nn/dilate.py
View file @
2ded2d8c
...
@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
...
@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
def
_dilate
(
*
indices
):
def
_dilate
(
*
indices
):
not_zero
=
[]
not_zero
=
[]
index_tuple
=
[]
index_tuple
=
[]
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
for
i
in
range
(
n
):
for
i
in
range
(
n
):
if
not
util
.
equal_const_int
(
strides
[
i
],
1
):
if
not
util
.
equal_const_int
(
strides
[
i
],
1
):
index_tuple
.
append
(
i
ndices
[
i
]
/
strides
[
i
]
)
index_tuple
.
append
(
i
dxdiv
(
indices
[
i
],
strides
[
i
])
)
not_zero
.
append
(
(
indices
[
i
]
%
strides
[
i
])
.
equal
(
0
))
not_zero
.
append
(
idxmod
(
indices
[
i
],
strides
[
i
])
.
equal
(
0
))
else
:
else
:
index_tuple
.
append
(
indices
[
i
])
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
if
not_zero
:
...
...
topi/python/topi/nn/flatten.py
View file @
2ded2d8c
...
@@ -38,12 +38,14 @@ def flatten(data):
...
@@ -38,12 +38,14 @@ def flatten(data):
for
i
in
range
(
1
,
len
(
ishape
)):
for
i
in
range
(
1
,
len
(
ishape
)):
dim
=
dim
*
ishape
[
i
]
dim
=
dim
*
ishape
[
i
]
oshape
=
[
ishape
[
0
],
dim
]
oshape
=
[
ishape
[
0
],
dim
]
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
def
unwrap
(
idx
,
shape
):
def
unwrap
(
idx
,
shape
):
index
=
[]
index
=
[]
for
s
in
reversed
(
shape
):
for
s
in
reversed
(
shape
):
index
.
append
(
idx
%
s
)
index
.
append
(
idx
mod
(
idx
,
s
)
)
idx
=
idx
/
s
idx
=
idx
div
(
idx
,
s
)
return
list
(
reversed
(
index
))
return
list
(
reversed
(
index
))
return
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
data
(
i
,
*
unwrap
(
j
,
ishape
[
1
:])))
return
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
data
(
i
,
*
unwrap
(
j
,
ishape
[
1
:])))
topi/python/topi/x86/conv2d.py
View file @
2ded2d8c
...
@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
...
@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
ic
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'ic'
)
ic
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'ic'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
idxmod
=
tvm
.
indexmod
idxdiv
=
tvm
.
indexdiv
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_vec
[
n
,
ic
//
ic_bn
,
oh
*
HSTR
+
kh
*
dilation_h
,
ic
%
ic_bn
,
tvm
.
sum
(
data_vec
[
n
,
idxdiv
(
ic
,
ic_bn
),
oh
*
HSTR
+
kh
*
dilation_h
,
idxmod
(
ic
,
ic_bn
),
ow
*
WSTR
+
kw
*
dilation_w
]
.
astype
(
out_dtype
)
*
ow
*
WSTR
+
kw
*
dilation_w
]
.
astype
(
out_dtype
)
*
kernel_vec
[
oc_chunk
,
ic
//
ic_bn
,
kh
,
kw
,
ic
%
ic_bn
,
kernel_vec
[
oc_chunk
,
idxdiv
(
ic
,
ic_bn
),
kh
,
kw
,
idxmod
(
ic
,
ic_bn
),
oc_block
]
.
astype
(
out_dtype
),
oc_block
]
.
astype
(
out_dtype
),
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv'
)
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv'
)
unpack
=
tvm
.
compute
(
unpack_shape
,
unpack
=
tvm
.
compute
(
unpack_shape
,
lambda
n
,
c
,
h
,
w
:
conv
[
n
,
c
//
oc_bn
,
h
,
w
,
c
%
oc_bn
]
lambda
n
,
c
,
h
,
w
:
conv
[
n
,
idxdiv
(
c
,
oc_bn
),
h
,
w
,
idxmod
(
c
,
oc_bn
)
]
.
astype
(
out_dtype
),
.
astype
(
out_dtype
),
name
=
'output_unpack'
,
name
=
'output_unpack'
,
tag
=
'conv2d_nchw'
)
tag
=
'conv2d_nchw'
)
...
@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
...
@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
cfg
=
get_config
()
cfg
=
get_config
()
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
origin_layout
)
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
origin_layout
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# change shape with the value in config
# change shape with the value in config
ic_bn
,
oc_bn
,
ow_bn
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
ic_bn
,
oc_bn
,
ow_bn
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
cfg
[
"tile_ow"
]
.
size
[
-
1
])
cfg
[
"tile_ow"
]
.
size
[
-
1
])
new_data_shape
=
(
raw_data_shape
[
0
],
raw_data_shape
[
1
]
//
ic_bn
,
new_data_shape
=
(
raw_data_shape
[
0
],
idxdiv
(
raw_data_shape
[
1
],
ic_bn
)
,
raw_data_shape
[
2
],
raw_data_shape
[
3
],
ic_bn
)
raw_data_shape
[
2
],
raw_data_shape
[
3
],
ic_bn
)
data_layout
=
"NCHW
%
dc"
%
ic_bn
data_layout
=
"NCHW
%
dc"
%
ic_bn
out_layout
=
"NCHW
%
dc"
%
oc_bn
out_layout
=
"NCHW
%
dc"
%
oc_bn
new_kernel_shape
=
(
raw_kernel_shape
[
0
]
//
oc_bn
,
raw_kernel_shape
[
1
]
//
ic_bn
,
new_kernel_shape
=
(
idxdiv
(
raw_kernel_shape
[
0
],
oc_bn
),
idxdiv
(
raw_kernel_shape
[
1
],
ic_bn
),
raw_kernel_shape
[
2
],
raw_kernel_shape
[
3
],
ic_bn
,
oc_bn
)
raw_kernel_shape
[
2
],
raw_kernel_shape
[
3
],
ic_bn
,
oc_bn
)
new_data
=
tvm
.
placeholder
(
new_data_shape
,
data
.
dtype
)
new_data
=
tvm
.
placeholder
(
new_data_shape
,
data
.
dtype
)
new_kernel
=
tvm
.
placeholder
(
new_kernel_shape
,
kernel
.
dtype
)
new_kernel
=
tvm
.
placeholder
(
new_kernel_shape
,
kernel
.
dtype
)
...
@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
...
@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
_
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
dtype
=
workload
_
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
dtype
=
workload
batch_size
,
in_channel
,
in_height
,
in_width
=
data
[:
-
1
]
batch_size
,
in_channel
,
in_height
,
in_width
=
data
[:
-
1
]
out_channel
,
_
,
k_height
,
k_width
=
kernel
[:
-
1
]
out_channel
,
_
,
k_height
,
k_width
=
kernel
[:
-
1
]
out_height
=
(
in_height
+
2
*
padding
[
0
]
-
k_height
)
//
strides
[
0
]
+
1
idxdiv
=
tvm
.
indexdiv
out_width
=
(
in_width
+
2
*
padding
[
1
]
-
k_width
)
//
strides
[
1
]
+
1
out_height
=
idxdiv
(
in_height
+
2
*
padding
[
0
]
-
k_height
,
strides
[
0
])
+
1
out_width
=
idxdiv
(
in_width
+
2
*
padding
[
1
]
-
k_width
,
strides
[
1
])
+
1
tile_ic
,
tile_oc
=
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
]
tile_ic
,
tile_oc
=
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
]
in_shape
=
(
batch_size
,
i
n_channel
//
tile_ic
,
in_height
,
in_width
,
tile_ic
)
in_shape
=
(
batch_size
,
i
dxdiv
(
in_channel
,
tile_ic
)
,
in_height
,
in_width
,
tile_ic
)
in_layout
=
"NCHW
%
dc"
%
tile_ic
in_layout
=
"NCHW
%
dc"
%
tile_ic
out_shape
=
(
batch_size
,
out_channel
//
tile_oc
,
out_height
,
out_width
,
tile_oc
)
out_shape
=
(
batch_size
,
idxdiv
(
out_channel
,
tile_oc
)
,
out_height
,
out_width
,
tile_oc
)
out_layout
=
"NCHW
%
dc"
%
tile_oc
out_layout
=
"NCHW
%
dc"
%
tile_oc
return
((
in_shape
,
in_layout
),),
((
out_shape
,
out_layout
),)
return
((
in_shape
,
in_layout
),),
((
out_shape
,
out_layout
),)
...
...
topi/python/topi/x86/dense.py
View file @
2ded2d8c
...
@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
...
@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
packw
=
tvm
.
compute
(
packw_shape
,
packw
=
tvm
.
compute
(
packw_shape
,
lambda
z
,
y
,
x
:
weight
[
z
*
packw_bn
+
x
,
y
],
name
=
"packed_weight"
)
lambda
z
,
y
,
x
:
weight
[
z
*
packw_bn
+
x
,
y
],
name
=
"packed_weight"
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
k
=
tvm
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
k
=
tvm
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
tvm
.
compute
((
M
,
N
),
C
=
tvm
.
compute
((
M
,
N
),
lambda
y
,
x
:
tvm
.
sum
(
lambda
y
,
x
:
tvm
.
sum
(
data
[
y
,
k
]
.
astype
(
out_dtype
)
*
data
[
y
,
k
]
.
astype
(
out_dtype
)
*
packw
[
x
//
packw_bn
,
k
,
x
%
packw_bn
]
.
astype
(
out_dtype
),
packw
[
idxdiv
(
x
,
packw_bn
),
k
,
idxmod
(
x
,
packw_bn
)
]
.
astype
(
out_dtype
),
axis
=
k
),
axis
=
k
),
tag
=
"dense_pack"
)
tag
=
"dense_pack"
)
if
bias
is
not
None
:
if
bias
is
not
None
:
...
...
topi/python/topi/x86/depthwise_conv2d.py
View file @
2ded2d8c
...
@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
...
@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
data_pad
=
data
data_pad
=
data
# depthconv stage
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
kh
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'kh'
)
kh
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'kw'
)
kw
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'kw'
)
Output
=
tvm
.
compute
(
Output
=
tvm
.
compute
(
(
batch
,
out_channel_chunk
,
out_height
,
out_width
,
out_channel_block
),
(
batch
,
out_channel_chunk
,
out_height
,
out_width
,
out_channel_block
),
lambda
b
,
oco
,
oh
,
ow
,
oci
:
tvm
.
sum
(
lambda
b
,
oco
,
oh
,
ow
,
oci
:
tvm
.
sum
(
(
data_pad
[
b
,
(
oco
*
out_channel_block
+
oci
)
//
channel_multiplier
//
in_channel_block
,
(
data_pad
[
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
b
,
((
oco
*
out_channel_block
+
oci
)
//
channel_multiplier
)
%
in_channel_block
]
idxdiv
(
idxdiv
(
oco
*
out_channel_block
+
oci
,
channel_multiplier
),
in_channel_block
),
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
idxmod
(
idxdiv
(
oco
*
out_channel_block
+
oci
,
channel_multiplier
),
in_channel_block
)]
.
astype
(
out_dtype
)
*
.
astype
(
out_dtype
)
*
kernel
[
oco
,
0
,
kh
,
kw
,
0
,
oci
]
.
astype
(
out_dtype
)),
kernel
[
oco
,
0
,
kh
,
kw
,
0
,
oci
]
.
astype
(
out_dtype
)),
axis
=
[
kh
,
kw
]),
axis
=
[
kh
,
kw
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment