Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
153417a5
Unverified
Commit
153417a5
authored
Jun 13, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Revamp IntSet (#3272)
parent
9bb16872
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1184 additions
and
842 deletions
+1184
-842
include/tvm/arithmetic.h
+109
-85
python/tvm/arith.py
+31
-12
src/api/api_arith.cc
+5
-0
src/arithmetic/analyzer.cc
+3
-2
src/arithmetic/bound_deducer.cc
+4
-4
src/arithmetic/canonical_simplify.cc
+4
-2
src/arithmetic/compute_expr.h
+5
-5
src/arithmetic/const_fold.h
+56
-2
src/arithmetic/detect_linear_equation.cc
+4
-4
src/arithmetic/int_op_overflow.h
+2
-2
src/arithmetic/int_set.cc
+544
-476
src/arithmetic/int_set.h
+143
-0
src/arithmetic/int_set_internal.h
+0
-79
src/lang/expr_operator.cc
+17
-2
src/pass/loop_partition.cc
+16
-13
tests/python/unittest/test_arith_deduce_bound.py
+168
-0
tests/python/unittest/test_arith_intset.py
+73
-154
No files found.
include/tvm/arithmetic.h
View file @
153417a5
...
...
@@ -328,71 +328,14 @@ class ConstraintContext {
std
::
function
<
void
()
>
exit_
;
};
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class
Analyzer
{
public
:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier
rewrite_simplify
;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier
canonical_simplify
;
/*! \brief constructor */
Analyzer
();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Expr
&
expr
);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Range
&
range
);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
};
//-----------------------------------------------
// Integer set
abstraction API
.
// Integer set
data structure
.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign
of an expression or set
.
* \brief Sign
type of an integer expression
.
*/
enum
SignType
{
kPositive
,
...
...
@@ -401,8 +344,13 @@ enum SignType {
kUnknown
};
// internal node container of int set.
struct
IntSetNode
;
/*!
* \brief Base class of all IntSet containers.
*/
struct
IntSetNode
:
public
Node
{
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
};
/*!
* \brief Integer set class, represent a set of integers in one dimension.
...
...
@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range.
*/
Range
cover_range
(
Range
max_range
)
const
;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet
cover_interval
()
const
;
/*! \return Lower bound of the set */
Expr
min
()
const
;
/*! \return upper bound of the set */
...
...
@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
};
/*!
* \brief
Base class of all IntSet containers
.
* \brief
Integer set analyzer
.
*/
struct
IntSetNode
:
public
Node
{
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
class
IntSetAnalyzer
{
public
:
/*!
* \brief Find a symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet
operator
()(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
);
private
:
friend
class
Analyzer
;
explicit
IntSetAnalyzer
(
Analyzer
*
parent
);
~
IntSetAnalyzer
();
class
Impl
;
/*! \brief Internal impl */
Impl
*
impl_
;
};
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
* \brief Analyzer that contains bunch of sub-analyzers.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
Array
<
Expr
>
DetectClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
class
Analyzer
{
public
:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier
rewrite_simplify
;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier
canonical_simplify
;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer
int_set
;
/*! \brief constructor */
Analyzer
();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Expr
&
expr
);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Range
&
range
);
/*!
* \brief Whether can we prove expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
};
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
...
...
@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain
DomainTouched
(
Stmt
body
,
const
Tensor
&
tensor
,
bool
consider_calls
,
bool
consider_provides
);
// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array
<
Expr
>
DetectClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
// implementation
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
...
...
python/tvm/arith.py
View file @
153417a5
...
...
@@ -32,21 +32,21 @@ class IntSet(NodeBase):
return
_api_internal
.
_IntSetIsEverything
(
self
)
@register_node
@register_node
(
"arith.IntervalSet"
)
class
IntervalSet
(
IntSet
):
"""Represent set of continuous interval"""
def
min
(
self
):
"""get the minimum value"""
return
_api_internal
.
_IntervalSetGetMin
(
self
)
def
max
(
self
):
"""get the maximum value"""
return
_api_internal
.
_IntervalSetGetMax
(
self
)
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : Expr
The minimum value in the interval.
@register_node
class
StrideSet
(
IntSet
):
"""Represent set of strided integers"""
max_value : Expr
The maximum value in the interval.
"""
def
__init__
(
self
,
min_value
,
max_value
):
self
.
__init_handle_by_constructor__
(
_make_IntervalSet
,
min_value
,
max_value
)
@register_node
(
"arith.ModularSet"
)
...
...
@@ -114,6 +114,7 @@ class Analyzer:
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_canonical_simplify
=
_mod
(
"canonical_simplify"
)
self
.
_int_set
=
_mod
(
"int_set"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
def
const_int_bound
(
self
,
expr
):
...
...
@@ -176,6 +177,24 @@ class Analyzer:
"""
return
self
.
_canonical_simplify
(
expr
)
def
int_set
(
self
,
expr
,
dom_map
):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : tvm.Expr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
-------
result : IntSet
The result.
"""
return
self
.
_int_set
(
expr
,
dom_map
)
def
bind
(
self
,
var
,
expr
):
"""Bind a variable to the expression.
...
...
src/api/api_arith.cc
View file @
153417a5
...
...
@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API
(
"arith.intset_interval"
)
.
set_body_typed
(
IntSet
::
interval
);
TVM_REGISTER_API
(
"arith.DetectLinearEquation"
)
.
set_body_typed
(
DetectLinearEquation
);
...
...
@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
canonical_simplify
(
args
[
0
]);
});
}
else
if
(
name
==
"int_set"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
int_set
(
args
[
0
],
args
[
1
]);
});
}
else
if
(
name
==
"bind"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
1
].
node_sptr
();
...
...
src/arithmetic/analyzer.cc
View file @
153417a5
...
...
@@ -31,7 +31,8 @@ Analyzer::Analyzer()
:
const_int_bound
(
this
),
modular_set
(
this
),
rewrite_simplify
(
this
),
canonical_simplify
(
this
)
{
canonical_simplify
(
this
),
int_set
(
this
)
{
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
...
...
@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
return
ptr
->
value
>
=
lower_bound
;
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
...
...
src/arithmetic/bound_deducer.cc
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -30,12 +30,12 @@
#include <unordered_set>
#include <unordered_map>
#include "int_set.h"
namespace
tvm
{
namespace
arith
{
using
namespace
ir
;
using
HalideIR
::
Internal
::
Interval
;
// a visitor to find the path to the target variable
// from a expression.
...
...
@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
d
.
Deduce
();
if
(
!
d
.
success
)
return
IntSet
::
nothing
();
Expr
min
=
Interval
::
neg_inf
,
max
=
Interval
::
pos_inf
;
Expr
min
=
neg_inf
(),
max
=
pos_inf
()
;
if
(
d
.
is_greater
)
{
min
=
d
.
result
;
}
else
{
...
...
src/arithmetic/canonical_simplify.cc
View file @
153417a5
...
...
@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
*/
...
...
@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
if
(
TryCompare
(
temp
,
cval
)
==
kLT
)
{
return
temp
;
}
else
{
return
SplitModConst
(
ToSplitExpr
(
temp
),
cval
);
// contonue to use logic below.
a
=
extra
;
psum
=
a
.
as
<
SumExprNode
>
();
CHECK
(
psum
!=
nullptr
);
}
}
}
...
...
src/arithmetic/compute_expr.h
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -27,8 +27,8 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits>
#include <algorithm>
namespace
tvm
{
namespace
arith
{
...
...
@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
template
<>
inline
Expr
ComputeExpr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
return
HalideIR
::
Internal
::
Interval
::
make_
max
(
a
,
b
);
return
max
(
a
,
b
);
}
template
<>
inline
Expr
ComputeExpr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
return
HalideIR
::
Internal
::
Interval
::
make_
min
(
a
,
b
);
return
min
(
a
,
b
);
}
template
<
typename
Op
>
...
...
src/arithmetic/const_fold.h
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
min
(
fa
->
value
,
fb
->
value
));
});
if
(
a
.
same_as
(
b
))
return
a
;
return
Expr
();
}
...
...
@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
max
(
fa
->
value
,
fb
->
value
));
});
if
(
a
.
same_as
(
b
))
return
a
;
return
Expr
();
}
...
...
@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
return
Expr
();
}
/*! \brief Helper namespace for symbolic value limits */
struct
SymbolicLimits
{
/*! \brief positive infinity */
static
Expr
pos_inf_
;
/*! \brief negative infinity */
static
Expr
neg_inf_
;
};
/*!
* \brief Opaque expression representing positive infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return positive infinity.
*/
inline
Expr
pos_inf
()
{
return
SymbolicLimits
::
pos_inf_
;
}
/*!
* \brief Check if value is positive infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline
bool
is_pos_inf
(
const
Expr
&
value
)
{
return
value
.
same_as
(
SymbolicLimits
::
pos_inf_
);
}
/*!
* \brief Opaque expression representing negative infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return negative infinity.
*/
inline
Expr
neg_inf
()
{
return
SymbolicLimits
::
neg_inf_
;
}
/*!
* \brief Check if value is negative infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline
bool
is_neg_inf
(
const
Expr
&
value
)
{
return
value
.
same_as
(
SymbolicLimits
::
neg_inf_
);
}
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
src/arithmetic/detect_linear_equation.cc
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -19,8 +19,8 @@
/*!
* Copyright (c) 2017 by Contributors
* \file
bound_deducer
.cc
* \brief Utility to de
duce bound of expression
* \file
detect_linear_equation
.cc
* \brief Utility to de
tect patterns in the expression.
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
...
...
src/arithmetic/int_op_overflow.h
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
src/arithmetic/int_set.cc
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -18,201 +18,55 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h>
#include <tvm/api_registry.h>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "
compute_expr
.h"
#include "
int_set_internal
.h"
#include "
int_set
.h"
#include "
pattern_match
.h"
namespace
tvm
{
namespace
arith
{
using
HalideIR
::
Internal
::
Interval
;
using
namespace
ir
;
inline
IntSet
IntSet
::
cover_interval
()
const
{
if
((
*
this
).
as
<
IntervalSet
>
())
return
*
this
;
const
StrideSet
*
s
=
(
*
this
).
as
<
StrideSet
>
();
if
(
s
)
{
CHECK_NE
(
s
->
extents
.
size
(),
0U
);
Expr
max
=
s
->
base
.
max
;
for
(
size_t
i
=
0
;
i
<
s
->
extents
.
size
();
++
i
)
{
max
=
max
+
s
->
extents
[
i
]
*
s
->
strides
[
i
]
-
s
->
strides
[
i
];
}
return
IntervalSet
::
make
(
s
->
base
.
min
,
Simplify
(
max
));
}
LOG
(
FATAL
)
<<
"cannot convert set "
<<
(
*
this
)
->
type_key
()
<<
" to interval"
;
return
IntSet
::
everything
();
}
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntSet
temp
;
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
if
(
s_int
==
nullptr
)
{
temp
=
this
->
cover_interval
();
s_int
=
temp
.
as
<
IntervalSet
>
();
}
if
(
s_int
->
i
.
is_bounded
())
{
return
Range
::
make_by_min_extent
(
s_int
->
i
.
min
,
Simplify
(
s_int
->
i
.
max
+
1
-
s_int
->
i
.
min
));
}
return
max_range
;
}
Expr
IntSet
::
min
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
);
return
s_int
->
i
.
min
;
}
Expr
IntSet
::
max
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
);
return
s_int
->
i
.
max
;
}
bool
IntSet
::
is_nothing
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_empty
());
}
bool
IntSet
::
is_everything
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_everything
());
}
Expr
SymbolicLimits
::
pos_inf_
=
Var
(
"pos_inf"
,
Handle
());
Expr
SymbolicLimits
::
neg_inf_
=
Var
(
"neg_inf"
,
Handle
());
bool
IntSet
::
is_single_point
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_single_point
());
IntervalSet
::
IntervalSet
(
Expr
min_value
,
Expr
max_value
)
{
auto
node
=
make_node
<
IntervalSetNode
>
();
node
->
min_value
=
std
::
move
(
min_value
);
node
->
max_value
=
std
::
move
(
max_value
);
node_
=
std
::
move
(
node
);
}
bool
IntSet
::
can_prove_positive
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
is_positive_const
(
ir
::
Simplify
(
s_int
->
i
.
min
)));
IntervalSet
MakeIntervalSet
(
Expr
min_value
,
Expr
max_value
)
{
return
IntervalSet
(
min_value
,
max_value
);
}
bool
IntSet
::
can_prove_negative
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
is_negative_const
(
ir
::
Simplify
(
s_int
->
i
.
max
)));
}
TVM_REGISTER_API
(
"arith._make_IntervalSet"
)
.
set_body_typed
(
MakeIntervalSet
);
bool
IntSet
::
can_prove_non_positive
()
const
{
if
(
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
())
{
auto
max
=
ir
::
Simplify
(
s_int
->
i
.
max
);
return
is_zero
(
max
)
||
is_negative_const
(
max
);
}
return
false
;
}
bool
IntSet
::
can_prove_non_negative
()
const
{
if
(
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
())
{
// Any reason why we should or should not use can_prove() to implement
// these functions?
auto
min
=
ir
::
Simplify
(
s_int
->
i
.
min
);
return
is_zero
(
min
)
||
is_positive_const
(
min
);
}
return
false
;
}
SignType
IntSet
::
sign_type
()
const
{
if
(
can_prove_positive
())
{
return
kPositive
;
}
else
if
(
can_prove_negative
())
{
return
kNegative
;
}
else
if
(
is_single_point
()
&&
is_zero
(
point_value
()))
{
return
kZero
;
IntervalSet
Intersect
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
Expr
max_value
=
min
(
a
->
max_value
,
b
->
max_value
);
Expr
min_value
=
max
(
a
->
min_value
,
b
->
min_value
);
if
((
max_value
.
type
().
is_int
()
||
max_value
.
type
().
is_uint
())
&&
(
min_value
.
type
().
is_int
()
||
min_value
.
type
().
is_uint
())
&&
analyzer
->
CanProveGreaterEqual
(
min_value
-
max_value
,
1
))
{
return
IntervalSet
::
Empty
();
}
else
{
return
kUnknown
;
}
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
&&
s_int
->
i
.
is_single_point
());
return
s_int
->
i
.
min
;
}
IntSet
IntSet
::
nothing
()
{
return
IntervalSet
::
make
(
Interval
::
nothing
());
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
make
(
Interval
::
everything
());
}
IntSet
IntSet
::
single_point
(
Expr
x
)
{
return
IntervalSet
::
make
(
Interval
::
single_point
(
x
));
}
IntSet
IntSet
::
range
(
Range
r
)
{
// must make sure it can be matched back by MatchRange.
if
(
is_one
(
r
->
extent
))
{
return
IntSet
::
single_point
(
r
->
min
);
}
if
(
is_positive_const
(
r
->
extent
)
&&
is_const
(
r
->
min
))
{
return
IntervalSet
::
make
(
r
->
min
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
extent
,
r
->
min
),
1
));
}
return
IntervalSet
::
make
(
r
->
min
,
(
r
->
extent
+
r
->
min
)
-
1
);
}
IntSet
IntSet
::
interval
(
Expr
min
,
Expr
max
)
{
if
(
min
.
same_as
(
max
))
{
return
IntSet
::
single_point
(
min
);
}
return
IntervalSet
::
make
(
min
,
max
);
}
inline
bool
prove_equal
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
// Check if a is created from b.
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
return
prove_equal
(
i
.
min
,
b
->
min
)
&&
prove_equal
(
i
.
max
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
b
->
extent
,
b
->
min
),
1
));
}
inline
bool
MatchPoint
(
const
IntSet
&
a
,
const
Expr
&
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
return
i
.
is_single_point
()
&&
i
.
min
.
same_as
(
b
);
}
IntSet
Union
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Interval
x
=
sets
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
IntSet
s
=
sets
[
i
].
cover_interval
();
const
Interval
&
y
=
s
.
as
<
IntervalSet
>
()
->
i
;
x
.
include
(
y
);
return
IntervalSet
(
min_value
,
max_value
);
}
x
.
max
=
ir
::
Simplify
(
x
.
max
);
x
.
min
=
ir
::
Simplify
(
x
.
min
);
return
IntervalSet
::
make
(
x
);
}
IntSet
Intersect
(
const
Array
<
IntSet
>&
sets
)
{
Interval
x
=
sets
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
Interval
y
=
sets
[
i
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
x
=
Interval
::
make_intersection
(
x
,
y
);
}
return
IntervalSet
::
make
(
x
);
IntervalSet
Union
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
Expr
max_value
=
max
(
a
->
max_value
,
b
->
max_value
);
Expr
min_value
=
min
(
a
->
min_value
,
b
->
min_value
);
return
IntervalSet
(
min_value
,
max_value
);
}
// type traits
...
...
@@ -227,407 +81,623 @@ struct is_logical_op {
static const bool value = true; \
};
// interval related.
template
<
typename
OP
>
inline
IntSet
CombineInterval
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
OP
>
(
a
.
min
,
b
.
min
));
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval "
<<
OP
::
_type_key
;
return
IntSet
::
everything
();
TVM_DECLARE_LOGICAL_OP
(
And
);
TVM_DECLARE_LOGICAL_OP
(
Or
);
TVM_DECLARE_LOGICAL_OP
(
EQ
);
TVM_DECLARE_LOGICAL_OP
(
NE
);
TVM_DECLARE_LOGICAL_OP
(
GE
);
TVM_DECLARE_LOGICAL_OP
(
GT
);
TVM_DECLARE_LOGICAL_OP
(
LE
);
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
/*!
* \brief Combine two interval set under arithmetic operations.
* \note this can possibly relax the set.
*/
template
<
typename
Op
>
inline
IntervalSet
Combine
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
Expr
res
=
TryConstFold
<
Op
>
(
a
->
min_value
,
b
->
min_value
);
if
(
!
res
.
defined
())
res
=
Op
::
make
(
a
->
min_value
,
b
->
min_value
);
return
IntervalSet
::
SinglePoint
(
res
);
}
if
(
is_logical_op
<
Op
>::
value
)
{
return
IntervalSet
(
make_const
(
a
->
min_value
.
type
(),
0
),
make_const
(
a
->
min_value
.
type
(),
1
));
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
a
->
IsEverything
())
return
a
;
if
(
b
->
IsEverything
())
return
b
;
return
IntervalSet
::
Everything
();
}
template
<>
inline
IntSet
CombineInterval
<
Add
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
));
}
Interval
r
=
Interval
::
everything
();
if
(
a
.
has_lower_bound
()
&&
b
.
has_lower_bound
())
{
r
.
min
=
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
);
}
if
(
a
.
has_upper_bound
()
&&
b
.
has_upper_bound
())
{
r
.
max
=
ComputeExpr
<
Add
>
(
a
.
max
,
b
.
max
);
}
return
IntervalSet
::
make
(
r
);
inline
IntervalSet
Combine
<
ir
::
Add
>
(
Analyzer
*
analyer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
+
b
->
min_value
);
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
Expr
min_value
=
a
->
HasLowerBound
()
&&
b
->
HasLowerBound
()
?
a
->
min_value
+
b
->
min_value
:
neg_inf
();
Expr
max_value
=
a
->
HasUpperBound
()
&&
b
->
HasUpperBound
()
?
a
->
max_value
+
b
->
max_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
template
<>
inline
IntSet
CombineInterval
<
Sub
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
min
));
inline
IntervalSet
Combine
<
ir
::
Sub
>
(
Analyzer
*
analyer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
-
b
->
min_value
);
}
Interval
r
=
Interval
::
everything
();
if
(
a
.
has_lower_bound
()
&&
b
.
has_upper_bound
())
{
r
.
min
=
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
max
);
}
if
(
a
.
has_upper_bound
()
&&
b
.
has_lower_bound
())
{
r
.
max
=
ComputeExpr
<
Sub
>
(
a
.
max
,
b
.
min
);
}
return
IntervalSet
::
make
(
r
);
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
Expr
min_value
=
a
->
HasLowerBound
()
&&
b
->
HasUpperBound
()
?
a
->
min_value
-
b
->
max_value
:
neg_inf
();
Expr
max_value
=
a
->
HasUpperBound
()
&&
b
->
HasLowerBound
()
?
a
->
max_value
-
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
template
<>
inline
IntSet
CombineInterval
<
Mul
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
));
}
if
(
a
.
is_single_point
()
&&
!
b
.
is_single_point
())
{
inline
IntervalSet
Combine
<
ir
::
Mul
>
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
*
b
->
min_value
);
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
a
->
IsSinglePoint
())
{
std
::
swap
(
a
,
b
);
}
if
(
b
.
is_single_p
oint
())
{
if
(
is_zero
(
b
.
min
))
return
IntSet
::
single_point
(
0
)
;
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
)
;
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
// no relaxation is needed in here due to set is inclusive
// TODO(tqchen): consider convert to StrideSet.
if
(
is_positive_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounde
d
())
{
if
(
b
->
IsSingleP
oint
())
{
if
(
is_zero
(
b
->
min_value
))
return
b
;
if
(
is_one
(
b
->
min_value
))
return
a
;
if
(
analyzer
->
CanProveGreaterEqual
(
b
->
min_value
,
0
))
{
Expr
min_value
=
a
->
HasLowerBound
()
?
a
->
min_value
*
b
->
min_value
:
neg_inf
()
;
Expr
max_value
=
a
->
HasUpperBound
()
?
a
->
max_value
*
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
analyzer
->
CanProveGreaterEqual
(
-
b
->
min_value
,
1
))
{
Expr
min_value
=
a
->
HasUpperBound
()
?
a
->
max_value
*
b
->
min_value
:
neg_inf
(
);
Expr
max_value
=
a
->
HasLowerBound
()
?
a
->
min_value
*
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
a
->
HasUpperBound
()
&&
a
->
HasLowerBoun
d
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
Expr
sign
=
b
->
min_value
>=
make_zero
(
b
->
min_value
.
type
().
element_of
());
Expr
e1
=
a
->
min_value
*
b
->
min_value
;
Expr
e2
=
a
->
max_value
*
b
->
min_value
;
return
IntervalSet
(
Select
::
make
(
sign
,
e1
,
e2
),
Select
::
make
(
sign
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
return
Int
Set
::
e
verything
();
D
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
return
Int
ervalSet
::
E
verything
();
}
template
<>
inline
IntSet
CombineInterval
<
Div
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Div
>
(
a
.
min
,
b
.
min
));
}
if
(
b
.
is_single_point
())
{
if
(
is_zero
(
b
.
min
))
{
inline
IntervalSet
Combine
<
ir
::
Div
>
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
/
b
->
min_value
);
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
b
->
IsSinglePoint
())
{
if
(
is_zero
(
b
->
min_value
))
{
LOG
(
FATAL
)
<<
"Divide by zero in CombineInterval Div"
;
}
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
);
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Div
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Div
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
if
(
is_one
(
b
->
min_value
))
return
a
;
// no relaxation is needed in here due to set is inclusive
if
(
is_positive_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
}
else
if
(
is_negative_const
(
b
.
min
))
{
return
IntervalSet
::
make
(
e2
,
e1
);
}
else
if
(
a
.
is_bounded
())
{
if
(
analyzer
->
CanProveGreaterEqual
(
b
->
min_value
,
0
))
{
Expr
min_value
=
a
->
HasLowerBound
()
?
a
->
min_value
/
b
->
min_value
:
neg_inf
();
Expr
max_value
=
a
->
HasUpperBound
()
?
a
->
max_value
/
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
analyzer
->
CanProveGreaterEqual
(
-
b
->
min_value
,
1
))
{
Expr
min_value
=
a
->
HasUpperBound
()
?
a
->
max_value
/
b
->
min_value
:
neg_inf
();
Expr
max_value
=
a
->
HasLowerBound
()
?
a
->
min_value
/
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
a
->
HasUpperBound
()
&&
a
->
HasLowerBound
())
{
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
Expr
sign
=
b
->
min_value
>=
make_zero
(
b
->
min_value
.
type
().
element_of
());
Expr
e1
=
a
->
min_value
/
b
->
min_value
;
Expr
e2
=
a
->
max_value
/
b
->
min_value
;
return
IntervalSet
(
Select
::
make
(
sign
,
e1
,
e2
),
Select
::
make
(
sign
,
e2
,
e1
));
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
return
Int
Set
::
e
verything
();
D
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
return
Int
ervalSet
::
E
verything
();
}
template
<>
inline
IntSet
CombineInterval
<
Mod
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Mod
>
(
a
.
min
,
b
.
min
));
inline
IntervalSet
Combine
<
ir
::
Mod
>
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
%
b
->
min_value
);
}
if
(
b
.
is_single_point
())
{
Expr
divisor
=
b
.
min
;
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
b
->
IsSinglePoint
())
{
const
Expr
&
divisor
=
b
->
min_value
;
if
(
is_zero
(
divisor
))
{
LOG
(
FATAL
)
<<
"Modular by zero in CombineInterval Mod"
;
}
return
IntervalSet
::
make
(
make_zero
(
divisor
.
type
()),
divisor
-
1
);
// We need to add more bound constraints throughout the code.
// The logic below assumes a is non-negative, which usually
// is the case of our application.
// TODO(tqchen): add bound constraints for a.
if
(
analyzer
->
CanProveGreaterEqual
(
divisor
,
0
))
{
return
IntervalSet
(
make_zero
(
divisor
.
type
()),
divisor
-
1
);
}
else
{
Expr
bound
=
abs
(
divisor
)
-
1
;
return
IntervalSet
(
-
bound
,
bound
);
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mod"
;
return
IntSet
::
everything
();
DLOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mod"
;
return
IntervalSet
::
Everything
();
}
template
<>
inline
IntSet
CombineInterval
<
Max
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Max
>
(
a
.
min
,
b
.
min
));
inline
IntervalSet
Combine
<
ir
::
Max
>
(
Analyzer
*
analzyer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
max
(
a
->
min_value
,
b
->
min_value
));
}
return
IntervalSet
::
make
(
Interval
::
make_max
(
a
.
min
,
b
.
min
),
Interval
::
make_max
(
a
.
max
,
b
.
max
));
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
return
IntervalSet
(
max
(
a
->
min_value
,
b
->
min_value
),
max
(
a
->
max_value
,
b
->
max_value
));
}
template
<>
inline
IntSet
CombineInterval
<
Min
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Min
>
(
a
.
min
,
b
.
min
));
inline
IntervalSet
Combine
<
ir
::
Min
>
(
Analyzer
*
analzyer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
min
(
a
->
min_value
,
b
->
min_value
));
}
return
IntervalSet
::
make
(
Interval
::
make_min
(
a
.
min
,
b
.
min
),
Interval
::
make_min
(
a
.
max
,
b
.
max
));
}
template
<
typename
OP
>
inline
IntSet
CombineInterval_
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval
<
OP
>
(
a
.
as
<
IntervalSet
>
()
->
i
,
b
.
as
<
IntervalSet
>
()
->
i
);
}
// stride related
inline
IntSet
AsStrideSet
(
IntSet
a
)
{
if
(
a
.
as
<
StrideSet
>
())
return
a
;
const
IntervalSet
*
s
=
a
.
as
<
IntervalSet
>
();
CHECK
(
s
->
i
.
is_bounded
());
NodePtr
<
StrideSet
>
n
=
make_node
<
StrideSet
>
();
n
->
base
=
s
->
i
;
return
IntSet
(
n
);
}
template
<
typename
OP
>
inline
IntSet
CombineSets
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
.
cover_interval
());
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
return
IntervalSet
(
min
(
a
->
min_value
,
b
->
min_value
),
min
(
a
->
max_value
,
b
->
max_value
));
}
template
<>
inline
IntSet
CombineSets
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
is_zero
(
a_int
->
i
.
min
))
return
b
;
if
(
b_int
&&
is_zero
(
b_int
->
i
.
min
))
return
a
;
a
=
AsStrideSet
(
a
);
b
=
AsStrideSet
(
b
);
const
StrideSet
*
a_stride
=
a
.
as
<
StrideSet
>
();
const
StrideSet
*
b_stride
=
b
.
as
<
StrideSet
>
();
auto
n
=
make_node
<
StrideSet
>
(
*
a_stride
);
for
(
size_t
i
=
0
;
i
<
b_stride
->
extents
.
size
();
++
i
)
{
n
->
extents
.
push_back
(
b_stride
->
extents
[
i
]);
n
->
strides
.
push_back
(
b_stride
->
strides
[
i
]);
}
n
->
base
=
CombineInterval
<
Add
>
(
a_stride
->
base
,
b_stride
->
base
).
as
<
IntervalSet
>
()
->
i
;
return
IntSet
(
n
);
}
inline
IntSet
NegateSet
(
IntSet
a
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
a_int
)
{
if
(
a_int
->
i
.
is_single_point
())
{
return
IntSet
::
single_point
(
-
a_int
->
i
.
min
);
}
else
{
Interval
r
=
Interval
::
everything
();
if
(
a_int
->
i
.
has_upper_bound
())
{
r
.
min
=
-
(
a_int
->
i
.
max
);
}
if
(
a_int
->
i
.
has_lower_bound
())
{
r
.
max
=
-
(
a_int
->
i
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
}
else
{
return
NegateSet
(
a
.
cover_interval
());
// internal helper function to get an interval set
IntervalSet
ToIntervalSet
(
IntSet
set
)
{
if
(
auto
*
node
=
set
.
as
<
IntervalSetNode
>
())
{
return
GetRef
<
IntervalSet
>
(
node
);
}
DLOG
(
INFO
)
<<
"cannot resolve int set "
<<
set
;
return
IntervalSet
::
Everything
();
}
template
<>
inline
IntSet
CombineSets
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
CombineSets
<
Add
>
(
a
,
NegateSet
(
b
));
}
TVM_DECLARE_LOGICAL_OP
(
And
);
TVM_DECLARE_LOGICAL_OP
(
Or
);
TVM_DECLARE_LOGICAL_OP
(
EQ
);
TVM_DECLARE_LOGICAL_OP
(
NE
);
TVM_DECLARE_LOGICAL_OP
(
GE
);
TVM_DECLARE_LOGICAL_OP
(
GT
);
TVM_DECLARE_LOGICAL_OP
(
LE
);
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
using
namespace
ir
;
// generic combine operations of two sets
template
<
typename
OP
>
inline
IntSet
Combine
(
const
IntSet
&
a
,
const
IntSet
&
b
)
{
if
(
is_logical_op
<
OP
>::
value
)
{
return
IntervalSet
::
make
(
0
,
1
);
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
class
IntervalSetEvaluator
:
public
ExprFunctor
<
IntervalSet
(
const
Expr
&
)
>
{
public
:
IntervalSetEvaluator
(
Analyzer
*
analyzer
,
const
Map
<
Var
,
IntSet
>&
dom_map
,
bool
eval_vec
=
false
)
:
analyzer_
(
analyzer
),
dom_map_
(
dom_map
),
eval_vec_
(
eval_vec
)
{
}
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
a_int
->
i
.
is_everything
())
return
a
;
if
(
b_int
&&
b_int
->
i
.
is_everything
())
return
b
;
if
(
a_int
&&
b_int
)
{
return
CombineInterval
<
OP
>
(
a_int
->
i
,
b_int
->
i
);
IntervalSet
Eval
(
const
Expr
&
val
)
{
return
this
->
VisitExpr
(
val
);
}
if
(
a_int
&&
!
(
a_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
,
b
.
cover_interval
());
IntervalSet
VisitExpr_
(
const
IntImm
*
op
)
final
{
return
IntervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
));
}
if
(
b_int
&&
!
(
b_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
);
IntervalSet
VisitExpr_
(
const
UIntImm
*
op
)
final
{
return
IntervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
));
}
return
CombineSets
<
OP
>
(
a
,
b
);
}
class
IntSetEvaluator
:
public
ExprFunctor
<
IntSet
(
const
Expr
&
,
const
Expr
&
)
>
{
public
:
explicit
IntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
,
bool
eval_vec
=
false
)
:
dom_map_
(
dom_map
),
eval_vec_
(
eval_vec
)
{}
// Evaluate.
IntSet
Eval
(
const
Expr
&
e
)
{
return
this
->
VisitExpr
(
e
,
e
);
}
IntSet
VisitExpr_
(
const
IntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
}
IntSet
VisitExpr_
(
const
UIntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
}
IntSet
VisitExpr_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
dom_map_
.
find
(
op
);
IntervalSet
VisitExpr_
(
const
Variable
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
auto
it
=
dom_map_
.
find
(
var
);
if
(
it
!=
dom_map_
.
end
())
{
return
it
->
second
;
return
ToIntervalSet
((
*
it
).
second
)
;
}
else
{
return
Int
Set
::
single_point
(
e
);
return
Int
ervalSet
::
SinglePoint
(
var
);
}
}
IntSet
VisitExpr_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Add
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Sub
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Sub
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Mul
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Mul
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Div
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Div
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Mod
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Mod
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Min
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Min
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Max
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Max
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
EQ
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
EQ
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
NE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
NE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
LT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
LT
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
LE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
LE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
GT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
GT
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
GE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
GE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
And
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
And
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Or
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Or
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
IntSet
VisitExpr_
(
const
Ramp
*
op
,
const
Expr
&
e
)
final
{
IntervalSet
VisitExpr_
(
const
Ramp
*
op
)
final
{
CHECK
(
eval_vec_
);
IntSet
base
=
Eval
(
op
->
base
);
int
v
stride
;
if
(
GetConstInt
(
op
->
stride
,
&
v
stride
))
{
Int
erval
Set
base
=
Eval
(
op
->
base
);
PVar
<
Integer
>
stride
;
if
(
stride
.
Match
(
op
->
stride
))
{
Type
t
=
op
->
base
.
type
();
if
(
vstride
>
0
)
{
int64_t
vstride
=
stride
.
Eval
()
->
value
;
if
(
vstride
>
0
)
{
return
Combine
<
Add
>
(
analyzer_
,
base
,
IntSet
::
interval
(
make_zero
(
t
),
make_const
(
t
,
vstride
*
op
->
lanes
-
1
)));
IntervalSet
(
make_zero
(
t
),
make_const
(
t
,
vstride
*
op
->
lanes
-
1
)));
}
else
{
return
Combine
<
Add
>
(
analyzer_
,
base
,
IntSet
::
interval
(
make_const
(
t
,
vstride
*
op
->
lanes
+
1
),
make_zero
(
t
)));
IntervalSet
(
make_const
(
t
,
vstride
*
op
->
lanes
+
1
),
make_zero
(
t
)));
}
}
LOG
(
WARNING
)
<<
"cannot evaluate set on expression "
<<
e
;
return
Int
Set
::
e
verything
();
DLOG
(
WARNING
)
<<
"cannot evaluate set on expression "
<<
GetRef
<
Expr
>
(
op
)
;
return
Int
ervalSet
::
E
verything
();
}
IntSet
VisitExpr_
(
const
Broadcast
*
op
,
const
Expr
&
e
)
final
{
IntervalSet
VisitExpr_
(
const
Broadcast
*
op
)
final
{
CHECK
(
eval_vec_
);
return
Eval
(
op
->
value
);
return
VisitExpr
(
op
->
value
);
}
IntSet
VisitExpr_
(
const
Select
*
op
,
const
Expr
&
e
)
final
{
IntSet
true_set
=
this
->
Eval
(
op
->
true_value
);
IntSet
false_set
=
this
->
Eval
(
op
->
false_value
);
return
Union
({
false_set
,
true_set
});
IntervalSet
VisitExpr_
(
const
Select
*
op
)
final
{
IntervalSet
true_set
=
this
->
Eval
(
op
->
true_value
);
IntervalSet
false_set
=
this
->
Eval
(
op
->
false_value
);
return
Union
(
analyzer_
,
false_set
,
true_set
);
}
IntSet
VisitExprDefault_
(
const
Node
*
op
,
const
Expr
&
e
)
final
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
e
->
type_key
();
return
IntSet
::
everything
();
IntervalSet
VisitExprDefault_
(
const
Node
*
op
)
final
{
DLOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
op
->
type_key
();
return
IntervalSet
::
Everything
();
}
private
:
// whether set is exactly single point that equals value.
bool
MatchPoint
(
const
IntervalSet
&
set
,
const
Expr
&
value
)
const
{
return
set
->
min_value
.
same_as
(
value
)
&&
set
->
max_value
.
same_as
(
value
);
}
template
<
typename
T
>
inline
Int
Set
Binary
(
const
T
*
op
,
const
Expr
&
e
)
{
IntSet
a
=
this
->
Eval
(
op
->
a
);
IntSet
b
=
this
->
Eval
(
op
->
b
);
inline
Int
ervalSet
VisitBinaryExpr_
(
const
T
*
op
)
{
Int
erval
Set
a
=
this
->
Eval
(
op
->
a
);
Int
erval
Set
b
=
this
->
Eval
(
op
->
b
);
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
return
Int
Set
::
single_point
(
e
);
return
Int
ervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
)
);
}
return
Combine
<
T
>
(
a
,
b
);
return
Combine
<
T
>
(
a
nalyzer_
,
a
,
b
);
}
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map_
;
Analyzer
*
analyzer_
;
const
Map
<
Var
,
IntSet
>&
dom_map_
;
bool
eval_vec_
{
false
};
};
class
IntSetAnalyzer
::
Impl
{
public
:
explicit
Impl
(
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{
}
IntSet
Eval
(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
const
{
return
IntervalSetEvaluator
(
analyzer_
,
dom_map
).
Eval
(
expr
);
}
private
:
Analyzer
*
analyzer_
;
};
IntSetAnalyzer
::
IntSetAnalyzer
(
Analyzer
*
parent
)
:
impl_
(
new
Impl
(
parent
))
{
}
IntSetAnalyzer
::~
IntSetAnalyzer
()
{
delete
impl_
;
}
IntSet
IntSetAnalyzer
::
operator
()(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
return
impl_
->
Eval
(
expr
,
dom_map
);
}
// Quickly adapt to IntSet interface
// TODO(tqchen): revisit IntSet interface as well.
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntSet
temp
;
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
!=
nullptr
);
if
(
s_int
->
HasUpperBound
()
&&
s_int
->
HasLowerBound
())
{
return
Range
::
make_by_min_extent
(
s_int
->
min_value
,
Simplify
(
s_int
->
max_value
+
1
-
s_int
->
min_value
));
}
return
max_range
;
}
Expr
IntSet
::
min
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
);
return
s_int
->
min_value
;
}
Expr
IntSet
::
max
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
);
return
s_int
->
max_value
;
}
bool
IntSet
::
is_nothing
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsEmpty
());
}
bool
IntSet
::
is_everything
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsEverything
());
}
bool
IntSet
::
is_single_point
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsSinglePoint
());
}
bool
IntSet
::
can_prove_positive
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
is_positive_const
(
ir
::
Simplify
(
s_int
->
min_value
)));
}
bool
IntSet
::
can_prove_negative
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
is_negative_const
(
ir
::
Simplify
(
s_int
->
max_value
)));
}
bool
IntSet
::
can_prove_non_positive
()
const
{
if
(
const
auto
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
())
{
auto
max
=
ir
::
Simplify
(
s_int
->
max_value
);
return
is_zero
(
max
)
||
is_negative_const
(
max
);
}
return
false
;
}
bool
IntSet
::
can_prove_non_negative
()
const
{
if
(
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
())
{
auto
min
=
ir
::
Simplify
(
s_int
->
min_value
);
return
is_zero
(
min
)
||
is_positive_const
(
min
);
}
return
false
;
}
SignType
IntSet
::
sign_type
()
const
{
if
(
can_prove_positive
())
{
return
kPositive
;
}
else
if
(
can_prove_negative
())
{
return
kNegative
;
}
else
if
(
is_single_point
()
&&
is_zero
(
point_value
()))
{
return
kZero
;
}
else
{
return
kUnknown
;
}
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
&&
s_int
->
IsSinglePoint
());
return
s_int
->
min_value
;
}
IntSet
IntSet
::
nothing
()
{
return
IntervalSet
::
Empty
();
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
Everything
();
}
IntSet
IntSet
::
single_point
(
Expr
x
)
{
return
IntervalSet
::
SinglePoint
(
x
);
}
IntSet
IntSet
::
interval
(
Expr
min
,
Expr
max
)
{
if
(
min
.
same_as
(
max
))
{
return
IntSet
::
single_point
(
min
);
}
return
IntervalSet
(
min
,
max
);
}
// Range related code
inline
bool
ProveEqual
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
IntSet
IntSet
::
range
(
Range
r
)
{
// must make sure it can be matched back by MatchRange.
if
(
is_one
(
r
->
extent
))
{
return
IntSet
::
single_point
(
r
->
min
);
}
return
IntervalSet
(
r
->
min
,
r
->
extent
+
r
->
min
-
1
);
}
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSetNode
*
a_int
=
a
.
as
<
IntervalSetNode
>
();
if
(
!
a_int
)
return
false
;
return
ProveEqual
(
a_int
->
min_value
,
b
->
min
)
&&
ProveEqual
(
a_int
->
max_value
,
b
->
extent
+
b
->
min
-
1
);
}
IntSet
Union
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Analyzer
ana
;
IntervalSet
x
=
ToIntervalSet
(
sets
[
0
]);
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
x
=
Union
(
&
ana
,
x
,
ToIntervalSet
(
sets
[
i
]));
}
return
IntervalSet
(
ir
::
Simplify
(
x
->
min_value
),
ir
::
Simplify
(
x
->
max_value
));
}
IntSet
Intersect
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Analyzer
ana
;
IntervalSet
x
=
ToIntervalSet
(
sets
[
0
]);
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
x
=
Intersect
(
&
ana
,
x
,
ToIntervalSet
(
sets
[
i
]));
}
return
IntervalSet
(
ir
::
Simplify
(
x
->
min_value
),
ir
::
Simplify
(
x
->
max_value
));
}
Map
<
Var
,
IntSet
>
ConvertDomMap
(
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
Map
<
Var
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
.
Set
(
kv
.
first
->
var
,
kv
.
second
);
}
return
dmap
;
}
Map
<
Var
,
IntSet
>
ConvertDomMap
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
Map
<
Var
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
.
Set
(
GetRef
<
Var
>
(
kv
.
first
),
kv
.
second
);
}
return
dmap
;
}
IntSet
EvalSet
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
return
IntSetEvaluator
(
dom_map
,
false
).
Eval
(
e
);
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
return
IntervalSetEvaluator
(
&
ana
,
dom_map
,
false
).
Eval
(
e
);
}
IntSet
IntSet
::
vector
(
Expr
x
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
return
IntSetEvaluator
(
dmap
,
true
).
Eval
(
x
);
Analyzer
ana
;
Map
<
Var
,
IntSet
>
dmap
;
return
IntervalSetEvaluator
(
&
ana
,
dmap
,
true
).
Eval
(
x
);
}
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
EvalSet
(
e
,
dmap
);
return
EvalSet
(
e
,
ConvertDomMap
(
dom_map
));
}
IntSet
EvalSet
(
Range
r
,
IntSet
EvalSet
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
(
dom_map
);
IntSet
min_set
=
m
.
Eval
(
r
->
min
).
cover_interval
();
return
EvalSet
(
e
,
ConvertDomMap
(
dom_map
));
}
IntSet
EvalSet
(
Range
r
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
IntervalSetEvaluator
m
(
&
ana
,
dom_map
);
IntervalSet
min_set
=
m
.
Eval
(
r
->
min
);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
Expr
sum
=
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
min
,
r
->
extent
),
1
);
IntSet
max_set
=
m
.
Eval
(
Simplify
(
sum
)).
cover_interval
();
const
Interval
&
ni
=
min_set
.
as
<
IntervalSet
>
()
->
i
;
const
Interval
&
xi
=
max_set
.
as
<
IntervalSet
>
()
->
i
;
if
(
!
ni
.
has_lower_bound
())
return
IntSet
::
everything
();
if
(
!
xi
.
has_upper_bound
())
return
IntSet
::
everything
();
return
IntervalSet
::
make
(
ni
.
min
,
xi
.
max
);
Expr
sum
=
r
->
min
+
r
->
extent
-
1
;
IntervalSet
max_set
=
m
.
Eval
(
Simplify
(
sum
));
if
(
!
min_set
->
HasLowerBound
())
return
IntSet
::
everything
();
if
(
!
max_set
->
HasUpperBound
())
return
IntSet
::
everything
();
return
IntervalSet
(
min_set
->
min_value
,
max_set
->
max_value
);
}
IntSet
EvalSet
(
IntSet
s
,
IntSet
EvalSet
(
Range
r
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
(
dom_map
);
s
=
s
.
cover_interval
();
const
IntervalSet
*
s_int
=
s
.
as
<
IntervalSet
>
();
Expr
vmax
=
s_int
->
i
.
has_upper_bound
()
?
m
.
Eval
(
s_int
->
i
.
max
).
cover_interval
().
max
()
:
s_int
->
i
.
max
;
Expr
vmin
=
s_int
->
i
.
has_lower_bound
()
?
m
.
Eval
(
s_int
->
i
.
min
).
cover_interval
().
min
()
:
s_int
->
i
.
min
;
return
IntervalSet
::
make
(
vmin
,
vmax
);
return
EvalSet
(
r
,
ConvertDomMap
(
dom_map
));
}
class
SubExprIntSetEvaluator
:
public
IntSetEvaluator
{
IntSet
EvalSet
(
IntSet
s
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
auto
dmap
=
ConvertDomMap
(
dom_map
);
IntervalSetEvaluator
m
(
&
ana
,
dmap
);
const
IntervalSetNode
*
s_int
=
s
.
as
<
IntervalSetNode
>
();
Expr
vmax
=
s_int
->
HasUpperBound
()
?
m
.
Eval
(
s_int
->
max_value
).
max
()
:
s_int
->
max_value
;
Expr
vmin
=
s_int
->
HasLowerBound
()
?
m
.
Eval
(
s_int
->
min_value
).
min
()
:
s_int
->
min_value
;
return
IntervalSet
(
vmin
,
vmax
);
}
class
SubExprIntervalSetEvaluator
:
public
IntervalSetEvaluator
{
public
:
explicit
SubExprIntSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
:
IntSetEvaluator
(
dom_map
)
{}
explicit
SubExprIntervalSetEvaluator
(
Analyzer
*
analyzer
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
:
IntervalSetEvaluator
(
analyzer
,
dom_map
)
{}
Int
Set
VisitExpr
(
const
Expr
&
n
,
const
Expr
&
e
)
final
{
Int
Set
ret
=
IntSetEvaluator
::
VisitExpr
(
n
,
e
);
Int
ervalSet
VisitExpr
(
const
Expr
&
n
)
final
{
Int
ervalSet
ret
=
IntervalSetEvaluator
::
VisitExpr
(
n
);
expr_map
[
n
]
=
ret
;
return
ret
;
}
...
...
@@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator {
ExprIntSetMap
expr_map
;
};
ExprIntSetMap
EvalSetForEachSubExpr
(
Expr
e
,
ExprIntSetMap
EvalSetForEachSubExpr
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
SubExprIntSetEvaluator
m
(
dom_map
);
Analyzer
ana
;
auto
dmap
=
ConvertDomMap
(
dom_map
);
SubExprIntervalSetEvaluator
m
(
&
ana
,
dmap
);
m
.
Eval
(
e
);
return
m
.
expr_map
;
}
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
EvalSet
(
r
,
dmap
);
return
EvalSet
(
r
,
ConvertDomMap
(
dom_map
));
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntervalSet
>
([](
const
IntervalSet
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"
interval-s
et"
<<
"["
<<
op
->
i
.
min
<<
", "
<<
op
->
i
.
max
<<
']'
;
.
set_dispatch
<
IntervalSet
Node
>
([](
const
IntervalSetNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"
IntervalS
et"
<<
"["
<<
op
->
min_value
<<
", "
<<
op
->
max_value
<<
']'
;
});
}
// namespace arith
}
// namespace tvm
src/arithmetic/int_set.h
0 → 100644
View file @
153417a5
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_set.h
* \brief Internal data structure for integer set.
*/
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <limits>
#include "const_fold.h"
namespace
tvm
{
namespace
arith
{
/*!
* \brief Symbolic interval set.
*
* \note We intentionally keep the internal of IntSet private,
as we might change it later.
*/
class
IntervalSetNode
:
public
IntSetNode
{
public
:
/*! \brief Minimum value in the interval. */
Expr
min_value
;
/*! \brief Maximum value in the interval. */
Expr
max_value
;
// visitor overload.
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"min_value"
,
&
min_value
);
v
->
Visit
(
"max_value"
,
&
max_value
);
}
/*! \return Whether the interval has upper bound. */
bool
HasUpperBound
()
const
{
return
!
is_pos_inf
(
max_value
)
&&
!
IsEmpty
();
}
/*! \return Whether the interval has lower bound. */
bool
HasLowerBound
()
const
{
return
!
is_neg_inf
(
min_value
)
&&
!
IsEmpty
();
}
/*! \return Whether the interval is a single point. */
bool
IsSinglePoint
()
const
{
return
min_value
.
same_as
(
max_value
);
}
/*! \return whether interval represent nothing */
bool
IsEmpty
()
const
{
// during computations, either extreme could occur.
return
is_pos_inf
(
min_value
)
||
is_neg_inf
(
max_value
);
}
/*! \return whether interval represent everything */
bool
IsEverything
()
const
{
return
is_neg_inf
(
min_value
)
&&
is_pos_inf
(
max_value
);
}
static
constexpr
const
char
*
_type_key
=
"arith.IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSetNode
,
IntSetNode
);
};
/*!
* \brief Interval set used for symbolic integer analysis.
* \sa IntervalSetNode
*/
class
IntervalSet
:
public
IntSet
{
public
:
/*!
* \brief Make a new instance of interval set.
* \param min_value The minimum value in the interval.
* \param max_value The maximum value in the interval.
* \return The created set.
*/
TVM_DLL
IntervalSet
(
Expr
min_value
,
Expr
max_value
);
/*!
* \brief Create an IntervalSet that represents a single point.
* \param value The value to be represented.
* \return The result set.
*/
static
IntervalSet
SinglePoint
(
Expr
value
)
{
return
IntervalSet
(
value
,
value
);
}
/*!
* \brief Create an IntervalSet that represents everything.
* \param value The value to be represented.
* \return The result set.
*/
static
IntervalSet
Everything
()
{
return
IntervalSet
(
neg_inf
(),
pos_inf
());
}
/*!
* \brief Create an empty eet.
* \return The result set.
*/
static
IntervalSet
Empty
()
{
return
IntervalSet
(
pos_inf
(),
neg_inf
());
}
TVM_DEFINE_NODE_REF_COW
(
IntervalSetNode
);
TVM_DEFINE_NODE_REF_METHODS
(
IntervalSet
,
IntSet
,
IntervalSetNode
);
};
/*!
* \brief Create union of two IntervalSets.
* \param analyzer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL
IntervalSet
Union
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
);
/*!
* \brief Create insersection of two IntervalSets.
* \param analzyer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL
IntervalSet
Intersect
(
Analyzer
*
analzyer
,
IntervalSet
a
,
IntervalSet
b
);
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_H_
src/arithmetic/int_set_internal.h
deleted
100644 → 0
View file @
9bb16872
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set_internal.h
* \brief Implementations of integer set
*/
#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
namespace
tvm
{
namespace
arith
{
using
HalideIR
::
Internal
::
Interval
;
/*! \brief Set of continuous interval */
struct
IntervalSet
:
public
IntSetNode
{
/*! \brief the internal interval*/
Interval
i
;
static
IntSet
make
(
Interval
i
)
{
NodePtr
<
IntervalSet
>
n
=
make_node
<
IntervalSet
>
();
n
->
i
=
i
;
return
IntSet
(
n
);
}
static
IntSet
make
(
Expr
min
,
Expr
max
)
{
NodePtr
<
IntervalSet
>
n
=
make_node
<
IntervalSet
>
();
n
->
i
.
min
=
min
;
n
->
i
.
max
=
max
;
return
IntSet
(
n
);
}
static
constexpr
const
char
*
_type_key
=
"IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSet
,
IntSetNode
);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct
StrideSet
:
public
IntSetNode
{
/*! \brief the base inetrval */
Interval
base
;
/*! \brief additional extents in positive number */
Array
<
Expr
>
extents
;
/*! \brief additional strides in positive number */
Array
<
Expr
>
strides
;
static
constexpr
const
char
*
_type_key
=
"StrideSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
StrideSet
,
IntSetNode
);
};
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
src/lang/expr_operator.cc
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
return
ir
::
Mod
::
make
(
a
,
b
);
}
Expr
min
(
Expr
a
,
Expr
b
)
{
// inf-aware simplificaiton
using
arith
::
is_pos_inf
;
using
arith
::
is_neg_inf
;
if
(
is_pos_inf
(
a
))
return
b
;
if
(
is_neg_inf
(
a
))
return
a
;
if
(
is_pos_inf
(
b
))
return
a
;
if
(
is_neg_inf
(
b
))
return
b
;
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Min
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
...
...
@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
}
Expr
max
(
Expr
a
,
Expr
b
)
{
// inf-aware simplificaiton
using
arith
::
is_pos_inf
;
using
arith
::
is_neg_inf
;
if
(
is_pos_inf
(
a
))
return
a
;
if
(
is_neg_inf
(
a
))
return
b
;
if
(
is_pos_inf
(
b
))
return
b
;
if
(
is_neg_inf
(
b
))
return
a
;
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Max
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
...
...
src/pass/loop_partition.cc
View file @
153417a5
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -28,7 +28,7 @@
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set
_internal
.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
...
...
@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
GetIntervalAndCondset
(
const
Partition
&
partitions
,
const
arith
::
Interval
&
for_interval
,
const
arith
::
Interval
Set
&
for_interval
,
bool
cond_value
);
inline
Stmt
MakeFor
(
const
Node
*
op
,
Expr
extent
,
Stmt
body
);
...
...
@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
/* Candidate IRs that may be partitioned potentially */
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
hint_map_
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
relax_map_
;
arith
::
Analyzer
analyzer_
;
CandidateSelector
selector
;
};
...
...
@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
// given in the second component provably have value given by cond_value
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
LoopPartitioner
::
GetIntervalAndCondset
(
const
Partition
&
partitions
,
const
arith
::
Interval
&
for_interval
,
const
arith
::
Interval
Set
&
for_interval
,
bool
cond_value
)
{
Array
<
IntSet
>
sets
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
for
(
const
auto
&
kv
:
partitions
)
{
if
(
kv
.
first
.
second
==
cond_value
)
{
arith
::
Interval
interval
=
kv
.
second
.
as
<
arith
::
IntervalSet
>
()
->
i
;
arith
::
Interval
intersection
=
arith
::
Interval
::
make_intersection
(
interval
,
for_interval
);
if
(
!
intersection
.
is_empty
())
{
arith
::
IntervalSet
interval
=
Downcast
<
arith
::
IntervalSet
>
(
kv
.
second
);
arith
::
IntervalSet
intersection
=
arith
::
Intersect
(
&
analyzer_
,
interval
,
for_interval
);
if
(
!
intersection
->
IsEmpty
())
{
sets
.
push_back
(
kv
.
second
);
cond_set
.
insert
(
kv
.
first
.
first
);
}
...
...
@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
max
,
Stmt
body
,
bool
partition_thread_scope
)
{
using
namespace
arith
;
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
finder
.
Visit
(
body
);
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
arith
::
Interval
for_interval
(
min
,
max
);
arith
::
Interval
Set
for_interval
(
min
,
max
);
bool
cond_value
;
IntSet
middle_interval
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
...
...
@@ -478,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
// if such interval doesn't exist, find an interval in which all
// conditions on var are false
std
::
tie
(
middle_interval
,
cond_set
)
=
GetIntervalAndCondset
(
finder
.
partitions
,
for_interval
,
false
);
GetIntervalAndCondset
(
finder
.
partitions
,
for_interval
,
false
);
if
(
middle_interval
.
is_nothing
())
// we couldn't find an interval in which the condintions are provably true or false
// Therefore, we can't partition the loop based on those conds
...
...
@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
cond_value
=
true
;
}
arith
::
Interval
middle_interval_i
=
middle_interval
.
as
<
arith
::
IntervalSet
>
()
->
i
;
IntervalSet
middle_interval_i
=
Downcast
<
IntervalSet
>
(
middle_interval
)
;
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
...
...
@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
body_begin
;
Stmt
pre_stmt
;
bool
pre_stmt_recurse
=
true
;
if
(
middle_interval_i
.
has_lower_b
ound
())
{
if
(
middle_interval_i
->
HasLowerB
ound
())
{
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
if
(
!
can_prove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
...
...
@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
post_doubt_begin
;
Stmt
post_stmt
;
bool
post_stmt_recurse
=
true
;
if
(
middle_interval_i
.
has_upper_b
ound
())
{
if
(
middle_interval_i
->
HasUpperB
ound
())
{
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
if
(
!
can_prove
(
middle_interval
.
max
()
==
max
))
{
// require the extent to be non-negative
...
...
tests/python/unittest/test_arith_deduce_bound.py
0 → 100644
View file @
153417a5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
c_s
=
tvm
.
arith
.
IntervalSet
(
10
,
15
)
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
zero
=
tvm
.
const
(
0
,
"int32"
)
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
# expression containing variable a is on rhs
e1
=
(
c
>
a
*
4
+
b
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
# expression containing variable a is on rhs
e2
=
(
zero
<
tvm
.
max
(
5
,
a
*
4
))
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
def
test_check
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
c_s
=
tvm
.
arith
.
IntervalSet
(
5
,
7
)
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
# no compare operator
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
a
+
b
,
{
b
:
b_s
},
{})
assert
res1
.
is_nothing
()
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
.
astype
(
c
.
dtype
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res2
.
is_nothing
()
# multiple target variable
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
a
*
2
-
a
>
b
,
{
b
:
b_s
},
{})
assert
res2
.
is_nothing
()
def
test_deduce_basic
():
def
test_basic
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
a1
,
a2
)
e0
=
b
+
a
*
coff
+
3
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
<
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
test_basic
(
0
,
4
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
2
,
6
,
4
)
test_basic
(
0
,
4
,
-
4
)
test_basic
(
1
,
5
,
-
4
)
test_basic
(
2
,
6
,
-
4
)
def
test_deduce_complex
():
def
test_complex
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
a1
,
a2
)
e0
=
(
b
*
3
+
a
*
coff
)
*
4
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<=
63
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
<=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>=
63
))
.
value
==
1
test_complex
(
0
,
4
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
2
,
6
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
1
,
5
,
-
4
)
test_complex
(
2
,
6
,
-
4
)
if
__name__
==
"__main__"
:
test_check
()
test_deduce_basic
()
test_deduce_complex
()
tests/python/unittest/test_arith_intset.py
View file @
153417a5
...
...
@@ -16,168 +16,87 @@
# under the License.
import
tvm
class
IntSetChecker
:
def
__init__
(
self
):
self
.
analyzer
=
tvm
.
arith
.
Analyzer
()
def
verify
(
self
,
data
,
dmap
,
expected
):
res
=
self
.
analyzer
.
int_set
(
data
,
dmap
)
def
err_msg
():
return
"
\n
data={}
\n
dmap={}
\n
res={}
\n
expected={}"
.
format
(
data
,
dmap
,
res
,
expected
)
def
equal
(
x
,
y
):
res
=
self
.
analyzer
.
canonical_simplify
(
x
-
y
)
return
tvm
.
ir_pass
.
Equal
(
res
,
0
)
assert
equal
(
res
.
min_value
,
expected
[
0
]),
err_msg
()
assert
equal
(
res
.
max_value
,
expected
[
1
]),
err_msg
()
def
test_basic
():
s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
assert
s
.
min
()
.
value
==
2
assert
s
.
max
()
.
value
==
3
s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
assert
s
.
min_value
.
value
==
2
assert
s
.
max_value
.
value
==
3
def
test_vector
():
base
=
10
stride
=
3
lanes
=
2
s
=
tvm
.
arith
.
intset_vector
(
tvm
.
make
.
Ramp
(
base
,
stride
,
lanes
))
assert
s
.
min
()
.
value
==
base
assert
s
.
max
()
.
value
==
base
+
stride
*
lanes
-
1
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
c_s
=
tvm
.
arith
.
intset_interval
(
10
,
15
)
d_s
=
tvm
.
arith
.
intset_interval
(
-
3
,
-
1
)
zero
=
tvm
.
const
(
0
,
"int32"
)
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
# expression containing variable a is on rhs
e1
=
(
c
>
a
*
4
+
b
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max
())
==
"neg_inf"
assert
str
(
res2
.
min
())
==
"pos_inf"
# expression containing variable a is on rhs
e2
=
(
zero
<
tvm
.
max
(
5
,
a
*
4
))
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max
())
==
"neg_inf"
assert
str
(
res2
.
min
())
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min
()))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min
()))
==
str
(
ans3
)
def
test_check
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
c_s
=
tvm
.
arith
.
intset_interval
(
5
,
7
)
d_s
=
tvm
.
arith
.
intset_interval
(
-
3
,
-
1
)
# no compare operator
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
a
+
b
,
{
b
:
b_s
},
{})
assert
res1
.
is_nothing
()
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
.
astype
(
c
.
dtype
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res2
.
is_nothing
()
# multiple target variable
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
a
*
2
-
a
>
b
,
{
b
:
b_s
},
{})
assert
res2
.
is_nothing
()
def
test_deduce_basic
():
def
test_basic
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
intset_interval
(
a1
,
a2
)
e0
=
b
+
a
*
coff
+
3
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
<
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
test_basic
(
0
,
4
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
2
,
6
,
4
)
test_basic
(
0
,
4
,
-
4
)
test_basic
(
1
,
5
,
-
4
)
test_basic
(
2
,
6
,
-
4
)
def
test_deduce_complex
():
def
test_complex
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
intset_interval
(
a1
,
a2
)
e0
=
(
b
*
3
+
a
*
coff
)
*
4
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<=
63
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
<=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>=
63
))
.
value
==
1
test_complex
(
0
,
4
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
2
,
6
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
1
,
5
,
-
4
)
test_complex
(
2
,
6
,
-
4
)
assert
s
.
min_value
.
value
==
base
assert
s
.
max_value
.
value
==
base
+
stride
*
lanes
-
1
def
test_add_sub
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
verify
(
x
+
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
y
,
10
+
y
))
ck
.
verify
(
x
+
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
),
y
:
tvm
.
arith
.
IntervalSet
(
1
,
11
)},
(
1
,
21
))
ck
.
verify
(
x
-
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
),
y
:
tvm
.
arith
.
IntervalSet
(
1
,
11
)},
(
-
11
,
9
))
def
test_mul_div
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
*
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
*
y
))
ck
.
verify
(
x
*
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
2
,
20
))
ck
.
verify
(
x
*
-
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
-
20
,
-
2
))
ck
.
verify
(
x
/
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
/
y
))
ck
.
verify
(
x
/
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
def
test_mod
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
ck
.
verify
(
x
%
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
x
%
10
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
def
test_max_min
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
verify
(
tvm
.
max
(
x
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
1
,
11
))
ck
.
verify
(
tvm
.
min
(
x
-
1
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
-
1
,
9
))
ck
.
verify
(
tvm
.
min
(
x
,
y
),
{},
(
tvm
.
min
(
x
,
y
),
tvm
.
min
(
x
,
y
)))
ck
.
verify
(
tvm
.
max
(
x
,
y
),
{},
(
tvm
.
max
(
x
,
y
),
tvm
.
max
(
x
,
y
)))
def
test_select
():
ck
=
IntSetChecker
()
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
x
-
1
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
-
1
,
11
))
if
__name__
==
"__main__"
:
test_basic
()
test_vector
()
test_deduce
()
test_check
()
test_deduce_basic
()
test_deduce_complex
()
test_add_sub
()
test_mul_div
()
test_max_min
()
test_select
()
test_mod
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment