Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2ded2d8c
Unverified
Commit
2ded2d8c
authored
Sep 27, 2019
by
Tianqi Chen
Committed by
GitHub
Sep 27, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Use explicit div mode in python. (#4014)
parent
16bed7e6
Hide whitespace changes
Inline
Side-by-side
Showing
46 changed files
with
481 additions
and
192 deletions
+481
-192
docs/api/python/tvm.rst
+14
-0
python/tvm/api.py
+71
-0
python/tvm/contrib/nnpack.py
+8
-4
python/tvm/expr.py
+29
-4
python/tvm/generic.py
+18
-1
python/tvm/hybrid/parser.py
+15
-2
python/tvm/relay/op/_transform.py
+4
-4
src/api/api_ir.cc
+2
-0
src/contrib/hybrid/codegen_hybrid.cc
+15
-2
src/contrib/hybrid/codegen_hybrid.h
+6
-4
src/lang/expr_operator.cc
+4
-0
src/pass/lower_intrin.cc
+0
-3
tests/python/unittest/test_arith_canonical_simplify.py
+50
-40
tests/python/unittest/test_arith_const_int_bound.py
+21
-14
tests/python/unittest/test_arith_deduce_bound.py
+6
-4
tests/python/unittest/test_arith_intset.py
+6
-4
tests/python/unittest/test_arith_modular_set.py
+16
-11
tests/python/unittest/test_autotvm_flop_calculator.py
+7
-2
tests/python/unittest/test_build_lower.py
+1
-1
tests/python/unittest/test_codegen_llvm.py
+4
-3
tests/python/unittest/test_ir_builder.py
+4
-2
tests/python/unittest/test_lang_buffer.py
+18
-11
tests/python/unittest/test_lang_operator.py
+8
-4
tests/python/unittest/test_lang_tensor_overload_op.py
+2
-2
tests/python/unittest/test_pass_basic.py
+4
-2
tests/python/unittest/test_pass_equal.py
+1
-1
tests/python/unittest/test_pass_loop_partition.py
+2
-2
tests/python/unittest/test_schedule_bound_inference.py
+7
-5
tests/python/unittest/test_schedule_tensorize.py
+5
-2
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
+6
-1
topi/python/topi/arm_cpu/conv2d_transpose.py
+6
-1
topi/python/topi/arm_cpu/depthwise_conv2d.py
+16
-8
topi/python/topi/cuda/conv2d_transpose_nchw.py
+6
-4
topi/python/topi/cuda/conv2d_winograd.py
+8
-3
topi/python/topi/cuda/nms.py
+2
-2
topi/python/topi/cuda/ssd/multibox.py
+1
-1
topi/python/topi/mali/conv2d.py
+7
-3
topi/python/topi/nn/bitserial_conv2d.py
+14
-6
topi/python/topi/nn/bitserial_dense.py
+8
-3
topi/python/topi/nn/conv2d.py
+7
-4
topi/python/topi/nn/depthwise_conv2d.py
+15
-5
topi/python/topi/nn/dilate.py
+4
-2
topi/python/topi/nn/flatten.py
+4
-2
topi/python/topi/x86/conv2d.py
+18
-9
topi/python/topi/x86/dense.py
+3
-1
topi/python/topi/x86/depthwise_conv2d.py
+8
-3
No files found.
docs/api/python/tvm.rst
View file @
2ded2d8c
...
...
@@ -35,6 +35,13 @@ The user facing API for computation declaration.
tvm.thread_axis
tvm.comm_reducer
tvm.sum
tvm.div
tvm.indexdiv
tvm.indexmod
tvm.truncdiv
tvm.truncmod
tvm.floordiv
tvm.floormod
tvm.min
tvm.max
tvm.tag_scope
...
...
@@ -53,6 +60,13 @@ The user facing API for computation declaration.
.. autofunction:: tvm.thread_axis
.. autofunction:: tvm.comm_reducer
.. autofunction:: tvm.sum
.. autofunction:: tvm.div
.. autofunction:: tvm.indexdiv
.. autofunction:: tvm.indexmod
.. autofunction:: tvm.truncdiv
.. autofunction:: tvm.truncmod
.. autofunction:: tvm.floordiv
.. autofunction:: tvm.floormod
.. autofunction:: tvm.min
.. autofunction:: tvm.max
.. autofunction:: tvm.tag_scope
python/tvm/api.py
View file @
2ded2d8c
...
...
@@ -890,6 +890,77 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
reducer
.
__doc__
=
doc_str
.
format
(
name
)
return
reducer
def
div
(
a
,
b
):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b).
"""
return
_make
.
_OpDiv
(
a
,
b
)
def
indexdiv
(
a
,
b
):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return
_make
.
_OpIndexDiv
(
a
,
b
)
def
indexmod
(
a
,
b
):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : Expr
The left hand operand, known to be non-negative.
b : Expr
The right hand operand, known to be non-negative.
Returns
-------
res : Expr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return
_make
.
_OpIndexMod
(
a
,
b
)
def
truncdiv
(
a
,
b
):
"""Compute the truncdiv of two expressions.
...
...
python/tvm/contrib/nnpack.py
View file @
2ded2d8c
...
...
@@ -101,8 +101,11 @@ def convolution_inference(
assert
isinstance
(
stride
,
list
)
and
len
(
stride
)
==
2
batch
,
_
,
input_height
,
input_width
=
data
.
shape
output_channels
,
_
,
kernel_height
,
kernel_width
=
kernel
.
shape
output_height
=
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
)
/
stride
[
0
]
+
1
output_width
=
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
)
/
stride
[
1
]
+
1
idxdiv
=
_api
.
indexdiv
output_height
=
idxdiv
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
,
stride
[
0
])
+
1
output_width
=
idxdiv
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
,
stride
[
1
])
+
1
return
_api
.
extern
(
(
batch
,
output_channels
,
output_height
,
output_width
),
...
...
@@ -153,8 +156,9 @@ def convolution_inference_without_weight_transform(
batch
,
_
,
input_height
,
input_width
=
data
.
shape
output_channels
,
_
,
_
,
_
=
transformed_kernel
.
shape
kernel_height
,
kernel_width
=
(
3
,
3
)
output_height
=
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
)
/
stride
[
0
]
+
1
output_width
=
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
)
/
stride
[
1
]
+
1
idxdiv
=
_api
.
indexdiv
output_height
=
idxdiv
(
input_height
+
padding
[
0
]
+
padding
[
1
]
-
kernel_height
,
stride
[
0
])
+
1
output_width
=
idxdiv
(
input_width
+
padding
[
0
]
+
padding
[
1
]
-
kernel_width
,
stride
[
1
])
+
1
return
_api
.
extern
(
(
batch
,
output_channels
,
output_height
,
output_width
),
...
...
python/tvm/expr.py
View file @
2ded2d8c
...
...
@@ -33,11 +33,25 @@ For example, you can use addexp.a to get the left operand of an Add node.
# pylint: disable=missing-docstring
from
__future__
import
absolute_import
as
_abs
from
._ffi.node
import
NodeBase
,
NodeGeneric
,
register_node
from
._ffi.runtime_ctypes
import
TVMType
,
TypeCode
from
.
import
make
as
_make
from
.
import
generic
as
_generic
from
.
import
_api_internal
def
div_ambiguity_error
():
return
RuntimeError
(
"TVM supports multiple types of integer divisions, "
+
"please call div, indexdiv/indexmod, floordiv/floormod "
+
" or truncdiv/truncmod directly to avoid ambiguity in the code."
)
def
_dtype_is_int
(
value
):
if
isinstance
(
value
,
int
):
return
True
return
(
isinstance
(
value
,
ExprOp
)
and
TVMType
(
value
.
dtype
)
.
type_code
==
TypeCode
.
INT
)
class
ExprOp
(
object
):
def
__add__
(
self
,
other
):
return
_generic
.
add
(
self
,
other
)
...
...
@@ -58,24 +72,35 @@ class ExprOp(object):
return
_generic
.
multiply
(
other
,
self
)
def
__div__
(
self
,
other
):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
self
,
other
)
def
__rdiv__
(
self
,
other
):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
other
,
self
)
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
self
,
other
)
def
__rtruediv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
return
_generic
.
divide
(
other
,
self
)
def
__floordiv__
(
self
,
other
):
return
self
.
__div__
(
other
)
# return _generic.floordiv(self, other)
return
_generic
.
divide
(
self
,
other
)
def
__rfloordiv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
# return _generic.floordiv(other, self)
return
_generic
.
divide
(
other
,
self
)
def
__mod__
(
self
,
other
):
# raise div_ambiguity_error()
return
_make
.
_OpMod
(
self
,
other
)
def
__neg__
(
self
):
...
...
python/tvm/generic.py
View file @
2ded2d8c
...
...
@@ -25,6 +25,7 @@ from . import make as _make
#Operator precedence used when overloading.
__op_priority__
=
0
def
add
(
lhs
,
rhs
):
"""Generic add operator.
...
...
@@ -78,7 +79,6 @@ def multiply(lhs, rhs):
"""
return
_make
.
_OpMul
(
lhs
,
rhs
)
def
divide
(
lhs
,
rhs
):
"""Generic divide operator.
...
...
@@ -96,6 +96,23 @@ def divide(lhs, rhs):
"""
return
_make
.
_OpDiv
(
lhs
,
rhs
)
def
floordiv
(
lhs
,
rhs
):
"""Generic floordiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return
_make
.
_OpFloorDiv
(
lhs
,
rhs
)
def
cast
(
src
,
dtype
):
"""Generic cast operator.
...
...
python/tvm/hybrid/parser.py
View file @
2ded2d8c
...
...
@@ -31,6 +31,7 @@ from . import util
from
.preprocessor
import
determine_variable_usage
from
..api
import
all
as
_all
from
..api
import
any
as
_any
from
..container
import
Array
from
..tensor
import
Tensor
,
Operation
from
..
import
_api_internal
as
_tvm_internal
...
...
@@ -78,6 +79,18 @@ class Symbol(Enum):
ThreadBind
=
10
def
_floordiv
(
x
,
y
):
if
isinstance
(
x
,
_expr
.
ExprOp
)
or
isinstance
(
y
,
_expr
.
ExprOp
):
return
_api
.
floordiv
(
x
,
y
)
return
operator
.
floordiv
(
x
,
y
)
def
_floormod
(
x
,
y
):
if
isinstance
(
x
,
_expr
.
ExprOp
)
or
isinstance
(
y
,
_expr
.
ExprOp
):
return
_api
.
floormod
(
x
,
y
)
return
operator
.
mod
(
x
,
y
)
class
HybridParser
(
ast
.
NodeVisitor
):
"""Python AST visitor pass which finally lowers it to HalideIR"""
...
...
@@ -87,8 +100,8 @@ class HybridParser(ast.NodeVisitor):
ast
.
Sub
:
operator
.
sub
,
ast
.
Mult
:
operator
.
mul
,
ast
.
Div
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
truediv
,
ast
.
FloorDiv
:
operator
.
div
if
sys
.
version_info
[
0
]
==
2
else
operator
.
true
div
,
ast
.
Mod
:
operator
.
mod
,
ast
.
FloorDiv
:
_floor
div
,
ast
.
Mod
:
_floor
mod
,
ast
.
BitOr
:
operator
.
or_
,
ast
.
BitAnd
:
operator
.
and_
,
ast
.
BitXor
:
operator
.
xor
,
...
...
python/tvm/relay/op/_transform.py
View file @
2ded2d8c
...
...
@@ -67,7 +67,7 @@ _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
@script
def
_arange_shape_func
(
start
,
stop
,
step
):
out
=
output_tensor
((
1
,),
"int64"
)
out
[
0
]
=
int64
(
ceil_div
((
float32
(
stop
[
0
])
-
float32
(
start
[
0
])),
float32
(
step
[
0
])))
out
[
0
]
=
int64
(
ceil_div
((
int64
(
stop
[
0
])
-
int64
(
start
[
0
])),
int64
(
step
[
0
])))
return
out
@_reg.register_shape_func
(
"arange"
,
True
)
...
...
@@ -131,12 +131,12 @@ def _reshape_shape_func(data_shape, newshape, ndim):
assert
len
(
newshape
)
-
i
>
2
,
"Not enough dims in new shape for -4"
if
newshape
[
i
+
1
]
==
-
1
:
assert
newshape
[
i
+
2
]
!=
-
1
,
"Split dims cannot both be -1."
out
[
dst_idx
]
=
data_shape
[
src_idx
]
/
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
]
=
data_shape
[
src_idx
]
/
/
int64
(
newshape
[
i
+
2
])
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
else
:
out
[
dst_idx
]
=
int64
(
newshape
[
i
+
1
])
if
newshape
[
i
+
2
]
==
-
1
:
out
[
dst_idx
+
1
]
=
data_shape
[
src_idx
]
/
int64
(
newshape
[
i
+
1
])
out
[
dst_idx
+
1
]
=
data_shape
[
src_idx
]
/
/
int64
(
newshape
[
i
+
1
])
else
:
out
[
dst_idx
+
1
]
=
int64
(
newshape
[
i
+
2
])
assert
data_shape
[
src_idx
]
==
out
[
dst_idx
]
*
out
[
dst_idx
+
1
],
\
...
...
@@ -159,7 +159,7 @@ def _reshape_shape_func(data_shape, newshape, ndim):
new_size
=
int64
(
1
)
for
i
in
const_range
(
out
.
shape
[
0
]):
new_size
*=
out
[
i
]
out
[
infer_idx
]
=
old_size
/
new_size
out
[
infer_idx
]
=
old_size
/
/
new_size
return
out
@_reg.register_shape_func
(
"reshape"
,
False
)
...
...
src/api/api_ir.cc
View file @
2ded2d8c
...
...
@@ -200,6 +200,8 @@ REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP
(
_OpMul
,
operator
*
);
REGISTER_MAKE_BINARY_OP
(
_OpDiv
,
div
);
REGISTER_MAKE_BINARY_OP
(
_OpMod
,
truncmod
);
REGISTER_MAKE_BINARY_OP
(
_OpIndexDiv
,
indexdiv
);
REGISTER_MAKE_BINARY_OP
(
_OpIndexMod
,
indexmod
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorDiv
,
floordiv
);
REGISTER_MAKE_BINARY_OP
(
_OpFloorMod
,
floormod
);
REGISTER_MAKE_BINARY_OP
(
_OpTruncDiv
,
truncdiv
);
...
...
src/contrib/hybrid/codegen_hybrid.cc
View file @
2ded2d8c
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -146,15 +146,28 @@ void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void
CodeGenHybrid
::
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"*"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
type
.
is_int
())
PrintBinaryExpr
(
op
,
"//"
,
os
,
this
);
else
PrintBinaryExpr
(
op
,
"/"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
FloorDiv
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
if
(
op
->
type
.
is_int
())
PrintBinaryExpr
(
op
,
"//"
,
os
,
this
);
else
PrintBinaryExpr
(
op
,
"/"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"%"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
FloorMod
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"%"
,
os
,
this
);
}
void
CodeGenHybrid
::
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
{
// NOLINT(*)
PrintBinaryExpr
(
op
,
"min"
,
os
,
this
);
}
...
...
src/contrib/hybrid/codegen_hybrid.h
View file @
2ded2d8c
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -100,6 +100,8 @@ class CodeGenHybrid :
void
VisitExpr_
(
const
Mul
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Div
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Mod
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
FloorDiv
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
FloorMod
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Min
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
Max
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
void
VisitExpr_
(
const
EQ
*
op
,
std
::
ostream
&
os
)
override
;
// NOLINT(*)
...
...
@@ -161,12 +163,12 @@ class CodeGenHybrid :
std
::
string
GetUniqueName
(
std
::
string
prefix
);
/*! \brief The output code string builder. */
std
::
stringstream
stream
;
/*!
/*!
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
std
::
string
GetVarID
(
const
Variable
*
v
);
/*!
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
...
...
src/lang/expr_operator.cc
View file @
2ded2d8c
...
...
@@ -216,6 +216,8 @@ Expr indexmod(Expr a, Expr b) {
}
Expr
floordiv
(
Expr
a
,
Expr
b
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
());
CHECK
(
b
.
type
().
is_int
()
||
b
.
type
().
is_uint
());
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorDiv
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
...
...
@@ -223,6 +225,8 @@ Expr floordiv(Expr a, Expr b) {
}
Expr
floormod
(
Expr
a
,
Expr
b
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
());
CHECK
(
b
.
type
().
is_int
()
||
b
.
type
().
is_uint
());
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
FloorMod
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
...
...
src/pass/lower_intrin.cc
View file @
2ded2d8c
...
...
@@ -74,9 +74,6 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if
(
op
==
nullptr
)
return
ret
;
int
shift
;
const
DataType
&
dtype
=
op
->
type
;
if
(
dtype
.
is_float
())
{
return
floor
(
Div
::
make
(
op
->
a
,
op
->
b
));
}
CHECK
(
dtype
.
is_int
()
||
!
dtype
.
is_uint
());
if
(
is_const_power_of_two_integer
(
op
->
b
,
&
shift
))
{
...
...
tests/python/unittest/test_arith_canonical_simplify.py
View file @
2ded2d8c
...
...
@@ -33,9 +33,11 @@ def test_mul_sum_simplify():
x
*
13
+
z
*
4
+
y
*
4
+
6
)
ck
.
verify
(
x
*
3
-
4
*
x
+
1
,
1
-
x
)
ck
.
verify
(
y
+
x
*
3
-
5
*
x
+
1
+
y
,
y
*
2
+
1
-
x
*
2
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# trucdiv
ck
.
verify
(
(
x
+
y
+
x
+
y
*
3
)
/
2
,
y
*
2
+
x
)
ck
.
verify
(
(
x
+
y
+
x
+
y
*
3
)
%
2
,
0
)
ck
.
verify
(
tdiv
(
x
+
y
+
x
+
y
*
3
,
2
)
,
y
*
2
+
x
)
ck
.
verify
(
tmod
(
x
+
y
+
x
+
y
*
3
,
2
)
,
0
)
# floordiv
fld
=
tvm
.
floordiv
...
...
@@ -51,28 +53,31 @@ def test_split_index_simplify():
x
,
y
,
z
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
),
tvm
.
var
(
"z"
)
# trucdiv
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# split div const
ck
.
verify
(
(
x
/
3
)
*
3
+
x
%
3
,
x
)
ck
.
verify
(
(
x
/
6
)
*
6
+
((
x
/
3
)
%
2
)
*
3
+
x
%
3
,
x
)
ck
.
verify
(
((
x
%
16
)
/
2
)
*
2
/
4
,
(
x
%
16
)
/
4
)
ck
.
verify
(
(
x
%
2
)
/
8
,
0
)
ck
.
verify
(
(
x
%
2
)
/
7
,
0
)
ck
.
verify
(
((
x
%
16
)
/
2
)
*
2
/
6
,
(
x
%
16
)
/
6
)
ck
.
verify
(
tdiv
(
x
,
3
)
*
3
+
tmod
(
x
,
3
)
,
x
)
ck
.
verify
(
tdiv
(
x
,
6
)
*
6
+
tmod
(
tdiv
(
x
,
3
),
2
)
*
3
+
tmod
(
x
,
3
)
,
x
)
ck
.
verify
(
tdiv
(
tdiv
(
tmod
(
x
,
16
),
2
)
*
2
,
4
),
tdiv
(
tmod
(
x
,
16
),
4
)
)
ck
.
verify
(
tdiv
(
tmod
(
x
,
2
),
8
)
,
0
)
ck
.
verify
(
tdiv
(
tmod
(
x
,
2
),
7
)
,
0
)
ck
.
verify
(
tdiv
(
tdiv
(
tmod
(
x
,
16
),
2
)
*
2
,
6
),
tdiv
(
tmod
(
x
,
16
),
6
)
)
# split mod const
ck
.
verify
(
(
x
*
8
)
%
16
,
(
x
%
2
)
*
8
)
ck
.
verify
(
(
x
*
8
)
%
2
,
0
)
ck
.
verify
(
tmod
((
x
*
8
),
16
),
tmod
(
x
,
2
)
*
8
)
ck
.
verify
(
tmod
(
x
*
8
,
2
)
,
0
)
# simplify then fold
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1000
))
ck
.
verify
(
(
x
*
4
+
y
)
/
2
*
2
+
(
x
*
4
+
y
)
%
2
,
x
*
4
+
y
)
ck
.
verify
(
tdiv
(
x
*
4
+
y
,
2
)
*
2
+
tmod
(
x
*
4
+
y
,
2
)
,
x
*
4
+
y
)
# complex fold
ck
.
verify
(
(
z
*
9
+
y
)
/
2
*
2
+
(
z
*
9
+
y
)
%
2
,
z
*
9
+
y
)
ck
.
verify
(
tdiv
(
z
*
9
+
y
,
2
)
*
2
+
tmod
(
z
*
9
+
y
,
2
)
,
z
*
9
+
y
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
100
,
1000
),
True
)
ck
.
verify
(
(
x
*
4
+
y
)
/
2
*
2
+
(
x
*
4
+
y
)
%
2
,
x
*
4
+
y
)
ck
.
verify
(
tdiv
(
x
*
4
+
y
,
2
)
*
2
+
tmod
(
x
*
4
+
y
,
2
)
,
x
*
4
+
y
)
# floordiv
fld
=
tvm
.
floordiv
...
...
@@ -85,23 +90,24 @@ def test_split_index_simplify():
ck
.
verify
(
fld
(
fld
(
flm
(
x
,
16
),
2
)
*
2
,
6
),
fld
(
flm
(
x
,
16
),
6
))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck
.
verify
(
(
x
/
6
)
*
2
+
fld
(
x
,
3
)
%
2
,
(
x
/
6
)
*
2
+
fld
(
x
,
3
)
%
2
)
ck
.
verify
(
tdiv
(
x
,
6
)
*
2
+
tmod
(
fld
(
x
,
3
),
2
),
tdiv
(
x
,
6
)
*
2
+
tmod
(
fld
(
x
,
3
),
2
)
)
def
test_div_simplify
():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
tdiv
=
tvm
.
truncdiv
# truc div
ck
.
verify
(
(
16
+
48
*
x
)
/
16
,
x
*
3
+
1
)
ck
.
verify
(
tdiv
(
16
+
48
*
x
,
16
)
,
x
*
3
+
1
)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
ck
.
verify
(
(
17
+
48
*
x
)
/
16
,
(
x
*
48
+
17
)
/
16
)
ck
.
verify
(
tdiv
(
17
+
48
*
x
,
16
),
tdiv
(
x
*
48
+
17
,
16
)
)
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
10
))
ck
.
verify
(
(
17
+
48
*
x
)
/
16
,
x
*
3
+
1
)
ck
.
verify
(
tdiv
(
17
+
48
*
x
,
16
)
,
x
*
3
+
1
)
# Trying expressions that are not simplifiable for any values of the variables
ck
.
verify
(
(
17
+
47
*
x
)
/
16
,
(
x
*
47
+
17
)
/
16
)
ck
.
verify
(
tdiv
(
17
+
47
*
x
,
16
),
tdiv
(
x
*
47
+
17
,
16
)
)
# floordiv
fld
=
tvm
.
floordiv
...
...
@@ -124,8 +130,10 @@ def test_canonical_mixed():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
z
=
tvm
.
const
(
3
,
"int32"
)
ck
.
verify
(
x
/
(
z
*
z
)
-
x
/
(
z
*
z
),
0
)
ck
.
verify
(
x
/
(
z
+
z
)
-
x
/
(
z
+
z
),
0
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
ck
.
verify
(
tdiv
(
x
,
(
z
*
z
))
-
tdiv
(
x
,
(
z
*
z
)),
0
)
ck
.
verify
(
tdiv
(
x
,
(
z
+
z
))
-
tdiv
(
x
,
(
z
+
z
)),
0
)
ck
.
verify
(
x
-
2
<
3
,
x
<
5
)
ck
.
verify
(
tvm
.
max
(
x
,
1
)
-
tvm
.
max
(
x
,
1
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
1
)
-
tvm
.
min
(
x
,
1
),
0
)
...
...
@@ -207,42 +215,44 @@ def test_reduce_simplify():
tvm
.
sum
(
k
+
j
,
[
k
,
j
]))
ck
.
verify
(
tvm
.
sum
(
A
[
3
],
[]),
A
[
3
])
# The rule below is not typical, removed for now
ck
.
verify
(
tvm
.
sum
(
k
/
10
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
ck
.
verify
(
tvm
.
sum
(
tvm
.
div
(
k
,
10
)
,
k
),
tvm
.
sum
(
tvm
.
const
(
0
,
"int32"
),
k
))
def
test_simplify_if_then_else
():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
# simplification that takes condition into account.
res
=
tvm
.
if_then_else
((
x
*
4
+
y
)
>=
466036
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
tvm
.
if_then_else
(
24512
<=
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
),
tmod
(
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
)
-
24512
,
16
)
,
x
),
y
)
res2
=
tvm
.
if_then_else
((
x
*
4
)
>=
466036
-
y
,
tvm
.
if_then_else
(
24512
<=
((((
x
*
4
)
+
y
)
-
466036
)
%
24528
),
(((((
x
*
4
)
+
y
)
-
466036
)
%
24528
)
-
24512
)
%
16
,
tvm
.
if_then_else
(
24512
<=
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
),
tmod
(
tmod
(((
x
*
4
)
+
y
)
-
466036
,
24528
)
-
24512
,
16
)
,
x
),
y
)
expected
=
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
466036
,
(
x
*
4
+
y
)),
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
((((
x
*
4
)
+
y
)
-
4
)
%
24528
)),
(((
x
*
4
)
+
y
)
-
4
)
%
16
,
tvm
.
if_then_else
(
tvm
.
expr
.
LE
(
24512
,
tmod
(((
x
*
4
)
+
y
)
-
4
,
24528
)),
tmod
(((
x
*
4
)
+
y
)
-
4
,
16
)
,
x
),
y
)
ck
.
verify
(
res
,
expected
)
ck
.
verify
(
res2
,
expected
)
# can only simplify if condition
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
100
)
%
3
,
(
x
+
100
)
%
3
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
(
x
+
y
+
1
)
%
3
,
(
x
+
100
)
%
3
)
res
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
tmod
(
x
+
y
+
100
,
3
),
tmod
(
x
+
100
,
3
)
)
expected
=
tvm
.
expr
.
Select
(
tvm
.
all
(
x
>=
-
1
,
y
>=
0
),
tmod
(
x
+
y
+
1
,
3
),
tmod
(
x
+
100
,
3
)
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
>
2
,
x
,
0
),
0
)
tvm
.
if_then_else
(
tdiv
(
x
,
3
)
>
2
,
x
,
0
),
0
)
expected
=
tvm
.
expr
.
Select
(
x
>=
10
,
x
,
0
)
ck
.
verify
(
res
,
ck
.
analyzer
.
canonical_simplify
(
expected
))
res
=
tvm
.
expr
.
Select
(
x
>=
10
,
tvm
.
if_then_else
(
x
/
3
<
2
,
x
,
0
),
0
)
tvm
.
if_then_else
(
tdiv
(
x
,
3
)
<
2
,
x
,
0
),
0
)
ck
.
verify
(
res
,
0
)
...
...
@@ -250,20 +260,20 @@ def test_complex_cases():
ck
=
CanonicalChecker
()
x
=
tvm
.
var
(
"x"
)
y
=
tvm
.
var
(
"y"
)
res2
=
(((((((((((
x
*
128
)
+
y
)
%
1296
)
/
36
)
*
2
)
+
1
)
/
2
)
*
36
)
+
((((((
x
*
128
)
+
y
)
%
36
)
*
2
)
+
1
)
/
2
))
-
(((
x
*
128
)
+
y
)
%
1296
))
+
1
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
res2
=
(
tdiv
(
tdiv
(
tmod
(
x
*
128
+
y
,
1296
),
36
)
*
2
+
1
,
2
)
*
36
+
tdiv
(
tmod
((
x
*
128
)
+
y
,
36
)
*
2
+
1
,
2
)
-
tmod
((
x
*
128
)
+
y
,
1296
)
+
1
)
ck
.
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
5
))
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
127
))
ck
.
verify
(
res2
,
1
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
1024
),
True
)
res3
=
((((((((((
x
*
1024
)
+
y
)
/
65536
)
+
((((
x
*
1024
)
+
y
)
%
65536
)
/
256
))
+
((((
x
*
1024
)
+
y
)
%
256
)
/
16
))
+
(((
x
*
1024
)
+
y
)
%
16
))
-
(
y
/
256
))
-
((
y
%
256
)
/
16
))
-
(
y
%
16
))
-
(
x
*
4
))
ck
.
verify
(
res3
,
((((
x
*
1024
)
+
y
)
/
256
)
-
(
y
/
256
))
-
(
x
*
4
))
res3
=
(
tdiv
(
x
*
1024
+
y
,
65536
)
+
tdiv
(
tmod
(
x
*
1024
+
y
,
65536
),
256
)
+
tdiv
(
tmod
(
x
*
1024
+
y
,
256
),
16
)
+
tmod
(
x
*
1024
+
y
,
16
)
-
tdiv
(
y
,
256
)
-
tdiv
(
tmod
(
y
,
256
),
16
)
-
tmod
(
y
,
16
)
-
(
x
*
4
))
ck
.
verify
(
res3
,
tdiv
((
x
*
1024
)
+
y
,
256
)
-
tdiv
(
y
,
256
)
-
(
x
*
4
))
if
__name__
==
"__main__"
:
...
...
tests/python/unittest/test_arith_const_int_bound.py
View file @
2ded2d8c
...
...
@@ -38,12 +38,13 @@ def test_dtype_bound():
def
test_cast_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
x
=
tvm
.
var
(
"x"
,
dtype
=
"int8"
)
bd
=
analyzer
.
const_int_bound
((
x
%
3
)
.
astype
(
"uint32"
))
tmod
=
tvm
.
truncmod
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
3
)
.
astype
(
"uint32"
))
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
2
bd
=
analyzer
.
const_int_bound
(
(
x
%
3
)
.
astype
(
"float32"
)
.
astype
(
"int32"
))
tmod
(
x
,
3
)
.
astype
(
"float32"
)
.
astype
(
"int32"
))
assert
bd
.
min_value
==
-
2
assert
bd
.
max_value
==
2
...
...
@@ -98,47 +99,50 @@ def test_mul_bound():
assert
bd
.
max_value
==
bd
.
POS_INF
def
test_div_bound
():
def
test_
trunc
div_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
-
2
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
0
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
-
4
assert
bd
.
max_value
==
9
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
4
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
-
2
,
1
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
/
y
)
bd
=
analyzer
.
const_int_bound
(
tdiv
(
x
,
y
)
)
assert
bd
.
min_value
==
bd
.
NEG_INF
assert
bd
.
max_value
==
bd
.
POS_INF
def
test_mod_bound
():
def
test_
trunc
mod_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tmod
=
tvm
.
truncmod
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
-
9
,
4
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
))
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
-
9
assert
bd
.
max_value
==
4
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
bd
.
NEG_INF
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
-
9
assert
bd
.
max_value
==
9
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
1
,
bd
.
POS_INF
),
override
=
True
)
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
4
,
10
),
override
=
True
)
bd
=
analyzer
.
const_int_bound
(
x
%
y
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
y
)
)
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
9
...
...
@@ -253,9 +257,12 @@ def test_shift_and_bound():
def
test_mix_index_bound
():
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
24
-
1
))
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
0
,
3
-
1
))
bd
=
analyzer
.
const_int_bound
(
(
x
%
8
)
+
(
x
/
8
)
*
8
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
8
)
+
tdiv
(
x
,
8
)
*
8
)
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
24
-
1
...
...
@@ -263,7 +270,7 @@ def test_mix_index_bound():
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
24
*
3
-
1
bd
=
analyzer
.
const_int_bound
(
(
x
%
7
)
+
(
x
/
7
)
*
7
)
bd
=
analyzer
.
const_int_bound
(
tmod
(
x
,
7
)
+
tdiv
(
x
,
7
)
*
7
)
assert
bd
.
min_value
==
0
assert
bd
.
max_value
==
(
23
//
7
)
*
7
+
6
...
...
@@ -273,8 +280,8 @@ if __name__ == "__main__":
test_cast_bound
()
test_add_sub_bound
()
test_mul_bound
()
test_div_bound
()
test_mod_bound
()
test_
trunc
div_bound
()
test_
trunc
mod_bound
()
test_floordiv_bound
()
test_floormod_bound
()
test_min_max_bound
()
...
...
tests/python/unittest/test_arith_deduce_bound.py
View file @
2ded2d8c
...
...
@@ -35,9 +35,11 @@ def test_deduce():
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
zero
=
tvm
.
const
(
0
,
"int32"
)
tdiv
=
tvm
.
truncdiv
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
(
(
d
-
c
)
/
(
b
*-
1
)
+
(
-
1
))
ans0
=
(
tdiv
(
d
-
c
,
b
*-
1
)
+
(
-
1
))
assert_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
...
...
@@ -46,7 +48,7 @@ def test_deduce():
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
(
(
d
-
c
)
/
d
-
1
)
ans0
=
(
tdiv
(
d
-
c
,
d
)
-
1
)
assert_expr_equal
(
res0
.
max_value
,
ans0
)
# expression containing variable a is on rhs
...
...
@@ -56,7 +58,7 @@ def test_deduce():
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(
((
c
-
b
)
+
-
1
)
/
4
-
1
)
ans1
=
(
tdiv
((
c
-
b
)
+
-
1
,
4
)
-
1
)
assert_expr_equal
(
res1
.
max_value
,
ans1
)
...
...
@@ -79,7 +81,7 @@ def test_deduce():
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
ans3
=
tdiv
(
2
,
c
)
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
...
...
tests/python/unittest/test_arith_intset.py
View file @
2ded2d8c
...
...
@@ -60,13 +60,14 @@ def test_add_sub():
def
test_mul_div
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tdiv
=
tvm
.
truncdiv
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
*
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
*
y
))
ck
.
verify
(
x
*
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
2
,
20
))
ck
.
verify
(
x
*
-
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
-
20
,
-
2
))
ck
.
verify
(
x
/
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
/
y
))
ck
.
verify
(
x
/
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
ck
.
verify
(
tdiv
(
x
,
y
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
tdiv
(
10
,
y
)
))
ck
.
verify
(
tdiv
(
x
,
2
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
fld
=
tvm
.
floordiv
ck
.
verify
(
fld
(
x
,
y
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
fld
(
10
,
y
)))
...
...
@@ -76,9 +77,10 @@ def test_mul_div():
def
test_mod
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
tmod
=
tvm
.
truncmod
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
%
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
x
%
10
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
ck
.
verify
(
tmod
(
x
,
y
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
tmod
(
x
,
10
)
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
flm
=
tvm
.
floormod
ck
.
verify
(
flm
(
x
,
10
),
{
x
:
tvm
.
arith
.
IntervalSet
(
-
10
,
10
)},
(
0
,
9
))
...
...
tests/python/unittest/test_arith_modular_set.py
View file @
2ded2d8c
...
...
@@ -54,7 +54,8 @@ def test_div_shift():
analyzer
=
tvm
.
arith
.
Analyzer
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
# not sure if x is non-negative
m
=
analyzer
.
modular_set
((
x
*
4
+
2
)
/
2
)
tdiv
=
tvm
.
truncdiv
m
=
analyzer
.
modular_set
(
tdiv
(
x
*
4
+
2
,
2
))
assert
m
.
coeff
==
1
assert
m
.
base
==
0
# right shift always round down so it is fine
...
...
@@ -67,7 +68,7 @@ def test_div_shift():
assert
m
.
base
==
1
# x is non-negative
analyzer
.
update
(
x
,
tvm
.
arith
.
ConstIntBound
(
0
,
100
))
m
=
analyzer
.
modular_set
(
(
x
*
4
+
2
)
/
2
)
m
=
analyzer
.
modular_set
(
tdiv
(
x
*
4
+
2
,
2
)
)
assert
m
.
coeff
==
2
assert
m
.
base
==
1
...
...
@@ -92,6 +93,7 @@ def test_mix_index():
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
tdiv
=
tvm
.
truncdiv
m
=
analyzer
.
modular_set
(
a
*
4
+
b
*
6
+
7
)
assert
m
.
coeff
==
2
assert
m
.
base
==
1
...
...
@@ -100,11 +102,11 @@ def test_mix_index():
assert
m
.
coeff
==
4
assert
m
.
base
==
3
m
=
analyzer
.
modular_set
(
(
a
*
4
+
1
)
/
(
b
*
8
+
3
))
m
=
analyzer
.
modular_set
(
tdiv
(
a
*
4
+
1
,
b
*
8
+
3
))
assert
m
.
coeff
==
1
assert
m
.
base
==
0
m
=
analyzer
.
modular_set
((
a
*
4
+
1
)
*
(
b
*
8
/
4
))
m
=
analyzer
.
modular_set
((
a
*
4
+
1
)
*
tdiv
(
b
*
8
,
4
))
assert
m
.
coeff
==
2
assert
m
.
base
==
0
...
...
@@ -121,11 +123,13 @@ def test_constraint_scope():
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
with
analyzer
.
constraint_scope
(
b
%
4
==
2
):
tmod
=
tvm
.
truncmod
with
analyzer
.
constraint_scope
(
tmod
(
b
,
4
)
==
2
):
m
=
analyzer
.
modular_set
(
b
+
1
)
assert
m
.
coeff
==
4
assert
m
.
base
==
3
with
analyzer
.
constraint_scope
(
a
%
2
==
1
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
2
)
==
1
):
m
=
analyzer
.
modular_set
(
b
+
a
*
2
)
assert
m
.
coeff
==
4
assert
m
.
base
==
0
...
...
@@ -140,15 +144,16 @@ def test_constraint_scope():
def
test_intersect
():
a
=
tvm
.
var
(
"a"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
with
analyzer
.
constraint_scope
(
a
%
4
==
1
):
with
analyzer
.
constraint_scope
(
a
%
3
==
1
):
tmod
=
tvm
.
truncmod
with
analyzer
.
constraint_scope
(
tmod
(
a
,
4
)
==
1
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
3
)
==
1
):
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
12
assert
m
.
base
==
1
with
analyzer
.
constraint_scope
(
a
%
3
==
2
):
with
analyzer
.
constraint_scope
(
a
%
5
==
3
):
with
analyzer
.
constraint_scope
(
a
%
7
==
2
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
3
)
==
2
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
5
)
==
3
):
with
analyzer
.
constraint_scope
(
tmod
(
a
,
7
)
==
2
):
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
105
assert
m
.
base
==
23
...
...
tests/python/unittest/test_autotvm_flop_calculator.py
View file @
2ded2d8c
...
...
@@ -60,11 +60,14 @@ def test_pack_gemm():
k
=
tvm
.
reduce_axis
((
0
,
L
))
bn
=
4
fld
=
tvm
.
floordiv
flm
=
tvm
.
floormod
A_pack
=
tvm
.
compute
((
N
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
A
[
i
*
bn
+
k
][
j
])
B_pack
=
tvm
.
compute
((
M
//
bn
,
L
,
bn
),
lambda
i
,
j
,
k
:
B
[
i
*
bn
+
k
][
j
])
C_pack
=
tvm
.
compute
((
N
//
bn
,
M
//
bn
,
bn
,
bn
),
lambda
i
,
j
,
ii
,
jj
:
tvm
.
sum
(
A_pack
[
i
,
k
,
ii
]
.
astype
(
acc_dtype
)
*
B_pack
[
j
,
k
,
jj
]
.
astype
(
acc_dtype
),
axis
=
[
k
]))
C
=
tvm
.
compute
((
N
,
M
),
lambda
i
,
j
:
C_pack
[
i
//
bn
][
j
//
bn
][
i
%
bn
][
j
%
bn
])
C
=
tvm
.
compute
((
N
,
M
),
lambda
i
,
j
:
C_pack
[
fld
(
i
,
bn
)][
fld
(
j
,
bn
)][
flm
(
i
,
bn
)][
flm
(
j
,
bn
)
])
s
=
tvm
.
create_schedule
([
C
.
op
])
assert
compute_flop
(
s
)
==
2
*
N
*
L
*
M
...
...
@@ -119,9 +122,11 @@ def test_average_pool():
OH
=
(
H
-
KH
)
+
1
OW
=
(
W
-
KW
)
+
1
C
=
tvm
.
compute
(
(
N
,
CO
,
OH
,
OW
),
lambda
n
,
co
,
h
,
w
:
tvm
.
sum
(
D
[
n
][
co
][
h
+
kh
][
w
+
kw
]
.
astype
(
acc_dtype
)
/
(
KW
*
KH
),
axis
=
[
kh
,
kw
]))
lambda
n
,
co
,
h
,
w
:
tvm
.
sum
(
tvm
.
div
(
D
[
n
][
co
][
h
+
kh
][
w
+
kw
]
.
astype
(
acc_dtype
),
(
KW
*
KH
)),
axis
=
[
kh
,
kw
]))
s
=
tvm
.
create_schedule
([
C
.
op
])
...
...
tests/python/unittest/test_build_lower.py
View file @
2ded2d8c
...
...
@@ -35,7 +35,7 @@ def test_lower_rfactor():
def
test_dependent_output_shape
():
n
,
m
,
x
=
tvm
.
var
(
'n'
),
tvm
.
var
(
'm'
),
tvm
.
var
(
'x'
)
A
=
tvm
.
placeholder
((
n
,
m
))
B
=
tvm
.
compute
((
m
,
n
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
,
name
=
'B'
)
B
=
tvm
.
compute
((
m
,
n
/
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
,
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
mod
=
tvm
.
build
(
s
,
[
A
,
B
,
x
])
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
2ded2d8c
...
...
@@ -409,7 +409,7 @@ def test_llvm_div():
"""Check that the semantics of div and mod is the same as in C/C++"""
def
check_div
(
start
,
end
,
divisor
,
dtype
):
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
))
/
tvm
.
const
(
divisor
,
dtype
))
lambda
i
:
tvm
.
div
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)
))
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
...
...
@@ -418,8 +418,9 @@ def test_llvm_div():
tvm
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
ref
)
def
check_mod
(
start
,
end
,
divisor
,
dtype
):
tmod
=
tvm
.
truncmod
T
=
tvm
.
compute
((
end
-
start
,),
lambda
i
:
t
vm
.
expr
.
Cast
(
dtype
,
(
start
+
i
))
%
tvm
.
const
(
divisor
,
dtype
))
lambda
i
:
t
mod
(
tvm
.
expr
.
Cast
(
dtype
,
(
start
+
i
)),
tvm
.
const
(
divisor
,
dtype
)
))
s
=
tvm
.
create_schedule
([
T
.
op
])
f
=
tvm
.
build
(
s
,
[
T
],
"llvm"
)
a
=
tvm
.
nd
.
empty
((
end
-
start
,),
dtype
)
...
...
@@ -443,7 +444,7 @@ def test_llvm_div():
def
test_llvm_fp_math
():
def
check_llvm_reciprocal
(
n
):
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
1.0
/
(
1e+37
*
A
[
i
]
),
name
=
'B'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
div
(
1.0
,(
1e+37
*
A
[
i
])
),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"llvm"
)
...
...
tests/python/unittest/test_ir_builder.py
View file @
2ded2d8c
...
...
@@ -41,8 +41,9 @@ def test_if():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
tmod
=
tvm
.
truncmod
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
if_scope
(
(
i
%
2
)
==
0
):
with
ib
.
if_scope
(
tmod
(
i
,
2
)
==
0
):
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
else_scope
():
A
[
0
]
=
A
[
i
]
+
2
...
...
@@ -108,13 +109,14 @@ def test_gpu():
dtype
=
"float32"
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
fld
=
tvm
.
floordiv
def
test_device_ir
(
A
,
B
,
C
):
n
=
A
.
shape
[
0
]
max_threads
=
32
ib
=
tvm
.
ir_builder
.
create
()
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
(
n
+
max_threads
-
1
)
//
max_threads
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
fld
(
n
+
max_threads
-
1
,
max_threads
)
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
max_threads
)
idx
=
bx
.
var
*
max_threads
+
tx
.
var
Aptr
=
ib
.
buffer_ptr
(
A
)
...
...
tests/python/unittest/test_lang_buffer.py
View file @
2ded2d8c
...
...
@@ -94,24 +94,31 @@ def test_buffer_index_merge_mult_mod():
def
assert_simplified_equal
(
index_simplified
,
index_direct
):
assert
tvm
.
ir_pass
.
Equal
(
index_simplified
,
index_direct
),
\
"index_simplified=
%
s, index_direct=
%
s"
%
(
index_simplified
,
index_direct
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# Test Case1
index_simplified
=
A_stride
.
vload
(((
k0
%
k1
)
/
s
,
(
k0
%
k1
)
%
s
+
(
k0
/
k1
)
*
k1
))
index_simplified
=
A_stride
.
vload
(
(
idxdiv
(
idxmod
(
k0
,
k1
),
s
),
idxmod
(
idxmod
(
k0
,
k1
),
s
)
+
idxdiv
(
k0
,
k1
)
*
k1
))
index_direct
=
A_stride
.
vload
((
0
,
k0
))
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case2
index_simplified
=
A
.
vload
((
(
k0
%
(
k1
/
s
))
/
n
,
(
k0
%
(
k1
/
s
))
%
n
+
(
k0
%
k1
)))
index_direct
=
A
.
vload
((
0
,
k0
%
k1
+
k0
%
(
k1
/
s
)))
index_simplified
=
A
.
vload
((
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
,
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
+
idxmod
(
k0
,
k1
)))
index_direct
=
A
.
vload
((
0
,
idxmod
(
k0
,
k1
)
+
idxmod
(
k0
,
idxdiv
(
k1
,
s
)
)))
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case3
index_simplified
=
A
.
vload
((((
k0
/
(
k1
/
s
))
*
(
k1
/
s
))
/
n
+
(
k0
%
(
k1
/
s
))
/
n
,
((
k0
/
(
k1
/
s
))
*
(
k1
/
s
))
%
n
+
(
k0
%
(
k1
/
s
))
%
n
))
index_simplified
=
A
.
vload
((
idxdiv
((
idxdiv
(
k0
,
idxdiv
(
k1
,
s
))
*
idxdiv
(
k1
,
s
)),
n
)
+
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
),
idxmod
((
idxdiv
(
k0
,
idxdiv
(
k1
,
s
))
*
idxdiv
(
k1
,
s
)),
n
)
+
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)))
index_direct
=
A
.
vload
((
0
,
k0
))
assert_simplified_equal
(
index_simplified
,
index_direct
)
# Test Case4 (not able to simplify)
index_simplified
=
A
.
vload
(((
k0
%
(
k1
/
s
))
/
n
,
(
k0
%
(
k1
/
n
))
%
n
+
(
k0
%
k1
)))
index_direct
=
A
.
vload
((
0
,
((
k0
%
(
k1
/
s
))
/
n
)
*
n
+
((
k0
%
(
k1
/
n
))
%
n
+
(
k0
%
k1
))))
index_simplified
=
A
.
vload
((
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
),
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
n
)),
n
)
+
idxmod
(
k0
,
k1
)))
index_direct
=
A
.
vload
((
0
,
idxdiv
(
idxmod
(
k0
,
idxdiv
(
k1
,
s
)),
n
)
*
n
+
(
idxmod
(
idxmod
(
k0
,
idxdiv
(
k1
,
n
)),
n
)
+
idxmod
(
k0
,
k1
))))
assert_simplified_equal
(
index_simplified
,
index_direct
)
...
...
@@ -143,14 +150,14 @@ def test_buffer_broadcast():
check
()
def
test_b
buffer_
roadcast_expr
():
def
test_b
uffer_b
roadcast_expr
():
n0
,
m0
,
x
=
tvm
.
var
(
'n0'
),
tvm
.
var
(
'm0'
),
tvm
.
var
(
'x'
)
n1
,
m1
=
tvm
.
var
(
'n1'
),
tvm
.
var
(
'm1'
)
o0
,
o1
=
tvm
.
var
(
'o0'
),
tvm
.
var
(
'o1'
)
A
=
tvm
.
placeholder
((
m0
,
n0
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
m1
,
n1
),
name
=
'B'
)
C
=
tvm
.
compute
((
o0
,
o1
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
+
B
[
i
,
j
],
name
=
'C'
)
C
=
tvm
.
compute
((
o0
,
o1
/
/
x
),
lambda
i
,
j
:
A
[
i
,
j
]
+
B
[
i
,
j
],
name
=
'C'
)
Ab
=
tvm
.
decl_buffer
(
A
.
shape
,
A
.
dtype
,
name
=
"Ab"
,
buffer_type
=
"auto_broadcast"
)
Bb
=
tvm
.
decl_buffer
(
B
.
shape
,
B
.
dtype
,
name
=
"Bb"
,
buffer_type
=
"auto_broadcast"
)
...
...
tests/python/unittest/test_lang_operator.py
View file @
2ded2d8c
...
...
@@ -32,10 +32,11 @@ def test_const_fold():
if
not
isinstance
(
x
,
(
tvm
.
expr
.
IntImm
,
tvm
.
expr
.
UIntImm
))
or
x
.
value
!=
int
(
y
):
raise
ValueError
(
"check error:
%
s vs
%
s "
%
(
x
,
y
))
tmod
=
tvm
.
truncmod
check
(
lambda
x
,
y
:
x
+
y
,
3
,
4
)
check
(
lambda
x
,
y
:
x
*
y
,
3
,
12
)
check
(
lambda
x
,
y
:
x
*
y
-
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
-
y
%
10
,
3
,
12
)
check
(
lambda
x
,
y
:
x
-
tmod
(
y
,
10
)
,
3
,
12
)
check
(
lambda
x
,
y
:
x
//
y
+
10
,
100
,
12
)
check
(
lambda
x
,
y
:
x
&
y
+
10
,
112
,
128
)
check
(
lambda
x
,
y
:
x
>
y
,
112
,
128
)
...
...
@@ -47,13 +48,15 @@ def test_const_fold():
def
test_const_fold2
():
x
=
tvm
.
var
(
"x"
)
tmod
=
tvm
.
truncmod
tdiv
=
tvm
.
truncdiv
assert
(
x
+
0
)
.
same_as
(
x
)
assert
(
0
+
x
)
.
same_as
(
x
)
assert
(
x
-
0
)
.
same_as
(
x
)
assert
(
x
%
1
)
.
value
==
0
assert
tmod
(
x
,
1
)
.
value
==
0
assert
(
x
*
1
)
.
same_as
(
x
)
assert
(
1
*
x
)
.
same_as
(
x
)
assert
isinstance
(
(
1
/
x
),
tvm
.
expr
.
Div
)
assert
isinstance
(
tdiv
(
1
,
x
),
tvm
.
expr
.
Div
)
def
test_const_fold3
():
# Test that using ints with logic operations is forbidden
...
...
@@ -88,8 +91,9 @@ def test_const_fold3():
def
test_const_fold4
():
x1
=
tvm
.
const
(
4
,
"int32"
)
x2
=
x1
+
5
tdiv
=
tvm
.
truncdiv
assert
isinstance
(
x2
,
tvm
.
expr
.
IntImm
)
and
x2
.
value
==
9
x3
=
x2
/
3
x3
=
tdiv
(
x2
,
3
)
assert
isinstance
(
x3
,
tvm
.
expr
.
IntImm
)
and
x3
.
value
==
3
x4
=
x3
+
0.55
assert
isinstance
(
x4
,
tvm
.
expr
.
FloatImm
)
and
abs
(
x4
.
value
-
3.55
)
<
1e-6
...
...
tests/python/unittest/test_lang_tensor_overload_op.py
View file @
2ded2d8c
...
...
@@ -72,7 +72,7 @@ def test_combination():
A
=
tvm
.
placeholder
((
n
,
m
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
m
),
name
=
'B'
)
C
=
tvm
.
placeholder
((
n
,
m
),
name
=
'C'
)
D
=
k
+
A
-
B
*
C
/
x
D
=
k
+
A
-
B
*
C
+
x
s
=
tvm
.
create_schedule
(
D
.
op
)
foo
=
tvm
.
build
(
s
,
[
x
,
A
,
B
,
C
,
D
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
...
...
@@ -82,7 +82,7 @@ def test_combination():
c
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
m
))
.
astype
(
C
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
foo
(
x
,
a
,
b
,
c
,
d
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
k
+
a
.
asnumpy
()
-
b
.
asnumpy
()
*
c
.
asnumpy
()
/
x
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
k
+
a
.
asnumpy
()
-
b
.
asnumpy
()
*
c
.
asnumpy
()
+
x
)
def
verify_tensor_scalar_bop
(
shape
,
typ
=
"add"
):
...
...
tests/python/unittest/test_pass_basic.py
View file @
2ded2d8c
...
...
@@ -17,13 +17,15 @@
import
tvm
def
test_simplify
():
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
x
=
tvm
.
var
(
'x'
)
e1
=
tvm
.
ir_pass
.
Simplify
(
x
+
2
+
1
)
assert
(
tvm
.
ir_pass
.
Equal
(
e1
,
x
+
3
))
e2
=
tvm
.
ir_pass
.
Simplify
(
x
*
3
+
5
*
x
)
assert
(
tvm
.
ir_pass
.
Equal
(
e2
,
x
*
8
))
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
x
/
3
*
3
)
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
t
vm
.
make
.
M
od
(
x
,
3
)))
e3
=
tvm
.
ir_pass
.
Simplify
(
x
-
tdiv
(
x
,
3
)
*
3
)
assert
(
tvm
.
ir_pass
.
Equal
(
e3
,
t
m
od
(
x
,
3
)))
def
test_verify_ssa
():
...
...
tests/python/unittest/test_pass_equal.py
View file @
2ded2d8c
...
...
@@ -24,7 +24,7 @@ def test_equal_expr():
return
x
+
y
+
1
def
func2
():
return
tvm
.
exp
(
(
x
+
y
+
1
)
*
y
/
4
)
return
tvm
.
exp
(
tvm
.
truncdiv
((
x
+
y
+
1
)
*
y
,
4
)
)
assert
tvm
.
ir_pass
.
Equal
(
func1
(),
func1
())
assert
tvm
.
ir_pass
.
Equal
(
func2
(),
func2
())
...
...
tests/python/unittest/test_pass_loop_partition.py
View file @
2ded2d8c
...
...
@@ -162,7 +162,7 @@ def test_condition():
ib
=
tvm
.
ir_builder
.
create
()
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
with
ib
.
for_range
(
0
,
((
n
+
3
)
/
4
),
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
tvm
.
truncdiv
(
n
+
3
,
4
),
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
4
,
'j'
)
as
j
:
ib
.
emit
(
tvm
.
make
.
Evaluate
(
tvm
.
make
.
Select
(
ib
.
likely
(
i
*
4
+
j
<
n
),
m
,
n
)))
...
...
@@ -206,7 +206,7 @@ def test_everything_during_deduction():
ib
=
tvm
.
ir_builder
.
create
()
with
ib
.
for_range
(
0
,
n
,
'i'
)
as
i
:
with
ib
.
for_range
(
0
,
32
,
'j'
)
as
j
:
with
ib
.
if_scope
(
ib
.
likely
(
i
/
j
<
m
)):
with
ib
.
if_scope
(
ib
.
likely
(
tvm
.
truncdiv
(
i
,
j
)
<
m
)):
# this guard will produce everything during deduction
ib
.
emit
(
tvm
.
make
.
Evaluate
(
m
))
stmt
=
ib
.
get
()
...
...
tests/python/unittest/test_schedule_bound_inference.py
View file @
2ded2d8c
...
...
@@ -111,9 +111,11 @@ def test_bound_fusesplit1():
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
container
.
Map
)
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
min
-
(
xo
*
split1
)
/
l
)
.
value
==
0
)
idxdiv
=
tvm
.
indexdiv
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
min
-
idxdiv
(
xo
*
split1
,
l
))
.
value
==
0
)
expected_extent
=
(
((
xo
+
1
)
*
split1
-
1
)
/
l
-
(
xo
*
split1
)
/
l
+
1
)
expected_extent
=
(
idxdiv
((
xo
+
1
)
*
split1
-
1
,
l
)
-
idxdiv
(
xo
*
split1
,
l
)
+
1
)
for
i
in
range
(
1
,
6
):
for
j
in
range
(
1
,
6
):
for
k
in
range
(
1
,
6
):
...
...
@@ -121,7 +123,7 @@ def test_bound_fusesplit1():
comp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
,
vars
))
.
value
exp_ext
=
tvm
.
ir_pass
.
Simplify
(
tvm
.
ir_pass
.
Substitute
(
expected_extent
,
vars
))
.
value
assert
(
comp_ext
==
exp_ext
)
assert
(
tvm
.
ir_pass
.
Simplify
(
bounds
[
A1
.
op
.
axis
[
1
]]
.
extent
-
l
)
.
value
==
0
)
def
test_bound_fusesplit2
():
...
...
@@ -394,11 +396,11 @@ def test_bound_simplification_failure():
if
not
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
:
print
(
stmt
)
assert
bounds
[
A
.
op
.
axis
[
0
]]
.
extent
.
value
<=
2
tdiv
=
tvm
.
truncdiv
# These are hard to simplify, moreover we don't simplify them
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
min
(
-
3
*
i
,
-
2
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
tvm
.
min
(
3
*
i
,
4
*
i
)
+
tvm
.
max
(
-
3
*
i
,
-
4
*
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
-
2
*
(
i
/
2
)
-
tvm
.
min
(
i
,
0
-
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
-
2
*
tdiv
(
i
,
2
)
-
tvm
.
min
(
i
,
0
-
i
)]))
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
+
(
0
-
i
)]))
# This would cause out of bounds, but we nevertheless include it
_check
(
tvm
.
compute
((
10
,),
lambda
i
:
A
[
i
]))
...
...
tests/python/unittest/test_schedule_tensorize.py
View file @
2ded2d8c
...
...
@@ -221,11 +221,14 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def
test_tensorize_op
():
tdiv
=
tvm
.
truncdiv
tmod
=
tvm
.
truncmod
def
op_intrin
():
bh
=
9
bw
=
9
x
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
y
=
tvm
.
compute
((
bh
,
bw
),
lambda
i
,
j
:
x
[
j
/
3
+
i
%
3
,
j
%
3
+
i
/
3
])
y
=
tvm
.
compute
((
bh
,
bw
),
lambda
i
,
j
:
x
[
tdiv
(
j
,
3
)
+
tmod
(
i
,
3
),
tmod
(
j
,
3
)
+
tdiv
(
i
,
3
)])
def
intrin_func
(
ins
,
outs
):
xx
,
=
ins
...
...
@@ -236,7 +239,7 @@ def test_tensorize_op():
return
tvm
.
decl_tensor_intrin
(
y
.
op
,
intrin_func
)
A
=
tvm
.
placeholder
((
5
,
5
),
name
=
'A'
)
B
=
tvm
.
compute
((
9
,
9
),
lambda
i
,
j
:
A
[
j
/
3
+
i
%
3
,
j
%
3
+
i
/
3
])
B
=
tvm
.
compute
((
9
,
9
),
lambda
i
,
j
:
A
[
tdiv
(
j
,
3
)
+
tmod
(
i
,
3
),
tmod
(
j
,
3
)
+
tdiv
(
i
,
3
)
])
bt
=
op_intrin
()
s
=
tvm
.
create_schedule
(
B
.
op
)
...
...
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
View file @
2ded2d8c
...
...
@@ -128,8 +128,13 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
kernel_vec
[
co
,
ci
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_output'
)
return
output
...
...
topi/python/topi/arm_cpu/conv2d_transpose.py
View file @
2ded2d8c
...
...
@@ -123,8 +123,13 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
kernel_vec
[
co
,
ci
,
KH
-
1
-
kh
,
KW
-
1
-
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
ci
,
kh
,
kw
]),
name
=
'conv'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_conv2d_transpose_output'
)
return
output
...
...
topi/python/topi/arm_cpu/depthwise_conv2d.py
View file @
2ded2d8c
...
...
@@ -293,21 +293,29 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
kh
=
tvm
.
reduce_axis
((
0
,
KH
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
KW
),
name
=
'kw'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
if
dilation_h
!=
1
or
dilation_w
!=
1
:
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
(
co
*
VC
+
vc
)
//
M
,
kh
,
kw
,
vh
,
vw
]
.
astype
(
out_dtype
)
*
kernel_vec
[
co
//
M
,
co
%
M
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
idxdiv
(
co
*
VC
+
vc
,
M
),
kh
,
kw
,
vh
,
vw
]
.
astype
(
out_dtype
)
*
kernel_vec
[
idxdiv
(
co
,
M
),
idxmod
(
co
,
M
),
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
else
:
conv
=
tvm
.
compute
(
ovshape
,
lambda
n
,
co
,
h
,
w
,
vh
,
vw
,
vc
:
\
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
(
co
*
VC
+
vc
)
//
M
,
vh
*
HSTR
+
kh
,
tvm
.
sum
(
data_vec
[
n
,
h
,
w
,
idxdiv
((
co
*
VC
+
vc
),
M
)
,
vh
*
HSTR
+
kh
,
vw
*
WSTR
+
kw
]
.
astype
(
out_dtype
)
*
kernel_vec
[
co
//
M
,
co
%
M
,
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
kernel_vec
[
idxdiv
(
co
,
M
),
idxmod
(
co
,
M
),
kh
,
kw
,
vc
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
]),
name
=
'depthwise_conv'
)
output
=
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
conv
[
n
,
idxdiv
(
co
,
VC
),
idxdiv
(
h
,
VH
),
idxdiv
(
w
,
VW
),
idxmod
(
h
,
VH
),
idxmod
(
w
,
VW
),
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_depthwise_conv_nchw_output'
)
return
output
...
...
topi/python/topi/cuda/conv2d_transpose_nchw.py
View file @
2ded2d8c
...
...
@@ -69,9 +69,11 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
[
0
,
0
,
(
bpad_bottom
+
stride_h
-
1
)
//
stride_h
,
(
bpad_right
+
stride_w
-
1
)
//
stride_w
],
name
=
'FirstPad'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# remove extra padding introduced by dilatation
border_h
=
(
stride_h
-
bpad_top
%
stride_h
)
%
stride_h
border_w
=
(
stride_w
-
bpad_left
%
stride_w
)
%
stride_w
border_h
=
idxmod
(
stride_h
-
idxmod
(
bpad_top
,
stride_h
),
stride_h
)
border_w
=
idxmod
(
stride_w
-
idxmod
(
bpad_left
,
stride_w
),
stride_w
)
# dilation stage
data
=
FirstPad
...
...
@@ -83,8 +85,8 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
index_tuple
=
[]
for
i
in
range
(
n
):
if
not
equal_const_int
(
strides
[
i
],
1
):
index_tuple
.
append
(
i
ndices
[
i
]
//
strides
[
i
]
)
not_zero
.
append
(
(
indices
[
i
]
%
strides
[
i
])
.
equal
(
0
))
index_tuple
.
append
(
i
dxdiv
(
indices
[
i
],
strides
[
i
])
)
not_zero
.
append
(
idxmod
(
indices
[
i
],
strides
[
i
])
.
equal
(
0
))
else
:
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
...
...
topi/python/topi/cuda/conv2d_winograd.py
View file @
2ded2d8c
...
...
@@ -85,10 +85,12 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
else
:
kernel_pack
=
kernel
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# pack input tile
input_tile
=
tvm
.
compute
((
CI
,
P
,
alpha
,
alpha
),
lambda
c
,
p
,
eps
,
nu
:
data_pad
[
p
//
(
nH
*
nW
)][
c
][
p
//
nW
%
nH
*
m
+
eps
]
[
p
%
nW
*
m
+
nu
],
name
=
'd'
)
data_pad
[
idxdiv
(
p
,
(
nH
*
nW
))][
c
][
idxmod
(
idxdiv
(
p
,
nW
),
nH
)
*
m
+
eps
]
[
idxmod
(
p
,
nW
)
*
m
+
nu
],
name
=
'd'
)
# transform data
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
...
...
@@ -113,7 +115,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
# output
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
inverse
[
co
][
n
*
nH
*
nW
+
(
h
//
m
)
*
nW
+
w
//
m
][
h
%
m
][
w
%
m
],
inverse
[
co
,
n
*
nH
*
nW
+
idxdiv
(
h
,
m
)
*
nW
+
idxdiv
(
w
,
m
),
idxmod
(
h
,
m
),
idxmod
(
w
,
m
)],
name
=
'output'
,
tag
=
'conv2d_nchw_winograd'
)
cfg
.
add_flop
(
2
*
N
*
CO
*
H
*
W
*
CI
*
KH
*
KW
)
...
...
topi/python/topi/cuda/nms.py
View file @
2ded2d8c
...
...
@@ -245,7 +245,7 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
new_range
=
num_anchors
//
elem_per_thread
+
1
# Scan: Downsweep:
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
# number of batches
i
=
tid
/
/
num_anchors
# number of batches
j
=
tid
%
num_anchors
# number of anchors
with
ib
.
if_scope
(
j
<
elem_per_thread
):
idx
[
tid
]
=
idx_in
[
tid
]
...
...
@@ -304,7 +304,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
tid
=
bx
*
max_threads
+
tx
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
i
=
tid
/
/
num_anchors
j
=
tid
%
num_anchors
base_idx
=
i
*
num_anchors
*
elem_length
with
ib
.
if_scope
(
flag
[
tid
]
>
0
):
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
2ded2d8c
...
...
@@ -315,7 +315,7 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
tid
=
bx
*
max_threads
+
tx
with
ib
.
if_scope
(
tid
<
batch_size
*
num_anchors
):
i
=
tid
/
num_anchors
i
=
tid
/
/
num_anchors
j
=
tid
%
num_anchors
with
ib
.
if_scope
(
cls_id
[
tid
]
>
0
):
with
ib
.
if_scope
(
tid
==
0
):
...
...
topi/python/topi/mali/conv2d.py
View file @
2ded2d8c
...
...
@@ -293,11 +293,14 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
tvm
.
sum
(
input_tile
[
ci
][
p
][
r_a
][
r_b
][
vp
]
*
B
[
r_a
][
eps
]
*
B
[
r_b
][
nu
],
axis
=
[
r_a
,
r_b
]),
name
=
'V'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# batch gemm
ci
=
tvm
.
reduce_axis
((
0
,
CI
),
name
=
'c'
)
M
=
tvm
.
compute
((
alpha
,
alpha
,
CO
,
P_round
),
lambda
eps
,
nu
,
co
,
p
:
tvm
.
sum
(
U
[
eps
][
nu
][
co
//
bna
][
ci
][
co
%
bna
]
*
V
[
eps
][
nu
][
p
//
bnb
][
ci
][
p
%
bnb
],
axis
=
ci
),
name
=
'M'
)
tvm
.
sum
(
U
[
eps
][
nu
][
idxdiv
(
co
,
bna
)][
ci
][
idxmod
(
co
,
bna
)
]
*
V
[
eps
][
nu
][
idxdiv
(
p
,
bnb
)][
ci
][
idxmod
(
p
,
bnb
)
],
axis
=
ci
),
name
=
'M'
)
r_a
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_a'
)
r_b
=
tvm
.
reduce_axis
((
0
,
alpha
),
'r_b'
)
...
...
@@ -307,7 +310,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output
output
=
tvm
.
compute
((
N
,
CO
,
H
,
W
),
lambda
n
,
co
,
h
,
w
:
Y
[
co
][
n
*
nH
*
nW
+
(
h
//
m
)
*
nW
+
w
//
m
][
h
%
m
][
w
%
m
]
Y
[
co
,
n
*
nH
*
nW
+
idxdiv
(
h
,
m
)
*
nW
+
idxdiv
(
w
,
m
),
idxmod
(
h
,
m
),
idxmod
(
w
,
m
)]
# The following hack term is used to make the padding in batch gemm ("M")
# effective, otherwise the padding will be eliminated by bound inference.
# Use `tvm.expr.Mul` instead of `*` to avoid issues in const folding.
...
...
topi/python/topi/nn/bitserial_conv2d.py
View file @
2ded2d8c
...
...
@@ -313,10 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis
=
[
ci
,
dh
,
dw
,
b1
,
b2
])
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv_out'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
return
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
co
//
VC
][
h
//
VH
][
w
//
VW
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
co
,
h
,
w
:
conv
[
n
][
idxdiv
(
co
,
VC
)][
idxdiv
(
h
,
VH
)][
idxdiv
(
w
,
VW
)][
idxmod
(
h
,
VH
)][
idxmod
(
w
,
VW
)][
idxmod
(
co
,
VC
)],
name
=
'conv_vec'
,
tag
=
'spatial_bitserial_conv_nchw'
)
@autotvm.register_topi_compute
(
bitserial_conv2d_nhwc
,
'cpu'
,
'direct'
)
def
spatial_pack_nhwc
(
cfg
,
data
,
kernel
,
stride
,
padding
,
in_bits
,
weight_bits
,
...
...
@@ -415,9 +419,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv
=
tvm
.
compute
(
ovshape
,
_conv
,
name
=
'conv'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv
[
n
][
h
//
VH
][
w
//
VW
][
co
//
VC
][
h
%
VH
][
w
%
VW
][
co
%
VC
],
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
return
tvm
.
compute
(
oshape
,
lambda
n
,
h
,
w
,
co
:
conv
[
n
][
idxdiv
(
h
,
VH
)][
idxdiv
(
w
,
VW
)][
idxdiv
(
co
,
VC
)][
idxmod
(
h
,
VH
)][
idxmod
(
w
,
VW
)][
idxmod
(
co
,
VC
)],
name
=
'output_unpack'
,
tag
=
'spatial_bitserial_conv_nhwc'
)
@tvm.target.generic_func
def
bitserial_conv2d_legalize
(
attrs
,
inputs
,
types
):
...
...
topi/python/topi/nn/bitserial_dense.py
View file @
2ded2d8c
...
...
@@ -121,13 +121,18 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
weight_vec
=
tvm
.
compute
(
wvshape
,
lambda
xo
,
wb
,
vx
,
k
:
weight_packed
[
xo
*
VX
+
vx
][
wb
][
k
],
name
=
'weight_vec'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
matmul_unipolar
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
(
tvm
.
popcount
(
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
])
-
tvm
.
popcount
(
~
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
]))
.
astype
(
out_dtype
)
(
tvm
.
popcount
(
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
])
-
tvm
.
popcount
(
~
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
])
)
.
astype
(
out_dtype
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense_unipolar'
)
matmul
=
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
tvm
.
sum
(
tvm
.
popcount
(
weight_vec
[
j
//
VX
,
wb
,
j
%
VX
,
k
]
&
data_packed
[
i
,
db
,
k
])
.
astype
(
out_dtype
)
tvm
.
popcount
(
weight_vec
[
idxdiv
(
j
,
VX
),
wb
,
idxmod
(
j
,
VX
),
k
]
&
data_packed
[
i
,
db
,
k
]
)
.
astype
(
out_dtype
)
<<
(
db
+
wb
)
.
astype
(
out_dtype
),
axis
=
[
wb
,
db
,
k
]),
tag
=
'bitserial_dense'
)
# binary ops
...
...
topi/python/topi/nn/conv2d.py
View file @
2ded2d8c
...
...
@@ -480,17 +480,20 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
i
c
//
ic_bn
,
i
dxdiv
(
ic
,
ic_bn
)
,
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
i
c
%
ic_bn
]
.
astype
(
out_dtype
)
i
dxmod
(
ic
,
ic_bn
)
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
i
c
//
ic_bn
,
i
dxdiv
(
ic
,
ic_bn
)
,
kh
,
kw
,
i
c
%
ic_bn
,
i
dxmod
(
ic
,
ic_bn
)
,
oc_block
],
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
...
...
topi/python/topi/nn/depthwise_conv2d.py
View file @
2ded2d8c
...
...
@@ -105,14 +105,17 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
Output
=
tvm
.
compute
(
(
batch
,
out_channel
,
out_height
,
out_width
),
lambda
b
,
c
,
i
,
j
:
tvm
.
sum
(
(
PaddedInput
[
b
,
c
/
channel_multiplier
,
i
*
stride_h
+
di
*
dilation_h
,
(
PaddedInput
[
b
,
idxdiv
(
c
,
channel_multiplier
)
,
i
*
stride_h
+
di
*
dilation_h
,
j
*
stride_w
+
dj
*
dilation_w
]
.
astype
(
out_dtype
)
*
Filter
[
c
/
channel_multiplier
,
c
%
channel_multiplier
,
di
,
dj
]
.
astype
(
out_dtype
)),
Filter
[
idxdiv
(
c
,
channel_multiplier
),
idxmod
(
c
,
channel_multiplier
),
di
,
dj
]
.
astype
(
out_dtype
)),
axis
=
[
di
,
dj
]),
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nchw"
)
return
Output
...
...
@@ -176,14 +179,19 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No
pad_after
=
[
0
,
pad_down
,
pad_right
,
0
]
PaddedInput
=
pad
(
Input
,
pad_before
,
pad_after
,
name
=
"PaddedInput"
)
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
di
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'di'
)
dj
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'dj'
)
Output
=
tvm
.
compute
(
(
batch
,
out_height
,
out_width
,
out_channel
),
lambda
b
,
i
,
j
,
c
:
tvm
.
sum
(
(
PaddedInput
[
b
,
i
*
stride_h
+
di
*
dilation_h
,
j
*
stride_w
+
dj
*
dilation_w
,
c
/
channel_multiplier
]
.
astype
(
out_dtype
)
*
Filter
[
di
,
dj
,
c
/
channel_multiplier
,
c
%
channel_multiplier
]
.
astype
(
out_dtype
)),
idxdiv
(
c
,
channel_multiplier
)]
.
astype
(
out_dtype
)
*
Filter
[
di
,
dj
,
idxdiv
(
c
,
channel_multiplier
),
idxmod
(
c
,
channel_multiplier
)]
.
astype
(
out_dtype
)),
axis
=
[
di
,
dj
]),
name
=
'DepthwiseConv2d'
,
tag
=
"depthwise_conv2d_nhwc"
)
return
Output
...
...
@@ -286,11 +294,13 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
dh
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
1
]
.
value
),
name
=
'dh'
)
dw
=
tvm
.
reduce_axis
((
0
,
Out_grad
.
shape
[
2
]
.
value
),
name
=
'dw'
)
db
=
tvm
.
reduce_axis
((
0
,
batch
),
name
=
'db'
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
Weight_grad
=
tvm
.
compute
(
(
filter_h
,
filter_w
,
in_c
,
channel_multiplier
),
lambda
fh
,
fw
,
c
,
m
:
tvm
.
sum
(
Out_grad
[
db
,
dh
,
dw
,
c
*
channel_multiplier
+
m
%
channel_multiplier
]
*
Out_grad
[
db
,
dh
,
dw
,
c
*
channel_multiplier
+
idxmod
(
m
,
channel_multiplier
)
]
*
padded_in
[
db
,
fh
+
dh
*
stride_h
,
fw
+
dw
*
stride_w
,
c
],
axis
=
[
db
,
dh
,
dw
]),
tag
=
'depthwise_conv2d_backward_weight_nhwc'
)
...
...
topi/python/topi/nn/dilate.py
View file @
2ded2d8c
...
...
@@ -52,10 +52,12 @@ def dilate(data, strides, name="DilatedInput"):
def
_dilate
(
*
indices
):
not_zero
=
[]
index_tuple
=
[]
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
for
i
in
range
(
n
):
if
not
util
.
equal_const_int
(
strides
[
i
],
1
):
index_tuple
.
append
(
i
ndices
[
i
]
/
strides
[
i
]
)
not_zero
.
append
(
(
indices
[
i
]
%
strides
[
i
])
.
equal
(
0
))
index_tuple
.
append
(
i
dxdiv
(
indices
[
i
],
strides
[
i
])
)
not_zero
.
append
(
idxmod
(
indices
[
i
],
strides
[
i
])
.
equal
(
0
))
else
:
index_tuple
.
append
(
indices
[
i
])
if
not_zero
:
...
...
topi/python/topi/nn/flatten.py
View file @
2ded2d8c
...
...
@@ -38,12 +38,14 @@ def flatten(data):
for
i
in
range
(
1
,
len
(
ishape
)):
dim
=
dim
*
ishape
[
i
]
oshape
=
[
ishape
[
0
],
dim
]
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
def
unwrap
(
idx
,
shape
):
index
=
[]
for
s
in
reversed
(
shape
):
index
.
append
(
idx
%
s
)
idx
=
idx
/
s
index
.
append
(
idx
mod
(
idx
,
s
)
)
idx
=
idx
div
(
idx
,
s
)
return
list
(
reversed
(
index
))
return
tvm
.
compute
(
oshape
,
lambda
i
,
j
:
data
(
i
,
*
unwrap
(
j
,
ishape
[
1
:])))
topi/python/topi/x86/conv2d.py
View file @
2ded2d8c
...
...
@@ -175,16 +175,20 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
ic
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'ic'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
idxmod
=
tvm
.
indexmod
idxdiv
=
tvm
.
indexdiv
conv
=
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_vec
[
n
,
ic
//
ic_bn
,
oh
*
HSTR
+
kh
*
dilation_h
,
ic
%
ic_bn
,
tvm
.
sum
(
data_vec
[
n
,
idxdiv
(
ic
,
ic_bn
),
oh
*
HSTR
+
kh
*
dilation_h
,
idxmod
(
ic
,
ic_bn
),
ow
*
WSTR
+
kw
*
dilation_w
]
.
astype
(
out_dtype
)
*
kernel_vec
[
oc_chunk
,
ic
//
ic_bn
,
kh
,
kw
,
ic
%
ic_bn
,
kernel_vec
[
oc_chunk
,
idxdiv
(
ic
,
ic_bn
),
kh
,
kw
,
idxmod
(
ic
,
ic_bn
),
oc_block
]
.
astype
(
out_dtype
),
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv'
)
unpack
=
tvm
.
compute
(
unpack_shape
,
lambda
n
,
c
,
h
,
w
:
conv
[
n
,
c
//
oc_bn
,
h
,
w
,
c
%
oc_bn
]
lambda
n
,
c
,
h
,
w
:
conv
[
n
,
idxdiv
(
c
,
oc_bn
),
h
,
w
,
idxmod
(
c
,
oc_bn
)
]
.
astype
(
out_dtype
),
name
=
'output_unpack'
,
tag
=
'conv2d_nchw'
)
...
...
@@ -311,14 +315,17 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
cfg
=
get_config
()
_create_tuning_space
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
origin_layout
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
# change shape with the value in config
ic_bn
,
oc_bn
,
ow_bn
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
cfg
[
"tile_ow"
]
.
size
[
-
1
])
new_data_shape
=
(
raw_data_shape
[
0
],
raw_data_shape
[
1
]
//
ic_bn
,
new_data_shape
=
(
raw_data_shape
[
0
],
idxdiv
(
raw_data_shape
[
1
],
ic_bn
)
,
raw_data_shape
[
2
],
raw_data_shape
[
3
],
ic_bn
)
data_layout
=
"NCHW
%
dc"
%
ic_bn
out_layout
=
"NCHW
%
dc"
%
oc_bn
new_kernel_shape
=
(
raw_kernel_shape
[
0
]
//
oc_bn
,
raw_kernel_shape
[
1
]
//
ic_bn
,
new_kernel_shape
=
(
idxdiv
(
raw_kernel_shape
[
0
],
oc_bn
),
idxdiv
(
raw_kernel_shape
[
1
],
ic_bn
),
raw_kernel_shape
[
2
],
raw_kernel_shape
[
3
],
ic_bn
,
oc_bn
)
new_data
=
tvm
.
placeholder
(
new_data_shape
,
data
.
dtype
)
new_kernel
=
tvm
.
placeholder
(
new_kernel_shape
,
kernel
.
dtype
)
...
...
@@ -334,12 +341,14 @@ def _conv2d_infer_layout(workload, cfg):
_
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
dtype
=
workload
batch_size
,
in_channel
,
in_height
,
in_width
=
data
[:
-
1
]
out_channel
,
_
,
k_height
,
k_width
=
kernel
[:
-
1
]
out_height
=
(
in_height
+
2
*
padding
[
0
]
-
k_height
)
//
strides
[
0
]
+
1
out_width
=
(
in_width
+
2
*
padding
[
1
]
-
k_width
)
//
strides
[
1
]
+
1
idxdiv
=
tvm
.
indexdiv
out_height
=
idxdiv
(
in_height
+
2
*
padding
[
0
]
-
k_height
,
strides
[
0
])
+
1
out_width
=
idxdiv
(
in_width
+
2
*
padding
[
1
]
-
k_width
,
strides
[
1
])
+
1
tile_ic
,
tile_oc
=
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
]
in_shape
=
(
batch_size
,
i
n_channel
//
tile_ic
,
in_height
,
in_width
,
tile_ic
)
in_shape
=
(
batch_size
,
i
dxdiv
(
in_channel
,
tile_ic
)
,
in_height
,
in_width
,
tile_ic
)
in_layout
=
"NCHW
%
dc"
%
tile_ic
out_shape
=
(
batch_size
,
out_channel
//
tile_oc
,
out_height
,
out_width
,
tile_oc
)
out_shape
=
(
batch_size
,
idxdiv
(
out_channel
,
tile_oc
)
,
out_height
,
out_width
,
tile_oc
)
out_layout
=
"NCHW
%
dc"
%
tile_oc
return
((
in_shape
,
in_layout
),),
((
out_shape
,
out_layout
),)
...
...
topi/python/topi/x86/dense.py
View file @
2ded2d8c
...
...
@@ -64,11 +64,13 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
packw
=
tvm
.
compute
(
packw_shape
,
lambda
z
,
y
,
x
:
weight
[
z
*
packw_bn
+
x
,
y
],
name
=
"packed_weight"
)
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
k
=
tvm
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
tvm
.
compute
((
M
,
N
),
lambda
y
,
x
:
tvm
.
sum
(
data
[
y
,
k
]
.
astype
(
out_dtype
)
*
packw
[
x
//
packw_bn
,
k
,
x
%
packw_bn
]
.
astype
(
out_dtype
),
packw
[
idxdiv
(
x
,
packw_bn
),
k
,
idxmod
(
x
,
packw_bn
)
]
.
astype
(
out_dtype
),
axis
=
k
),
tag
=
"dense_pack"
)
if
bias
is
not
None
:
...
...
topi/python/topi/x86/depthwise_conv2d.py
View file @
2ded2d8c
...
...
@@ -117,14 +117,19 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
data_pad
=
data
# depthconv stage
idxdiv
=
tvm
.
indexdiv
idxmod
=
tvm
.
indexmod
kh
=
tvm
.
reduce_axis
((
0
,
filter_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
filter_width
),
name
=
'kw'
)
Output
=
tvm
.
compute
(
(
batch
,
out_channel_chunk
,
out_height
,
out_width
,
out_channel_block
),
lambda
b
,
oco
,
oh
,
ow
,
oci
:
tvm
.
sum
(
(
data_pad
[
b
,
(
oco
*
out_channel_block
+
oci
)
//
channel_multiplier
//
in_channel_block
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
((
oco
*
out_channel_block
+
oci
)
//
channel_multiplier
)
%
in_channel_block
]
(
data_pad
[
b
,
idxdiv
(
idxdiv
(
oco
*
out_channel_block
+
oci
,
channel_multiplier
),
in_channel_block
),
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
idxmod
(
idxdiv
(
oco
*
out_channel_block
+
oci
,
channel_multiplier
),
in_channel_block
)]
.
astype
(
out_dtype
)
*
kernel
[
oco
,
0
,
kh
,
kw
,
0
,
oci
]
.
astype
(
out_dtype
)),
axis
=
[
kh
,
kw
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment