Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
153417a5
Unverified
Commit
153417a5
authored
Jun 13, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Revamp IntSet (#3272)
parent
9bb16872
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1165 additions
and
823 deletions
+1165
-823
include/tvm/arithmetic.h
+107
-83
python/tvm/arith.py
+31
-12
src/api/api_arith.cc
+5
-0
src/arithmetic/analyzer.cc
+3
-2
src/arithmetic/bound_deducer.cc
+2
-2
src/arithmetic/canonical_simplify.cc
+4
-2
src/arithmetic/compute_expr.h
+3
-3
src/arithmetic/const_fold.h
+54
-0
src/arithmetic/detect_linear_equation.cc
+2
-2
src/arithmetic/int_op_overflow.h
+0
-0
src/arithmetic/int_set.cc
+542
-474
src/arithmetic/int_set.h
+143
-0
src/arithmetic/int_set_internal.h
+0
-79
src/lang/expr_operator.cc
+15
-0
src/pass/loop_partition.cc
+13
-10
tests/python/unittest/test_arith_deduce_bound.py
+168
-0
tests/python/unittest/test_arith_intset.py
+73
-154
No files found.
include/tvm/arithmetic.h
View file @
153417a5
...
@@ -328,71 +328,14 @@ class ConstraintContext {
...
@@ -328,71 +328,14 @@ class ConstraintContext {
std
::
function
<
void
()
>
exit_
;
std
::
function
<
void
()
>
exit_
;
};
};
/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class
Analyzer
{
public
:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier
rewrite_simplify
;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier
canonical_simplify
;
/*! \brief constructor */
Analyzer
();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Expr
&
expr
);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Range
&
range
);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
};
//-----------------------------------------------
//-----------------------------------------------
// Integer set
abstraction API
.
// Integer set
data structure
.
//
//
// This is a API build on top of the base
// This is a API build on top of the base
// integer analysis API to provide set analysis.
// integer analysis API to provide set analysis.
//------------------------------------------------
//------------------------------------------------
/*!
/*!
* \brief Sign
of an expression or set
.
* \brief Sign
type of an integer expression
.
*/
*/
enum
SignType
{
enum
SignType
{
kPositive
,
kPositive
,
...
@@ -401,8 +344,13 @@ enum SignType {
...
@@ -401,8 +344,13 @@ enum SignType {
kUnknown
kUnknown
};
};
// internal node container of int set.
/*!
struct
IntSetNode
;
* \brief Base class of all IntSet containers.
*/
struct
IntSetNode
:
public
Node
{
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
};
/*!
/*!
* \brief Integer set class, represent a set of integers in one dimension.
* \brief Integer set class, represent a set of integers in one dimension.
...
@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
...
@@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range.
* \return The covering range.
*/
*/
Range
cover_range
(
Range
max_range
)
const
;
Range
cover_range
(
Range
max_range
)
const
;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet
cover_interval
()
const
;
/*! \return Lower bound of the set */
/*! \return Lower bound of the set */
Expr
min
()
const
;
Expr
min
()
const
;
/*! \return upper bound of the set */
/*! \return upper bound of the set */
...
@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
...
@@ -493,33 +436,91 @@ class IntSet : public NodeRef {
};
};
/*!
/*!
* \brief
Base class of all IntSet containers
.
* \brief
Integer set analyzer
.
*/
*/
struct
IntSetNode
:
public
Node
{
class
IntSetAnalyzer
{
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
public
:
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
/*!
* \brief Find a symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet
operator
()(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
);
private
:
friend
class
Analyzer
;
explicit
IntSetAnalyzer
(
Analyzer
*
parent
);
~
IntSetAnalyzer
();
class
Impl
;
/*! \brief Internal impl */
Impl
*
impl_
;
};
};
/*!
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* \brief Analyzer that contains bunch of sub-analyzers.
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
*
* \param e The expression to be detected.
* Each sub-analyzer can make use of another sub-analyzer
* \param vars List of variables to be used in detection.
* by weak reference of this.
* \return [coeff[i]] if it is possible, empty array if it is not.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class
Analyzer
{
public
:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer
const_int_bound
;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer
modular_set
;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier
rewrite_simplify
;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier
canonical_simplify
;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer
int_set
;
/*! \brief constructor */
Analyzer
();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void
Bind
(
const
VarExpr
&
var
,
const
Expr
&
expr
);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
*/
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
void
Bind
(
const
VarExpr
&
var
,
const
Range
&
range
);
/*!
* \brief Whether can we prove expr >= val.
/*!
* Non-negative proof is very useful in integer analysis
* \brief Detect if expression corresponds to clip bound of the vars
* to lower divisions and mods given difference in trunc and ceil mode.
*
*
* \param e The expression to be detected.
* \param expr The expression.
* \param vars List of variables to be used in detection.
* \param lower_bound The lower bound.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* \return Whether we can prove it.
* return empty if the e does not match the pattern.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
*/
Array
<
Expr
>
DetectClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
bool
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
);
};
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
/*!
* \brief Find an symbolic integer set that contains all possible values of
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
* e given the domain of each iteration variables.
...
@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
...
@@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
*/
Domain
DomainTouched
(
Stmt
body
,
const
Tensor
&
tensor
,
bool
consider_calls
,
bool
consider_provides
);
Domain
DomainTouched
(
Stmt
body
,
const
Tensor
&
tensor
,
bool
consider_calls
,
bool
consider_provides
);
// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array
<
Expr
>
DetectClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
// implementation
// implementation
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
...
...
python/tvm/arith.py
View file @
153417a5
...
@@ -32,21 +32,21 @@ class IntSet(NodeBase):
...
@@ -32,21 +32,21 @@ class IntSet(NodeBase):
return
_api_internal
.
_IntSetIsEverything
(
self
)
return
_api_internal
.
_IntSetIsEverything
(
self
)
@register_node
@register_node
(
"arith.IntervalSet"
)
class
IntervalSet
(
IntSet
):
class
IntervalSet
(
IntSet
):
"""Represent set of continuous interval"""
"""Represent set of continuous interval [min_value, max_value]
def
min
(
self
):
"""get the minimum value"""
return
_api_internal
.
_IntervalSetGetMin
(
self
)
def
max
(
self
):
"""get the maximum value"""
return
_api_internal
.
_IntervalSetGetMax
(
self
)
Parameters
----------
min_value : Expr
The minimum value in the interval.
@register_node
max_value : Expr
class
StrideSet
(
IntSet
):
The maximum value in the interval.
"""Represent set of strided integers"""
"""
def
__init__
(
self
,
min_value
,
max_value
):
self
.
__init_handle_by_constructor__
(
_make_IntervalSet
,
min_value
,
max_value
)
@register_node
(
"arith.ModularSet"
)
@register_node
(
"arith.ModularSet"
)
...
@@ -114,6 +114,7 @@ class Analyzer:
...
@@ -114,6 +114,7 @@ class Analyzer:
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_modular_set
=
_mod
(
"modular_set"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_rewrite_simplify
=
_mod
(
"rewrite_simplify"
)
self
.
_canonical_simplify
=
_mod
(
"canonical_simplify"
)
self
.
_canonical_simplify
=
_mod
(
"canonical_simplify"
)
self
.
_int_set
=
_mod
(
"int_set"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
self
.
_enter_constraint_context
=
_mod
(
"enter_constraint_context"
)
def
const_int_bound
(
self
,
expr
):
def
const_int_bound
(
self
,
expr
):
...
@@ -176,6 +177,24 @@ class Analyzer:
...
@@ -176,6 +177,24 @@ class Analyzer:
"""
"""
return
self
.
_canonical_simplify
(
expr
)
return
self
.
_canonical_simplify
(
expr
)
def
int_set
(
self
,
expr
,
dom_map
):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : tvm.Expr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
-------
result : IntSet
The result.
"""
return
self
.
_int_set
(
expr
,
dom_map
)
def
bind
(
self
,
var
,
expr
):
def
bind
(
self
,
var
,
expr
):
"""Bind a variable to the expression.
"""Bind a variable to the expression.
...
...
src/api/api_arith.cc
View file @
153417a5
...
@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
...
@@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API
(
"arith.intset_interval"
)
TVM_REGISTER_API
(
"arith.intset_interval"
)
.
set_body_typed
(
IntSet
::
interval
);
.
set_body_typed
(
IntSet
::
interval
);
TVM_REGISTER_API
(
"arith.DetectLinearEquation"
)
TVM_REGISTER_API
(
"arith.DetectLinearEquation"
)
.
set_body_typed
(
DetectLinearEquation
);
.
set_body_typed
(
DetectLinearEquation
);
...
@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
...
@@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
canonical_simplify
(
args
[
0
]);
*
ret
=
self
->
canonical_simplify
(
args
[
0
]);
});
});
}
else
if
(
name
==
"int_set"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
self
->
int_set
(
args
[
0
],
args
[
1
]);
});
}
else
if
(
name
==
"bind"
)
{
}
else
if
(
name
==
"bind"
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
1
].
node_sptr
();
auto
&
sptr
=
args
[
1
].
node_sptr
();
...
...
src/arithmetic/analyzer.cc
View file @
153417a5
...
@@ -31,7 +31,8 @@ Analyzer::Analyzer()
...
@@ -31,7 +31,8 @@ Analyzer::Analyzer()
:
const_int_bound
(
this
),
:
const_int_bound
(
this
),
modular_set
(
this
),
modular_set
(
this
),
rewrite_simplify
(
this
),
rewrite_simplify
(
this
),
canonical_simplify
(
this
)
{
canonical_simplify
(
this
),
int_set
(
this
)
{
}
}
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
void
Analyzer
::
Bind
(
const
VarExpr
&
v
,
const
Expr
&
expr
)
{
...
@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
...
@@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
return
ptr
->
value
>
=
lower_bound
;
}
}
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
auto
bd
=
this
->
const_int_bound
(
this
->
rewrite_simplify
(
expr
));
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
if
(
bd
->
min_value
>=
lower_bound
)
return
true
;
...
...
src/arithmetic/bound_deducer.cc
View file @
153417a5
...
@@ -30,12 +30,12 @@
...
@@ -30,12 +30,12 @@
#include <unordered_set>
#include <unordered_set>
#include <unordered_map>
#include <unordered_map>
#include "int_set.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
using
namespace
ir
;
using
namespace
ir
;
using
HalideIR
::
Internal
::
Interval
;
// a visitor to find the path to the target variable
// a visitor to find the path to the target variable
// from a expression.
// from a expression.
...
@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
...
@@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
BoundDeducer
d
(
v
,
e
,
hint_map
,
relax_map
);
d
.
Deduce
();
d
.
Deduce
();
if
(
!
d
.
success
)
return
IntSet
::
nothing
();
if
(
!
d
.
success
)
return
IntSet
::
nothing
();
Expr
min
=
Interval
::
neg_inf
,
max
=
Interval
::
pos_inf
;
Expr
min
=
neg_inf
(),
max
=
pos_inf
()
;
if
(
d
.
is_greater
)
{
if
(
d
.
is_greater
)
{
min
=
d
.
result
;
min
=
d
.
result
;
}
else
{
}
else
{
...
...
src/arithmetic/canonical_simplify.cc
View file @
153417a5
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
* \brief Canonical form based simplification.
*/
*/
...
@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
...
@@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
if
(
TryCompare
(
temp
,
cval
)
==
kLT
)
{
if
(
TryCompare
(
temp
,
cval
)
==
kLT
)
{
return
temp
;
return
temp
;
}
else
{
}
else
{
return
SplitModConst
(
ToSplitExpr
(
temp
),
cval
);
// contonue to use logic below.
a
=
extra
;
psum
=
a
.
as
<
SumExprNode
>
();
CHECK
(
psum
!=
nullptr
);
}
}
}
}
}
}
...
...
src/arithmetic/compute_expr.h
View file @
153417a5
...
@@ -27,8 +27,8 @@
...
@@ -27,8 +27,8 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <arithmetic/Interval.h>
#include <limits>
#include <limits>
#include <algorithm>
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
...
@@ -105,12 +105,12 @@ inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
return
HalideIR
::
Internal
::
Interval
::
make_
max
(
a
,
b
);
return
max
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
ComputeExpr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
ComputeExpr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
return
HalideIR
::
Internal
::
Interval
::
make_
min
(
a
,
b
);
return
min
(
a
,
b
);
}
}
template
<
typename
Op
>
template
<
typename
Op
>
...
...
src/arithmetic/const_fold.h
View file @
153417a5
...
@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
...
@@ -206,6 +206,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
min
(
fa
->
value
,
fb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
min
(
fa
->
value
,
fb
->
value
));
});
});
if
(
a
.
same_as
(
b
))
return
a
;
return
Expr
();
return
Expr
();
}
}
...
@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
...
@@ -216,6 +217,7 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
max
(
fa
->
value
,
fb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
max
(
fa
->
value
,
fb
->
value
));
});
});
if
(
a
.
same_as
(
b
))
return
a
;
return
Expr
();
return
Expr
();
}
}
...
@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
...
@@ -307,6 +309,58 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
return
Expr
();
return
Expr
();
}
}
/*! \brief Helper namespace for symbolic value limits */
struct
SymbolicLimits
{
/*! \brief positive infinity */
static
Expr
pos_inf_
;
/*! \brief negative infinity */
static
Expr
neg_inf_
;
};
/*!
* \brief Opaque expression representing positive infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return positive infinity.
*/
inline
Expr
pos_inf
()
{
return
SymbolicLimits
::
pos_inf_
;
}
/*!
* \brief Check if value is positive infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline
bool
is_pos_inf
(
const
Expr
&
value
)
{
return
value
.
same_as
(
SymbolicLimits
::
pos_inf_
);
}
/*!
* \brief Opaque expression representing negative infinity.
*
* It can can only be used as parameter of by min/max
* for integer analysis and cannot be used in normal expressions.
*
* \return negative infinity.
*/
inline
Expr
neg_inf
()
{
return
SymbolicLimits
::
neg_inf_
;
}
/*!
* \brief Check if value is negative infinity.
* \param value The value to be checked.
*
* \return The check result.
*/
inline
bool
is_neg_inf
(
const
Expr
&
value
)
{
return
value
.
same_as
(
SymbolicLimits
::
neg_inf_
);
}
}
// namespace arith
}
// namespace arith
}
// namespace tvm
}
// namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
src/arithmetic/detect_linear_equation.cc
View file @
153417a5
...
@@ -19,8 +19,8 @@
...
@@ -19,8 +19,8 @@
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file
bound_deducer
.cc
* \file
detect_linear_equation
.cc
* \brief Utility to de
duce bound of expression
* \brief Utility to de
tect patterns in the expression.
*/
*/
#include <tvm/expr.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
...
...
src/arithmetic/int_op_overflow.h
View file @
153417a5
src/arithmetic/int_set.cc
View file @
153417a5
...
@@ -18,201 +18,55 @@
...
@@ -18,201 +18,55 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \file int_set.cc
* \brief The integer set functions
* \brief The integer set functions
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h>
#include <tvm/api_registry.h>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <unordered_map>
#include "
compute_expr
.h"
#include "
int_set
.h"
#include "
int_set_internal
.h"
#include "
pattern_match
.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
using
HalideIR
::
Internal
::
Interval
;
Expr
SymbolicLimits
::
pos_inf_
=
Var
(
"pos_inf"
,
Handle
())
;
using
namespace
ir
;
Expr
SymbolicLimits
::
neg_inf_
=
Var
(
"neg_inf"
,
Handle
())
;
inline
IntSet
IntSet
::
cover_interval
()
const
{
IntervalSet
::
IntervalSet
(
Expr
min_value
,
Expr
max_value
)
{
if
((
*
this
).
as
<
IntervalSet
>
())
return
*
this
;
auto
node
=
make_node
<
IntervalSetNode
>
();
const
StrideSet
*
s
=
(
*
this
).
as
<
StrideSet
>
();
node
->
min_value
=
std
::
move
(
min_value
);
if
(
s
)
{
node
->
max_value
=
std
::
move
(
max_value
);
CHECK_NE
(
s
->
extents
.
size
(),
0U
);
node_
=
std
::
move
(
node
);
Expr
max
=
s
->
base
.
max
;
for
(
size_t
i
=
0
;
i
<
s
->
extents
.
size
();
++
i
)
{
max
=
max
+
s
->
extents
[
i
]
*
s
->
strides
[
i
]
-
s
->
strides
[
i
];
}
return
IntervalSet
::
make
(
s
->
base
.
min
,
Simplify
(
max
));
}
LOG
(
FATAL
)
<<
"cannot convert set "
<<
(
*
this
)
->
type_key
()
<<
" to interval"
;
return
IntSet
::
everything
();
}
}
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntervalSet
MakeIntervalSet
(
Expr
min_value
,
Expr
max_value
)
{
IntSet
temp
;
return
IntervalSet
(
min_value
,
max_value
);
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
if
(
s_int
==
nullptr
)
{
temp
=
this
->
cover_interval
();
s_int
=
temp
.
as
<
IntervalSet
>
();
}
if
(
s_int
->
i
.
is_bounded
())
{
return
Range
::
make_by_min_extent
(
s_int
->
i
.
min
,
Simplify
(
s_int
->
i
.
max
+
1
-
s_int
->
i
.
min
));
}
return
max_range
;
}
}
Expr
IntSet
::
min
()
const
{
TVM_REGISTER_API
(
"arith._make_IntervalSet"
)
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
.
set_body_typed
(
MakeIntervalSet
);
CHECK
(
s_int
);
return
s_int
->
i
.
min
;
}
Expr
IntSet
::
max
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
);
return
s_int
->
i
.
max
;
}
bool
IntSet
::
is_nothing
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_empty
());
}
bool
IntSet
::
is_everything
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
s_int
->
i
.
is_everything
());
}
bool
IntSet
::
is_single_point
()
const
{
IntervalSet
Intersect
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
Expr
max_value
=
min
(
a
->
max_value
,
b
->
max_value
);
return
(
s_int
&&
s_int
->
i
.
is_single_point
());
Expr
min_value
=
max
(
a
->
min_value
,
b
->
min_value
);
}
if
((
max_value
.
type
().
is_int
()
||
max_value
.
type
().
is_uint
())
&&
(
min_value
.
type
().
is_int
()
||
min_value
.
type
().
is_uint
())
&&
bool
IntSet
::
can_prove_positive
()
const
{
analyzer
->
CanProveGreaterEqual
(
min_value
-
max_value
,
1
))
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
IntervalSet
::
Empty
();
return
(
s_int
&&
is_positive_const
(
ir
::
Simplify
(
s_int
->
i
.
min
)));
}
bool
IntSet
::
can_prove_negative
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
return
(
s_int
&&
is_negative_const
(
ir
::
Simplify
(
s_int
->
i
.
max
)));
}
bool
IntSet
::
can_prove_non_positive
()
const
{
if
(
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
())
{
auto
max
=
ir
::
Simplify
(
s_int
->
i
.
max
);
return
is_zero
(
max
)
||
is_negative_const
(
max
);
}
return
false
;
}
bool
IntSet
::
can_prove_non_negative
()
const
{
if
(
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
())
{
// Any reason why we should or should not use can_prove() to implement
// these functions?
auto
min
=
ir
::
Simplify
(
s_int
->
i
.
min
);
return
is_zero
(
min
)
||
is_positive_const
(
min
);
}
return
false
;
}
SignType
IntSet
::
sign_type
()
const
{
if
(
can_prove_positive
())
{
return
kPositive
;
}
else
if
(
can_prove_negative
())
{
return
kNegative
;
}
else
if
(
is_single_point
()
&&
is_zero
(
point_value
()))
{
return
kZero
;
}
else
{
}
else
{
return
kUnknown
;
return
IntervalSet
(
min_value
,
max_value
);
}
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSet
*
s_int
=
(
*
this
).
as
<
IntervalSet
>
();
CHECK
(
s_int
&&
s_int
->
i
.
is_single_point
());
return
s_int
->
i
.
min
;
}
IntSet
IntSet
::
nothing
()
{
return
IntervalSet
::
make
(
Interval
::
nothing
());
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
make
(
Interval
::
everything
());
}
IntSet
IntSet
::
single_point
(
Expr
x
)
{
return
IntervalSet
::
make
(
Interval
::
single_point
(
x
));
}
IntSet
IntSet
::
range
(
Range
r
)
{
// must make sure it can be matched back by MatchRange.
if
(
is_one
(
r
->
extent
))
{
return
IntSet
::
single_point
(
r
->
min
);
}
if
(
is_positive_const
(
r
->
extent
)
&&
is_const
(
r
->
min
))
{
return
IntervalSet
::
make
(
r
->
min
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
extent
,
r
->
min
),
1
));
}
return
IntervalSet
::
make
(
r
->
min
,
(
r
->
extent
+
r
->
min
)
-
1
);
}
IntSet
IntSet
::
interval
(
Expr
min
,
Expr
max
)
{
if
(
min
.
same_as
(
max
))
{
return
IntSet
::
single_point
(
min
);
}
}
return
IntervalSet
::
make
(
min
,
max
);
}
}
inline
bool
prove_equal
(
Expr
lhs
,
Expr
rhs
)
{
IntervalSet
Union
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
Expr
max_value
=
max
(
a
->
max_value
,
b
->
max_value
);
}
Expr
min_value
=
min
(
a
->
min_value
,
b
->
min_value
);
return
IntervalSet
(
min_value
,
max_value
);
// Check if a is created from b.
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
return
prove_equal
(
i
.
min
,
b
->
min
)
&&
prove_equal
(
i
.
max
,
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
b
->
extent
,
b
->
min
),
1
));
}
inline
bool
MatchPoint
(
const
IntSet
&
a
,
const
Expr
&
b
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
!
a_int
)
return
false
;
const
Interval
&
i
=
a_int
->
i
;
return
i
.
is_single_point
()
&&
i
.
min
.
same_as
(
b
);
}
IntSet
Union
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Interval
x
=
sets
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
IntSet
s
=
sets
[
i
].
cover_interval
();
const
Interval
&
y
=
s
.
as
<
IntervalSet
>
()
->
i
;
x
.
include
(
y
);
}
x
.
max
=
ir
::
Simplify
(
x
.
max
);
x
.
min
=
ir
::
Simplify
(
x
.
min
);
return
IntervalSet
::
make
(
x
);
}
IntSet
Intersect
(
const
Array
<
IntSet
>&
sets
)
{
Interval
x
=
sets
[
0
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
Interval
y
=
sets
[
i
].
cover_interval
().
as
<
IntervalSet
>
()
->
i
;
x
=
Interval
::
make_intersection
(
x
,
y
);
}
return
IntervalSet
::
make
(
x
);
}
}
// type traits
// type traits
...
@@ -227,407 +81,623 @@ struct is_logical_op {
...
@@ -227,407 +81,623 @@ struct is_logical_op {
static const bool value = true; \
static const bool value = true; \
};
};
// interval related.
TVM_DECLARE_LOGICAL_OP
(
And
);
template
<
typename
OP
>
TVM_DECLARE_LOGICAL_OP
(
Or
);
inline
IntSet
CombineInterval
(
Interval
a
,
Interval
b
)
{
TVM_DECLARE_LOGICAL_OP
(
EQ
);
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
TVM_DECLARE_LOGICAL_OP
(
NE
);
return
IntSet
::
single_point
(
ComputeExpr
<
OP
>
(
a
.
min
,
b
.
min
));
TVM_DECLARE_LOGICAL_OP
(
GE
);
}
TVM_DECLARE_LOGICAL_OP
(
GT
);
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval "
<<
OP
::
_type_key
;
TVM_DECLARE_LOGICAL_OP
(
LE
);
return
IntSet
::
everything
();
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
/*!
* \brief Combine two interval set under arithmetic operations.
* \note this can possibly relax the set.
*/
template
<
typename
Op
>
inline
IntervalSet
Combine
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
Expr
res
=
TryConstFold
<
Op
>
(
a
->
min_value
,
b
->
min_value
);
if
(
!
res
.
defined
())
res
=
Op
::
make
(
a
->
min_value
,
b
->
min_value
);
return
IntervalSet
::
SinglePoint
(
res
);
}
if
(
is_logical_op
<
Op
>::
value
)
{
return
IntervalSet
(
make_const
(
a
->
min_value
.
type
(),
0
),
make_const
(
a
->
min_value
.
type
(),
1
));
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
a
->
IsEverything
())
return
a
;
if
(
b
->
IsEverything
())
return
b
;
return
IntervalSet
::
Everything
();
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Add
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Add
>
(
Analyzer
*
analyer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
}
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
Interval
r
=
Interval
::
everything
();
return
IntervalSet
::
SinglePoint
(
a
->
min_value
+
b
->
min_value
);
if
(
a
.
has_lower_bound
()
&&
b
.
has_lower_bound
())
{
}
r
.
min
=
ComputeExpr
<
Add
>
(
a
.
min
,
b
.
min
);
if
(
a
->
IsEmpty
())
return
a
;
}
if
(
b
->
IsEmpty
())
return
b
;
if
(
a
.
has_upper_bound
()
&&
b
.
has_upper_bound
())
{
Expr
min_value
=
r
.
max
=
ComputeExpr
<
Add
>
(
a
.
max
,
b
.
max
);
a
->
HasLowerBound
()
&&
b
->
HasLowerBound
()
?
}
a
->
min_value
+
b
->
min_value
:
neg_inf
();
return
IntervalSet
::
make
(
r
);
Expr
max_value
=
a
->
HasUpperBound
()
&&
b
->
HasUpperBound
()
?
a
->
max_value
+
b
->
max_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Sub
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Sub
>
(
Analyzer
*
analyer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
}
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
Interval
r
=
Interval
::
everything
();
return
IntervalSet
::
SinglePoint
(
a
->
min_value
-
b
->
min_value
);
if
(
a
.
has_lower_bound
()
&&
b
.
has_upper_bound
())
{
r
.
min
=
ComputeExpr
<
Sub
>
(
a
.
min
,
b
.
max
);
}
if
(
a
.
has_upper_bound
()
&&
b
.
has_lower_bound
())
{
r
.
max
=
ComputeExpr
<
Sub
>
(
a
.
max
,
b
.
min
);
}
}
return
IntervalSet
::
make
(
r
);
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
Expr
min_value
=
a
->
HasLowerBound
()
&&
b
->
HasUpperBound
()
?
a
->
min_value
-
b
->
max_value
:
neg_inf
();
Expr
max_value
=
a
->
HasUpperBound
()
&&
b
->
HasLowerBound
()
?
a
->
max_value
-
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Mul
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Mul
>
(
Analyzer
*
analyzer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
}
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
if
(
a
.
is_single_point
()
&&
!
b
.
is_single_point
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
*
b
->
min_value
);
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
a
->
IsSinglePoint
())
{
std
::
swap
(
a
,
b
);
std
::
swap
(
a
,
b
);
}
}
if
(
b
.
is_single_p
oint
())
{
if
(
b
->
IsSingleP
oint
())
{
if
(
is_zero
(
b
.
min
))
return
IntSet
::
single_point
(
0
)
;
if
(
is_zero
(
b
->
min_value
))
return
b
;
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
)
;
if
(
is_one
(
b
->
min_value
))
return
a
;
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
if
(
analyzer
->
CanProveGreaterEqual
(
b
->
min_value
,
0
))
{
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Mul
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
Expr
min_value
=
a
->
HasLowerBound
()
?
a
->
min_value
*
b
->
min_value
:
neg_inf
()
;
// no relaxation is needed in here due to set is inclusive
Expr
max_value
=
a
->
HasUpperBound
()
?
a
->
max_value
*
b
->
min_value
:
pos_inf
();
// TODO(tqchen): consider convert to StrideSet.
return
IntervalSet
(
min_value
,
max_value
);
if
(
is_positive_const
(
b
.
min
))
{
}
else
if
(
analyzer
->
CanProveGreaterEqual
(
-
b
->
min_value
,
1
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
Expr
min_value
=
a
->
HasUpperBound
()
?
a
->
max_value
*
b
->
min_value
:
neg_inf
(
);
}
else
if
(
is_negative_const
(
b
.
min
))
{
Expr
max_value
=
a
->
HasLowerBound
()
?
a
->
min_value
*
b
->
min_value
:
pos_inf
();
return
IntervalSet
::
make
(
e2
,
e1
);
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
a
.
is_bounde
d
())
{
}
else
if
(
a
->
HasUpperBound
()
&&
a
->
HasLowerBoun
d
())
{
using
ir
::
Select
;
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
Expr
sign
=
b
->
min_value
>=
make_zero
(
b
->
min_value
.
type
().
element_of
());
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
Expr
e1
=
a
->
min_value
*
b
->
min_value
;
Expr
e2
=
a
->
max_value
*
b
->
min_value
;
return
IntervalSet
(
Select
::
make
(
sign
,
e1
,
e2
),
Select
::
make
(
sign
,
e2
,
e1
));
}
}
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
D
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mul"
;
return
Int
Set
::
e
verything
();
return
Int
ervalSet
::
E
verything
();
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Div
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Div
>
(
Analyzer
*
analyzer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Div
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
}
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
if
(
b
.
is_single_point
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
/
b
->
min_value
);
if
(
is_zero
(
b
.
min
))
{
}
if
(
a
->
IsEmpty
())
return
a
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
b
->
IsSinglePoint
())
{
if
(
is_zero
(
b
->
min_value
))
{
LOG
(
FATAL
)
<<
"Divide by zero in CombineInterval Div"
;
LOG
(
FATAL
)
<<
"Divide by zero in CombineInterval Div"
;
}
}
if
(
is_one
(
b
.
min
))
return
IntervalSet
::
make
(
a
);
if
(
is_one
(
b
->
min_value
))
return
a
;
Expr
e1
=
a
.
has_lower_bound
()
?
ComputeExpr
<
Div
>
(
a
.
min
,
b
.
min
)
:
a
.
min
;
Expr
e2
=
a
.
has_upper_bound
()
?
ComputeExpr
<
Div
>
(
a
.
max
,
b
.
min
)
:
a
.
max
;
// no relaxation is needed in here due to set is inclusive
// no relaxation is needed in here due to set is inclusive
if
(
is_positive_const
(
b
.
min
))
{
if
(
analyzer
->
CanProveGreaterEqual
(
b
->
min_value
,
0
))
{
return
IntervalSet
::
make
(
e1
,
e2
);
Expr
min_value
=
a
->
HasLowerBound
()
?
a
->
min_value
/
b
->
min_value
:
neg_inf
();
}
else
if
(
is_negative_const
(
b
.
min
))
{
Expr
max_value
=
a
->
HasUpperBound
()
?
a
->
max_value
/
b
->
min_value
:
pos_inf
();
return
IntervalSet
::
make
(
e2
,
e1
);
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
a
.
is_bounded
())
{
}
else
if
(
analyzer
->
CanProveGreaterEqual
(
-
b
->
min_value
,
1
))
{
Expr
min_value
=
a
->
HasUpperBound
()
?
a
->
max_value
/
b
->
min_value
:
neg_inf
();
Expr
max_value
=
a
->
HasLowerBound
()
?
a
->
min_value
/
b
->
min_value
:
pos_inf
();
return
IntervalSet
(
min_value
,
max_value
);
}
else
if
(
a
->
HasUpperBound
()
&&
a
->
HasLowerBound
())
{
using
ir
::
Select
;
using
ir
::
Select
;
Expr
cmp
=
b
.
min
>=
make_zero
(
b
.
min
.
type
().
element_of
());
Expr
sign
=
b
->
min_value
>=
make_zero
(
b
->
min_value
.
type
().
element_of
());
return
IntervalSet
::
make
(
Select
::
make
(
cmp
,
e1
,
e2
),
Select
::
make
(
cmp
,
e2
,
e1
));
Expr
e1
=
a
->
min_value
/
b
->
min_value
;
Expr
e2
=
a
->
max_value
/
b
->
min_value
;
return
IntervalSet
(
Select
::
make
(
sign
,
e1
,
e2
),
Select
::
make
(
sign
,
e2
,
e1
));
}
}
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
D
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Div"
;
return
Int
Set
::
e
verything
();
return
Int
ervalSet
::
E
verything
();
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Mod
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Mod
>
(
Analyzer
*
analyzer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Mod
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
a
->
min_value
%
b
->
min_value
);
}
}
if
(
b
.
is_single_point
())
{
if
(
a
->
IsEmpty
())
return
a
;
Expr
divisor
=
b
.
min
;
if
(
b
->
IsEmpty
())
return
b
;
if
(
b
->
IsSinglePoint
())
{
const
Expr
&
divisor
=
b
->
min_value
;
if
(
is_zero
(
divisor
))
{
if
(
is_zero
(
divisor
))
{
LOG
(
FATAL
)
<<
"Modular by zero in CombineInterval Mod"
;
LOG
(
FATAL
)
<<
"Modular by zero in CombineInterval Mod"
;
}
}
return
IntervalSet
::
make
(
make_zero
(
divisor
.
type
()),
divisor
-
1
);
// We need to add more bound constraints throughout the code.
// The logic below assumes a is non-negative, which usually
// is the case of our application.
// TODO(tqchen): add bound constraints for a.
if
(
analyzer
->
CanProveGreaterEqual
(
divisor
,
0
))
{
return
IntervalSet
(
make_zero
(
divisor
.
type
()),
divisor
-
1
);
}
else
{
Expr
bound
=
abs
(
divisor
)
-
1
;
return
IntervalSet
(
-
bound
,
bound
);
}
}
LOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mod"
;
return
IntSet
::
everything
();
}
template
<>
inline
IntSet
CombineInterval
<
Max
>
(
Interval
a
,
Interval
b
)
{
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
return
IntSet
::
single_point
(
ComputeExpr
<
Max
>
(
a
.
min
,
b
.
min
));
}
}
return
IntervalSet
::
make
(
Interval
::
make_max
(
a
.
min
,
b
.
min
),
DLOG
(
WARNING
)
<<
"Return Everything in CombineInterval Mod"
;
Interval
::
make_max
(
a
.
max
,
b
.
max
)
);
return
IntervalSet
::
Everything
(
);
}
}
template
<>
template
<>
inline
IntSet
CombineInterval
<
Min
>
(
Interval
a
,
Interval
b
)
{
inline
IntervalSet
Combine
<
ir
::
Max
>
(
Analyzer
*
analzyer
,
if
(
a
.
is_single_point
()
&&
b
.
is_single_point
())
{
IntervalSet
a
,
return
IntSet
::
single_point
(
ComputeExpr
<
Min
>
(
a
.
min
,
b
.
min
));
IntervalSet
b
)
{
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
return
IntervalSet
::
SinglePoint
(
max
(
a
->
min_value
,
b
->
min_value
));
}
}
return
IntervalSet
::
make
(
Interval
::
make_min
(
a
.
min
,
b
.
min
),
if
(
a
->
IsEmpty
())
return
a
;
Interval
::
make_min
(
a
.
max
,
b
.
max
));
if
(
b
->
IsEmpty
())
return
b
;
}
return
IntervalSet
(
max
(
a
->
min_value
,
b
->
min_value
),
max
(
a
->
max_value
,
b
->
max_value
));
template
<
typename
OP
>
inline
IntSet
CombineInterval_
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval
<
OP
>
(
a
.
as
<
IntervalSet
>
()
->
i
,
b
.
as
<
IntervalSet
>
()
->
i
);
}
// stride related
inline
IntSet
AsStrideSet
(
IntSet
a
)
{
if
(
a
.
as
<
StrideSet
>
())
return
a
;
const
IntervalSet
*
s
=
a
.
as
<
IntervalSet
>
();
CHECK
(
s
->
i
.
is_bounded
());
NodePtr
<
StrideSet
>
n
=
make_node
<
StrideSet
>
();
n
->
base
=
s
->
i
;
return
IntSet
(
n
);
}
template
<
typename
OP
>
inline
IntSet
CombineSets
(
IntSet
a
,
IntSet
b
)
{
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
.
cover_interval
());
}
}
template
<>
template
<>
inline
IntSet
CombineSets
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
inline
IntervalSet
Combine
<
ir
::
Min
>
(
Analyzer
*
analzyer
,
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
IntervalSet
a
,
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
IntervalSet
b
)
{
if
(
a_int
&&
is_zero
(
a_int
->
i
.
min
))
return
b
;
if
(
a
->
IsSinglePoint
()
&&
b
->
IsSinglePoint
())
{
if
(
b_int
&&
is_zero
(
b_int
->
i
.
min
))
return
a
;
return
IntervalSet
::
SinglePoint
(
min
(
a
->
min_value
,
b
->
min_value
));
a
=
AsStrideSet
(
a
);
b
=
AsStrideSet
(
b
);
const
StrideSet
*
a_stride
=
a
.
as
<
StrideSet
>
();
const
StrideSet
*
b_stride
=
b
.
as
<
StrideSet
>
();
auto
n
=
make_node
<
StrideSet
>
(
*
a_stride
);
for
(
size_t
i
=
0
;
i
<
b_stride
->
extents
.
size
();
++
i
)
{
n
->
extents
.
push_back
(
b_stride
->
extents
[
i
]);
n
->
strides
.
push_back
(
b_stride
->
strides
[
i
]);
}
n
->
base
=
CombineInterval
<
Add
>
(
a_stride
->
base
,
b_stride
->
base
).
as
<
IntervalSet
>
()
->
i
;
return
IntSet
(
n
);
}
inline
IntSet
NegateSet
(
IntSet
a
)
{
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
if
(
a_int
)
{
if
(
a_int
->
i
.
is_single_point
())
{
return
IntSet
::
single_point
(
-
a_int
->
i
.
min
);
}
else
{
Interval
r
=
Interval
::
everything
();
if
(
a_int
->
i
.
has_upper_bound
())
{
r
.
min
=
-
(
a_int
->
i
.
max
);
}
if
(
a_int
->
i
.
has_lower_bound
())
{
r
.
max
=
-
(
a_int
->
i
.
min
);
}
return
IntervalSet
::
make
(
r
);
}
}
}
else
{
if
(
a
->
IsEmpty
())
return
a
;
return
NegateSet
(
a
.
cover_interval
());
if
(
b
->
IsEmpty
())
return
b
;
}
return
IntervalSet
(
min
(
a
->
min_value
,
b
->
min_value
),
}
min
(
a
->
max_value
,
b
->
max_value
));
template
<>
inline
IntSet
CombineSets
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
CombineSets
<
Add
>
(
a
,
NegateSet
(
b
));
}
}
TVM_DECLARE_LOGICAL_OP
(
And
);
// internal helper function to get an interval set
TVM_DECLARE_LOGICAL_OP
(
Or
);
IntervalSet
ToIntervalSet
(
IntSet
set
)
{
TVM_DECLARE_LOGICAL_OP
(
EQ
);
if
(
auto
*
node
=
set
.
as
<
IntervalSetNode
>
())
{
TVM_DECLARE_LOGICAL_OP
(
NE
);
return
GetRef
<
IntervalSet
>
(
node
);
TVM_DECLARE_LOGICAL_OP
(
GE
);
TVM_DECLARE_LOGICAL_OP
(
GT
);
TVM_DECLARE_LOGICAL_OP
(
LE
);
TVM_DECLARE_LOGICAL_OP
(
LT
);
TVM_DECLARE_LOGICAL_OP
(
Not
);
// generic combine operations of two sets
template
<
typename
OP
>
inline
IntSet
Combine
(
const
IntSet
&
a
,
const
IntSet
&
b
)
{
if
(
is_logical_op
<
OP
>::
value
)
{
return
IntervalSet
::
make
(
0
,
1
);
}
const
IntervalSet
*
a_int
=
a
.
as
<
IntervalSet
>
();
const
IntervalSet
*
b_int
=
b
.
as
<
IntervalSet
>
();
if
(
a_int
&&
a_int
->
i
.
is_everything
())
return
a
;
if
(
b_int
&&
b_int
->
i
.
is_everything
())
return
b
;
if
(
a_int
&&
b_int
)
{
return
CombineInterval
<
OP
>
(
a_int
->
i
,
b_int
->
i
);
}
if
(
a_int
&&
!
(
a_int
->
i
.
is_bounded
()))
{
return
CombineInterval_
<
OP
>
(
a
,
b
.
cover_interval
());
}
}
if
(
b_int
&&
!
(
b_int
->
i
.
is_bounded
()))
{
DLOG
(
INFO
)
<<
"cannot resolve int set "
<<
set
;
return
CombineInterval_
<
OP
>
(
a
.
cover_interval
(),
b
);
return
IntervalSet
::
Everything
();
}
return
CombineSets
<
OP
>
(
a
,
b
);
}
}
class
IntSetEvaluator
:
using
namespace
ir
;
public
ExprFunctor
<
IntSet
(
const
Expr
&
,
const
Expr
&
)
>
{
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
class
IntervalSetEvaluator
:
public
ExprFunctor
<
IntervalSet
(
const
Expr
&
)
>
{
public
:
public
:
explicit
IntSetEvaluator
(
IntervalSetEvaluator
(
Analyzer
*
analyzer
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
,
const
Map
<
Var
,
IntSet
>&
dom_map
,
bool
eval_vec
=
false
)
bool
eval_vec
=
false
)
:
dom_map_
(
dom_map
),
eval_vec_
(
eval_vec
)
{}
:
analyzer_
(
analyzer
),
// Evaluate.
dom_map_
(
dom_map
),
IntSet
Eval
(
const
Expr
&
e
)
{
eval_vec_
(
eval_vec
)
{
return
this
->
VisitExpr
(
e
,
e
);
}
}
IntSet
VisitExpr_
(
const
IntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
IntervalSet
Eval
(
const
Expr
&
val
)
{
return
this
->
VisitExpr
(
val
);
}
}
IntSet
VisitExpr_
(
const
UIntImm
*
op
,
const
Expr
&
e
)
final
{
return
IntSet
::
single_point
(
e
);
IntervalSet
VisitExpr_
(
const
IntImm
*
op
)
final
{
return
IntervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
));
}
IntervalSet
VisitExpr_
(
const
UIntImm
*
op
)
final
{
return
IntervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
));
}
}
IntSet
VisitExpr_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
dom_map_
.
find
(
op
);
IntervalSet
VisitExpr_
(
const
Variable
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
auto
it
=
dom_map_
.
find
(
var
);
if
(
it
!=
dom_map_
.
end
())
{
if
(
it
!=
dom_map_
.
end
())
{
return
it
->
second
;
return
ToIntervalSet
((
*
it
).
second
)
;
}
else
{
}
else
{
return
Int
Set
::
single_point
(
e
);
return
Int
ervalSet
::
SinglePoint
(
var
);
}
}
}
}
IntSet
VisitExpr_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Add
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Sub
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Sub
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Mul
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Mul
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Div
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Div
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Mod
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Mod
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Min
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Min
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Max
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Max
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
EQ
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
EQ
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
NE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
NE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
LT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
LT
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
LE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
LE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
GT
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
GT
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
GE
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
GE
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
And
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
And
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Or
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
IntervalSet
VisitExpr_
(
const
Or
*
op
)
final
{
return
VisitBinaryExpr_
(
op
);
}
}
IntSet
VisitExpr_
(
const
Ramp
*
op
,
const
Expr
&
e
)
final
{
IntervalSet
VisitExpr_
(
const
Ramp
*
op
)
final
{
CHECK
(
eval_vec_
);
CHECK
(
eval_vec_
);
IntSet
base
=
Eval
(
op
->
base
);
Int
erval
Set
base
=
Eval
(
op
->
base
);
int
v
stride
;
PVar
<
Integer
>
stride
;
if
(
GetConstInt
(
op
->
stride
,
&
v
stride
))
{
if
(
stride
.
Match
(
op
->
stride
))
{
Type
t
=
op
->
base
.
type
();
Type
t
=
op
->
base
.
type
();
if
(
vstride
>
0
)
{
int64_t
vstride
=
stride
.
Eval
()
->
value
;
if
(
vstride
>
0
)
{
return
Combine
<
Add
>
(
return
Combine
<
Add
>
(
analyzer_
,
base
,
base
,
IntSet
::
interval
(
make_zero
(
t
),
IntervalSet
(
make_zero
(
t
),
make_const
(
t
,
vstride
*
op
->
lanes
-
1
)));
make_const
(
t
,
vstride
*
op
->
lanes
-
1
)));
}
else
{
}
else
{
return
Combine
<
Add
>
(
return
Combine
<
Add
>
(
analyzer_
,
base
,
base
,
IntSet
::
interval
(
make_const
(
t
,
vstride
*
op
->
lanes
+
1
),
IntervalSet
(
make_const
(
t
,
vstride
*
op
->
lanes
+
1
),
make_zero
(
t
)));
make_zero
(
t
)));
}
}
}
}
LOG
(
WARNING
)
<<
"cannot evaluate set on expression "
<<
e
;
DLOG
(
WARNING
)
<<
"cannot evaluate set on expression "
<<
GetRef
<
Expr
>
(
op
)
;
return
Int
Set
::
e
verything
();
return
Int
ervalSet
::
E
verything
();
}
}
IntSet
VisitExpr_
(
const
Broadcast
*
op
,
const
Expr
&
e
)
final
{
IntervalSet
VisitExpr_
(
const
Broadcast
*
op
)
final
{
CHECK
(
eval_vec_
);
CHECK
(
eval_vec_
);
return
Eval
(
op
->
value
);
return
VisitExpr
(
op
->
value
);
}
}
IntSet
VisitExpr_
(
const
Select
*
op
,
const
Expr
&
e
)
final
{
IntSet
true_set
=
this
->
Eval
(
op
->
true_value
);
IntervalSet
VisitExpr_
(
const
Select
*
op
)
final
{
IntSet
false_set
=
this
->
Eval
(
op
->
false_value
);
IntervalSet
true_set
=
this
->
Eval
(
op
->
true_value
);
return
Union
({
false_set
,
true_set
});
IntervalSet
false_set
=
this
->
Eval
(
op
->
false_value
);
return
Union
(
analyzer_
,
false_set
,
true_set
);
}
}
IntSet
VisitExprDefault_
(
const
Node
*
op
,
const
Expr
&
e
)
final
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
e
->
type_key
();
IntervalSet
VisitExprDefault_
(
const
Node
*
op
)
final
{
return
IntSet
::
everything
();
DLOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
op
->
type_key
();
return
IntervalSet
::
Everything
();
}
}
private
:
private
:
// whether set is exactly single point that equals value.
bool
MatchPoint
(
const
IntervalSet
&
set
,
const
Expr
&
value
)
const
{
return
set
->
min_value
.
same_as
(
value
)
&&
set
->
max_value
.
same_as
(
value
);
}
template
<
typename
T
>
template
<
typename
T
>
inline
Int
Set
Binary
(
const
T
*
op
,
const
Expr
&
e
)
{
inline
Int
ervalSet
VisitBinaryExpr_
(
const
T
*
op
)
{
IntSet
a
=
this
->
Eval
(
op
->
a
);
Int
erval
Set
a
=
this
->
Eval
(
op
->
a
);
IntSet
b
=
this
->
Eval
(
op
->
b
);
Int
erval
Set
b
=
this
->
Eval
(
op
->
b
);
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
if
(
MatchPoint
(
a
,
op
->
a
)
&&
MatchPoint
(
b
,
op
->
b
))
{
return
Int
Set
::
single_point
(
e
);
return
Int
ervalSet
::
SinglePoint
(
GetRef
<
Expr
>
(
op
)
);
}
}
return
Combine
<
T
>
(
a
,
b
);
return
Combine
<
T
>
(
a
nalyzer_
,
a
,
b
);
}
}
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map_
;
Analyzer
*
analyzer_
;
const
Map
<
Var
,
IntSet
>&
dom_map_
;
bool
eval_vec_
{
false
};
bool
eval_vec_
{
false
};
};
};
IntSet
EvalSet
(
Expr
e
,
class
IntSetAnalyzer
::
Impl
{
public
:
explicit
Impl
(
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{
}
IntSet
Eval
(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
const
{
return
IntervalSetEvaluator
(
analyzer_
,
dom_map
).
Eval
(
expr
);
}
private
:
Analyzer
*
analyzer_
;
};
IntSetAnalyzer
::
IntSetAnalyzer
(
Analyzer
*
parent
)
:
impl_
(
new
Impl
(
parent
))
{
}
IntSetAnalyzer
::~
IntSetAnalyzer
()
{
delete
impl_
;
}
IntSet
IntSetAnalyzer
::
operator
()(
const
Expr
&
expr
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
return
impl_
->
Eval
(
expr
,
dom_map
);
}
// Quickly adapt to IntSet interface
// TODO(tqchen): revisit IntSet interface as well.
Range
IntSet
::
cover_range
(
Range
max_range
)
const
{
IntSet
temp
;
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
!=
nullptr
);
if
(
s_int
->
HasUpperBound
()
&&
s_int
->
HasLowerBound
())
{
return
Range
::
make_by_min_extent
(
s_int
->
min_value
,
Simplify
(
s_int
->
max_value
+
1
-
s_int
->
min_value
));
}
return
max_range
;
}
Expr
IntSet
::
min
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
);
return
s_int
->
min_value
;
}
Expr
IntSet
::
max
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
);
return
s_int
->
max_value
;
}
bool
IntSet
::
is_nothing
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsEmpty
());
}
bool
IntSet
::
is_everything
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsEverything
());
}
bool
IntSet
::
is_single_point
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
s_int
->
IsSinglePoint
());
}
bool
IntSet
::
can_prove_positive
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
is_positive_const
(
ir
::
Simplify
(
s_int
->
min_value
)));
}
bool
IntSet
::
can_prove_negative
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
return
(
s_int
&&
is_negative_const
(
ir
::
Simplify
(
s_int
->
max_value
)));
}
bool
IntSet
::
can_prove_non_positive
()
const
{
if
(
const
auto
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
())
{
auto
max
=
ir
::
Simplify
(
s_int
->
max_value
);
return
is_zero
(
max
)
||
is_negative_const
(
max
);
}
return
false
;
}
bool
IntSet
::
can_prove_non_negative
()
const
{
if
(
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
())
{
auto
min
=
ir
::
Simplify
(
s_int
->
min_value
);
return
is_zero
(
min
)
||
is_positive_const
(
min
);
}
return
false
;
}
SignType
IntSet
::
sign_type
()
const
{
if
(
can_prove_positive
())
{
return
kPositive
;
}
else
if
(
can_prove_negative
())
{
return
kNegative
;
}
else
if
(
is_single_point
()
&&
is_zero
(
point_value
()))
{
return
kZero
;
}
else
{
return
kUnknown
;
}
}
Expr
IntSet
::
point_value
()
const
{
const
IntervalSetNode
*
s_int
=
(
*
this
).
as
<
IntervalSetNode
>
();
CHECK
(
s_int
&&
s_int
->
IsSinglePoint
());
return
s_int
->
min_value
;
}
IntSet
IntSet
::
nothing
()
{
return
IntervalSet
::
Empty
();
}
IntSet
IntSet
::
everything
()
{
return
IntervalSet
::
Everything
();
}
IntSet
IntSet
::
single_point
(
Expr
x
)
{
return
IntervalSet
::
SinglePoint
(
x
);
}
IntSet
IntSet
::
interval
(
Expr
min
,
Expr
max
)
{
if
(
min
.
same_as
(
max
))
{
return
IntSet
::
single_point
(
min
);
}
return
IntervalSet
(
min
,
max
);
}
// Range related code
inline
bool
ProveEqual
(
Expr
lhs
,
Expr
rhs
)
{
return
is_zero
(
ir
::
Simplify
(
lhs
-
rhs
));
}
IntSet
IntSet
::
range
(
Range
r
)
{
// must make sure it can be matched back by MatchRange.
if
(
is_one
(
r
->
extent
))
{
return
IntSet
::
single_point
(
r
->
min
);
}
return
IntervalSet
(
r
->
min
,
r
->
extent
+
r
->
min
-
1
);
}
bool
IntSet
::
match_range
(
const
Range
&
b
)
const
{
const
IntSet
&
a
=
*
this
;
const
IntervalSetNode
*
a_int
=
a
.
as
<
IntervalSetNode
>
();
if
(
!
a_int
)
return
false
;
return
ProveEqual
(
a_int
->
min_value
,
b
->
min
)
&&
ProveEqual
(
a_int
->
max_value
,
b
->
extent
+
b
->
min
-
1
);
}
IntSet
Union
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Analyzer
ana
;
IntervalSet
x
=
ToIntervalSet
(
sets
[
0
]);
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
x
=
Union
(
&
ana
,
x
,
ToIntervalSet
(
sets
[
i
]));
}
return
IntervalSet
(
ir
::
Simplify
(
x
->
min_value
),
ir
::
Simplify
(
x
->
max_value
));
}
IntSet
Intersect
(
const
Array
<
IntSet
>&
sets
)
{
if
(
sets
.
size
()
==
0
)
return
IntSet
::
nothing
();
if
(
sets
.
size
()
==
1
)
return
sets
[
0
];
Analyzer
ana
;
IntervalSet
x
=
ToIntervalSet
(
sets
[
0
]);
for
(
size_t
i
=
1
;
i
<
sets
.
size
();
++
i
)
{
x
=
Intersect
(
&
ana
,
x
,
ToIntervalSet
(
sets
[
i
]));
}
return
IntervalSet
(
ir
::
Simplify
(
x
->
min_value
),
ir
::
Simplify
(
x
->
max_value
));
}
Map
<
Var
,
IntSet
>
ConvertDomMap
(
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
Map
<
Var
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
.
Set
(
kv
.
first
->
var
,
kv
.
second
);
}
return
dmap
;
}
Map
<
Var
,
IntSet
>
ConvertDomMap
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
return
IntSetEvaluator
(
dom_map
,
false
).
Eval
(
e
);
Map
<
Var
,
IntSet
>
dmap
;
for
(
auto
kv
:
dom_map
)
{
dmap
.
Set
(
GetRef
<
Var
>
(
kv
.
first
),
kv
.
second
);
}
return
dmap
;
}
IntSet
EvalSet
(
Expr
e
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
return
IntervalSetEvaluator
(
&
ana
,
dom_map
,
false
).
Eval
(
e
);
}
}
IntSet
IntSet
::
vector
(
Expr
x
)
{
IntSet
IntSet
::
vector
(
Expr
x
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
Analyzer
ana
;
return
IntSetEvaluator
(
dmap
,
true
).
Eval
(
x
);
Map
<
Var
,
IntSet
>
dmap
;
return
IntervalSetEvaluator
(
&
ana
,
dmap
,
true
).
Eval
(
x
);
}
}
IntSet
EvalSet
(
Expr
e
,
IntSet
EvalSet
(
Expr
e
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
return
EvalSet
(
e
,
ConvertDomMap
(
dom_map
));
for
(
auto
kv
:
dom_map
)
{
dmap
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
EvalSet
(
e
,
dmap
);
}
}
IntSet
EvalSet
(
Range
r
,
IntSet
EvalSet
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
(
dom_map
);
return
EvalSet
(
e
,
ConvertDomMap
(
dom_map
));
IntSet
min_set
=
m
.
Eval
(
r
->
min
).
cover_interval
();
}
IntSet
EvalSet
(
Range
r
,
const
Map
<
Var
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
IntervalSetEvaluator
m
(
&
ana
,
dom_map
);
IntervalSet
min_set
=
m
.
Eval
(
r
->
min
);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
// Simplifying first can give tighter bounds if r->min and r->extent share variables
Expr
sum
=
ComputeExpr
<
Sub
>
(
ComputeExpr
<
Add
>
(
r
->
min
,
r
->
extent
),
1
);
Expr
sum
=
r
->
min
+
r
->
extent
-
1
;
IntSet
max_set
=
m
.
Eval
(
Simplify
(
sum
)).
cover_interval
();
IntervalSet
max_set
=
m
.
Eval
(
Simplify
(
sum
));
const
Interval
&
ni
=
min_set
.
as
<
IntervalSet
>
()
->
i
;
if
(
!
min_set
->
HasLowerBound
())
return
IntSet
::
everything
();
const
Interval
&
xi
=
max_set
.
as
<
IntervalSet
>
()
->
i
;
if
(
!
max_set
->
HasUpperBound
())
return
IntSet
::
everything
();
if
(
!
ni
.
has_lower_bound
())
return
IntSet
::
everything
();
return
IntervalSet
(
min_set
->
min_value
,
max_set
->
max_value
);
if
(
!
xi
.
has_upper_bound
())
return
IntSet
::
everything
();
return
IntervalSet
::
make
(
ni
.
min
,
xi
.
max
);
}
}
IntSet
EvalSet
(
IntSet
s
,
IntSet
EvalSet
(
Range
r
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
IntSetEvaluator
m
(
dom_map
);
return
EvalSet
(
r
,
ConvertDomMap
(
dom_map
));
s
=
s
.
cover_interval
();
const
IntervalSet
*
s_int
=
s
.
as
<
IntervalSet
>
();
Expr
vmax
=
s_int
->
i
.
has_upper_bound
()
?
m
.
Eval
(
s_int
->
i
.
max
).
cover_interval
().
max
()
:
s_int
->
i
.
max
;
Expr
vmin
=
s_int
->
i
.
has_lower_bound
()
?
m
.
Eval
(
s_int
->
i
.
min
).
cover_interval
().
min
()
:
s_int
->
i
.
min
;
return
IntervalSet
::
make
(
vmin
,
vmax
);
}
}
class
SubExprIntSetEvaluator
:
public
IntSetEvaluator
{
IntSet
EvalSet
(
IntSet
s
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
Analyzer
ana
;
auto
dmap
=
ConvertDomMap
(
dom_map
);
IntervalSetEvaluator
m
(
&
ana
,
dmap
);
const
IntervalSetNode
*
s_int
=
s
.
as
<
IntervalSetNode
>
();
Expr
vmax
=
s_int
->
HasUpperBound
()
?
m
.
Eval
(
s_int
->
max_value
).
max
()
:
s_int
->
max_value
;
Expr
vmin
=
s_int
->
HasLowerBound
()
?
m
.
Eval
(
s_int
->
min_value
).
min
()
:
s_int
->
min_value
;
return
IntervalSet
(
vmin
,
vmax
);
}
class
SubExprIntervalSetEvaluator
:
public
IntervalSetEvaluator
{
public
:
public
:
explicit
SubExprIntSetEvaluator
(
explicit
SubExprIntervalSetEvaluator
(
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
Analyzer
*
analyzer
,
:
IntSetEvaluator
(
dom_map
)
{}
const
Map
<
Var
,
IntSet
>&
dom_map
)
:
IntervalSetEvaluator
(
analyzer
,
dom_map
)
{}
Int
Set
VisitExpr
(
const
Expr
&
n
,
const
Expr
&
e
)
final
{
Int
ervalSet
VisitExpr
(
const
Expr
&
n
)
final
{
Int
Set
ret
=
IntSetEvaluator
::
VisitExpr
(
n
,
e
);
Int
ervalSet
ret
=
IntervalSetEvaluator
::
VisitExpr
(
n
);
expr_map
[
n
]
=
ret
;
expr_map
[
n
]
=
ret
;
return
ret
;
return
ret
;
}
}
...
@@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator {
...
@@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator {
ExprIntSetMap
expr_map
;
ExprIntSetMap
expr_map
;
};
};
ExprIntSetMap
EvalSetForEachSubExpr
(
Expr
e
,
ExprIntSetMap
EvalSetForEachSubExpr
(
Expr
e
,
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
const
std
::
unordered_map
<
const
Variable
*
,
IntSet
>&
dom_map
)
{
SubExprIntSetEvaluator
m
(
dom_map
);
Analyzer
ana
;
auto
dmap
=
ConvertDomMap
(
dom_map
);
SubExprIntervalSetEvaluator
m
(
&
ana
,
dmap
);
m
.
Eval
(
e
);
m
.
Eval
(
e
);
return
m
.
expr_map
;
return
m
.
expr_map
;
}
}
IntSet
EvalSet
(
Range
r
,
IntSet
EvalSet
(
Range
r
,
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
const
Map
<
IterVar
,
IntSet
>&
dom_map
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dmap
;
return
EvalSet
(
r
,
ConvertDomMap
(
dom_map
));
for
(
auto
kv
:
dom_map
)
{
dmap
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
EvalSet
(
r
,
dmap
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntervalSet
>
([](
const
IntervalSet
*
op
,
IRPrinter
*
p
)
{
.
set_dispatch
<
IntervalSet
Node
>
([](
const
IntervalSetNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"
interval-s
et"
p
->
stream
<<
"
IntervalS
et"
<<
"["
<<
op
->
i
.
min
<<
", "
<<
"["
<<
op
->
min_value
<<
", "
<<
op
->
i
.
max
<<
']'
;
<<
op
->
max_value
<<
']'
;
});
});
}
// namespace arith
}
// namespace arith
}
// namespace tvm
}
// namespace tvm
src/arithmetic/int_set.h
0 → 100644
View file @
153417a5
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_set.h
* \brief Internal data structure for integer set.
*/
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_
#include <tvm/arithmetic.h>
#include <tvm/expr_operator.h>
#include <limits>
#include "const_fold.h"
namespace
tvm
{
namespace
arith
{
/*!
* \brief Symbolic interval set.
*
* \note We intentionally keep the internal of IntSet private,
as we might change it later.
*/
class
IntervalSetNode
:
public
IntSetNode
{
public
:
/*! \brief Minimum value in the interval. */
Expr
min_value
;
/*! \brief Maximum value in the interval. */
Expr
max_value
;
// visitor overload.
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"min_value"
,
&
min_value
);
v
->
Visit
(
"max_value"
,
&
max_value
);
}
/*! \return Whether the interval has upper bound. */
bool
HasUpperBound
()
const
{
return
!
is_pos_inf
(
max_value
)
&&
!
IsEmpty
();
}
/*! \return Whether the interval has lower bound. */
bool
HasLowerBound
()
const
{
return
!
is_neg_inf
(
min_value
)
&&
!
IsEmpty
();
}
/*! \return Whether the interval is a single point. */
bool
IsSinglePoint
()
const
{
return
min_value
.
same_as
(
max_value
);
}
/*! \return whether interval represent nothing */
bool
IsEmpty
()
const
{
// during computations, either extreme could occur.
return
is_pos_inf
(
min_value
)
||
is_neg_inf
(
max_value
);
}
/*! \return whether interval represent everything */
bool
IsEverything
()
const
{
return
is_neg_inf
(
min_value
)
&&
is_pos_inf
(
max_value
);
}
static
constexpr
const
char
*
_type_key
=
"arith.IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSetNode
,
IntSetNode
);
};
/*!
* \brief Interval set used for symbolic integer analysis.
* \sa IntervalSetNode
*/
class
IntervalSet
:
public
IntSet
{
public
:
/*!
* \brief Make a new instance of interval set.
* \param min_value The minimum value in the interval.
* \param max_value The maximum value in the interval.
* \return The created set.
*/
TVM_DLL
IntervalSet
(
Expr
min_value
,
Expr
max_value
);
/*!
* \brief Create an IntervalSet that represents a single point.
* \param value The value to be represented.
* \return The result set.
*/
static
IntervalSet
SinglePoint
(
Expr
value
)
{
return
IntervalSet
(
value
,
value
);
}
/*!
* \brief Create an IntervalSet that represents everything.
* \param value The value to be represented.
* \return The result set.
*/
static
IntervalSet
Everything
()
{
return
IntervalSet
(
neg_inf
(),
pos_inf
());
}
/*!
* \brief Create an empty eet.
* \return The result set.
*/
static
IntervalSet
Empty
()
{
return
IntervalSet
(
pos_inf
(),
neg_inf
());
}
TVM_DEFINE_NODE_REF_COW
(
IntervalSetNode
);
TVM_DEFINE_NODE_REF_METHODS
(
IntervalSet
,
IntSet
,
IntervalSetNode
);
};
/*!
* \brief Create union of two IntervalSets.
* \param analyzer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL
IntervalSet
Union
(
Analyzer
*
analyzer
,
IntervalSet
a
,
IntervalSet
b
);
/*!
* \brief Create insersection of two IntervalSets.
* \param analzyer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL
IntervalSet
Intersect
(
Analyzer
*
analzyer
,
IntervalSet
a
,
IntervalSet
b
);
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_H_
src/arithmetic/int_set_internal.h
deleted
100644 → 0
View file @
9bb16872
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file int_set_internal.h
* \brief Implementations of integer set
*/
#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
namespace
tvm
{
namespace
arith
{
using
HalideIR
::
Internal
::
Interval
;
/*! \brief Set of continuous interval */
struct
IntervalSet
:
public
IntSetNode
{
/*! \brief the internal interval*/
Interval
i
;
static
IntSet
make
(
Interval
i
)
{
NodePtr
<
IntervalSet
>
n
=
make_node
<
IntervalSet
>
();
n
->
i
=
i
;
return
IntSet
(
n
);
}
static
IntSet
make
(
Expr
min
,
Expr
max
)
{
NodePtr
<
IntervalSet
>
n
=
make_node
<
IntervalSet
>
();
n
->
i
.
min
=
min
;
n
->
i
.
max
=
max
;
return
IntSet
(
n
);
}
static
constexpr
const
char
*
_type_key
=
"IntervalSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntervalSet
,
IntSetNode
);
};
/*!
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
struct
StrideSet
:
public
IntSetNode
{
/*! \brief the base inetrval */
Interval
base
;
/*! \brief additional extents in positive number */
Array
<
Expr
>
extents
;
/*! \brief additional strides in positive number */
Array
<
Expr
>
strides
;
static
constexpr
const
char
*
_type_key
=
"StrideSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
StrideSet
,
IntSetNode
);
};
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
src/lang/expr_operator.cc
View file @
153417a5
...
@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
...
@@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) {
return
ir
::
Mod
::
make
(
a
,
b
);
return
ir
::
Mod
::
make
(
a
,
b
);
}
}
Expr
min
(
Expr
a
,
Expr
b
)
{
Expr
min
(
Expr
a
,
Expr
b
)
{
// inf-aware simplificaiton
using
arith
::
is_pos_inf
;
using
arith
::
is_neg_inf
;
if
(
is_pos_inf
(
a
))
return
b
;
if
(
is_neg_inf
(
a
))
return
a
;
if
(
is_pos_inf
(
b
))
return
a
;
if
(
is_neg_inf
(
b
))
return
b
;
BinaryOpMatchTypes
(
a
,
b
);
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Min
>
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Min
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
if
(
ret
.
defined
())
return
ret
;
...
@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
...
@@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) {
}
}
Expr
max
(
Expr
a
,
Expr
b
)
{
Expr
max
(
Expr
a
,
Expr
b
)
{
// inf-aware simplificaiton
using
arith
::
is_pos_inf
;
using
arith
::
is_neg_inf
;
if
(
is_pos_inf
(
a
))
return
a
;
if
(
is_neg_inf
(
a
))
return
b
;
if
(
is_pos_inf
(
b
))
return
b
;
if
(
is_neg_inf
(
b
))
return
a
;
BinaryOpMatchTypes
(
a
,
b
);
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Max
>
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Max
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
if
(
ret
.
defined
())
return
ret
;
...
...
src/pass/loop_partition.cc
View file @
153417a5
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "../arithmetic/int_set
_internal
.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
...
@@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator {
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
GetIntervalAndCondset
(
const
Partition
&
partitions
,
GetIntervalAndCondset
(
const
Partition
&
partitions
,
const
arith
::
Interval
&
for_interval
,
const
arith
::
Interval
Set
&
for_interval
,
bool
cond_value
);
bool
cond_value
);
inline
Stmt
MakeFor
(
const
Node
*
op
,
Expr
extent
,
Stmt
body
);
inline
Stmt
MakeFor
(
const
Node
*
op
,
Expr
extent
,
Stmt
body
);
...
@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
...
@@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator {
/* Candidate IRs that may be partitioned potentially */
/* Candidate IRs that may be partitioned potentially */
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
hint_map_
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
hint_map_
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
relax_map_
;
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
relax_map_
;
arith
::
Analyzer
analyzer_
;
CandidateSelector
selector
;
CandidateSelector
selector
;
};
};
...
@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
...
@@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator {
// given in the second component provably have value given by cond_value
// given in the second component provably have value given by cond_value
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
std
::
pair
<
IntSet
,
std
::
unordered_set
<
const
Node
*>>
LoopPartitioner
::
GetIntervalAndCondset
(
const
Partition
&
partitions
,
LoopPartitioner
::
GetIntervalAndCondset
(
const
Partition
&
partitions
,
const
arith
::
Interval
&
for_interval
,
const
arith
::
Interval
Set
&
for_interval
,
bool
cond_value
)
{
bool
cond_value
)
{
Array
<
IntSet
>
sets
;
Array
<
IntSet
>
sets
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
for
(
const
auto
&
kv
:
partitions
)
{
for
(
const
auto
&
kv
:
partitions
)
{
if
(
kv
.
first
.
second
==
cond_value
)
{
if
(
kv
.
first
.
second
==
cond_value
)
{
arith
::
Interval
interval
=
kv
.
second
.
as
<
arith
::
IntervalSet
>
()
->
i
;
arith
::
IntervalSet
interval
=
Downcast
<
arith
::
IntervalSet
>
(
kv
.
second
);
arith
::
Interval
intersection
=
arith
::
Interval
::
make_intersection
(
interval
,
for_interval
);
arith
::
IntervalSet
intersection
=
arith
::
Intersect
(
if
(
!
intersection
.
is_empty
())
{
&
analyzer_
,
interval
,
for_interval
);
if
(
!
intersection
->
IsEmpty
())
{
sets
.
push_back
(
kv
.
second
);
sets
.
push_back
(
kv
.
second
);
cond_set
.
insert
(
kv
.
first
.
first
);
cond_set
.
insert
(
kv
.
first
.
first
);
}
}
...
@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
max
,
Expr
max
,
Stmt
body
,
Stmt
body
,
bool
partition_thread_scope
)
{
bool
partition_thread_scope
)
{
using
namespace
arith
;
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
PartitionFinder
finder
(
var
,
hint_map_
,
relax_map_
);
finder
.
Visit
(
body
);
finder
.
Visit
(
body
);
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
if
(
finder
.
partitions
.
empty
())
return
Stmt
();
arith
::
Interval
for_interval
(
min
,
max
);
arith
::
Interval
Set
for_interval
(
min
,
max
);
bool
cond_value
;
bool
cond_value
;
IntSet
middle_interval
;
IntSet
middle_interval
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
std
::
unordered_set
<
const
Node
*>
cond_set
;
...
@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
cond_value
=
true
;
cond_value
=
true
;
}
}
arith
::
Interval
middle_interval_i
=
middle_interval
.
as
<
arith
::
IntervalSet
>
()
->
i
;
IntervalSet
middle_interval_i
=
Downcast
<
IntervalSet
>
(
middle_interval
)
;
// middle_interval is the subrange of the loop variable range for which a
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
// The part of the loop variable range that is before (after resp.) that
...
@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
body_begin
;
Expr
body_begin
;
Stmt
pre_stmt
;
Stmt
pre_stmt
;
bool
pre_stmt_recurse
=
true
;
bool
pre_stmt_recurse
=
true
;
if
(
middle_interval_i
.
has_lower_b
ound
())
{
if
(
middle_interval_i
->
HasLowerB
ound
())
{
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
body_begin
=
ir
::
Simplify
(
middle_interval
.
min
());
if
(
!
can_prove
(
body_begin
==
min
))
{
if
(
!
can_prove
(
body_begin
==
min
))
{
Expr
cond
=
(
body_begin
-
min
>=
0
);
Expr
cond
=
(
body_begin
-
min
>=
0
);
...
@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
...
@@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Expr
post_doubt_begin
;
Expr
post_doubt_begin
;
Stmt
post_stmt
;
Stmt
post_stmt
;
bool
post_stmt_recurse
=
true
;
bool
post_stmt_recurse
=
true
;
if
(
middle_interval_i
.
has_upper_b
ound
())
{
if
(
middle_interval_i
->
HasUpperB
ound
())
{
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
post_doubt_begin
=
ir
::
Simplify
(
middle_interval
.
max
()
+
1
);
if
(
!
can_prove
(
middle_interval
.
max
()
==
max
))
{
if
(
!
can_prove
(
middle_interval
.
max
()
==
max
))
{
// require the extent to be non-negative
// require the extent to be non-negative
...
...
tests/python/unittest/test_arith_deduce_bound.py
0 → 100644
View file @
153417a5
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
c_s
=
tvm
.
arith
.
IntervalSet
(
10
,
15
)
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
zero
=
tvm
.
const
(
0
,
"int32"
)
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
e0
=
d
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
# expression containing variable a is on rhs
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max_value
))
==
str
(
ans0
)
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
# expression containing variable a is on rhs
e1
=
(
c
>
a
*
4
+
b
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max_value
))
==
str
(
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
# expression containing variable a is on rhs
e2
=
(
zero
<
tvm
.
max
(
5
,
a
*
4
))
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max_value
)
==
"neg_inf"
assert
str
(
res2
.
min_value
)
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min_value
))
==
str
(
ans3
)
def
test_check
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
c_s
=
tvm
.
arith
.
IntervalSet
(
5
,
7
)
d_s
=
tvm
.
arith
.
IntervalSet
(
-
3
,
-
1
)
# no compare operator
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
a
+
b
,
{
b
:
b_s
},
{})
assert
res1
.
is_nothing
()
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
.
astype
(
c
.
dtype
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res2
.
is_nothing
()
# multiple target variable
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
a
*
2
-
a
>
b
,
{
b
:
b_s
},
{})
assert
res2
.
is_nothing
()
def
test_deduce_basic
():
def
test_basic
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
a1
,
a2
)
e0
=
b
+
a
*
coff
+
3
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
<
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
test_basic
(
0
,
4
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
2
,
6
,
4
)
test_basic
(
0
,
4
,
-
4
)
test_basic
(
1
,
5
,
-
4
)
test_basic
(
2
,
6
,
-
4
)
def
test_deduce_complex
():
def
test_complex
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
IntervalSet
(
a1
,
a2
)
e0
=
(
b
*
3
+
a
*
coff
)
*
4
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
>
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<=
63
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
<=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max_value
,
b_s
.
max_value
]
if
coff
<
0
else
[
res1
.
min_value
,
b_s
.
min_value
]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>=
63
))
.
value
==
1
test_complex
(
0
,
4
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
2
,
6
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
1
,
5
,
-
4
)
test_complex
(
2
,
6
,
-
4
)
if
__name__
==
"__main__"
:
test_check
()
test_deduce_basic
()
test_deduce_complex
()
tests/python/unittest/test_arith_intset.py
View file @
153417a5
...
@@ -16,168 +16,87 @@
...
@@ -16,168 +16,87 @@
# under the License.
# under the License.
import
tvm
import
tvm
class
IntSetChecker
:
def
__init__
(
self
):
self
.
analyzer
=
tvm
.
arith
.
Analyzer
()
def
verify
(
self
,
data
,
dmap
,
expected
):
res
=
self
.
analyzer
.
int_set
(
data
,
dmap
)
def
err_msg
():
return
"
\n
data={}
\n
dmap={}
\n
res={}
\n
expected={}"
.
format
(
data
,
dmap
,
res
,
expected
)
def
equal
(
x
,
y
):
res
=
self
.
analyzer
.
canonical_simplify
(
x
-
y
)
return
tvm
.
ir_pass
.
Equal
(
res
,
0
)
assert
equal
(
res
.
min_value
,
expected
[
0
]),
err_msg
()
assert
equal
(
res
.
max_value
,
expected
[
1
]),
err_msg
()
def
test_basic
():
def
test_basic
():
s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
s
=
tvm
.
arith
.
IntervalSet
(
2
,
3
)
assert
s
.
min
()
.
value
==
2
assert
s
.
min_value
.
value
==
2
assert
s
.
max
()
.
value
==
3
assert
s
.
max_value
.
value
==
3
def
test_vector
():
def
test_vector
():
base
=
10
base
=
10
stride
=
3
stride
=
3
lanes
=
2
lanes
=
2
s
=
tvm
.
arith
.
intset_vector
(
tvm
.
make
.
Ramp
(
base
,
stride
,
lanes
))
s
=
tvm
.
arith
.
intset_vector
(
tvm
.
make
.
Ramp
(
base
,
stride
,
lanes
))
assert
s
.
min
()
.
value
==
base
assert
s
.
min_value
.
value
==
base
assert
s
.
max
()
.
value
==
base
+
stride
*
lanes
-
1
assert
s
.
max_value
.
value
==
base
+
stride
*
lanes
-
1
def
test_deduce
():
a
=
tvm
.
var
(
'a'
)
def
test_add_sub
():
b
=
tvm
.
var
(
'b'
)
ck
=
IntSetChecker
()
c
=
tvm
.
var
(
'c'
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
d
=
tvm
.
var
(
'd'
)
ck
.
verify
(
x
+
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
y
,
10
+
y
))
ck
.
verify
(
x
+
y
,
b_s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
),
y
:
tvm
.
arith
.
IntervalSet
(
1
,
11
)},
c_s
=
tvm
.
arith
.
intset_interval
(
10
,
15
)
(
1
,
21
))
d_s
=
tvm
.
arith
.
intset_interval
(
-
3
,
-
1
)
ck
.
verify
(
x
-
y
,
zero
=
tvm
.
const
(
0
,
"int32"
)
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
),
y
:
tvm
.
arith
.
IntervalSet
(
1
,
11
)},
(
-
11
,
9
))
e0
=
(
-
b
)
*
a
+
c
-
d
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
def
test_mul_div
():
ans0
=
((
d
-
c
)
/
(
b
*-
1
))
ck
=
IntSetChecker
()
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
# expression containing variable a is on rhs
ck
.
verify
(
x
*
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
*
y
))
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ck
.
verify
(
x
*
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
2
,
20
))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
ck
.
verify
(
x
*
-
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
-
20
,
-
2
))
ck
.
verify
(
x
/
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
10
/
y
))
e0
=
d
*
a
+
c
-
d
ck
.
verify
(
x
/
2
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
5
))
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ans0
=
((
0
-
c
)
/
d
+
1
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
def
test_mod
():
ck
=
IntSetChecker
()
# expression containing variable a is on rhs
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
res0
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ck
.
analyzer
.
update
(
y
,
tvm
.
arith
.
ConstIntBound
(
1
,
100
),
override
=
True
)
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res0
.
max
()))
==
str
(
ans0
)
ck
.
verify
(
x
%
y
,
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
0
,
y
-
1
))
ck
.
verify
(
x
%
10
,
{
x
:
tvm
.
arith
.
IntervalSet
(
1
,
10
)},
(
0
,
9
))
e1
=
(
a
*
4
+
b
<
c
)
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
def
test_max_min
():
ans1
=
(((
c
-
b
)
+
-
1
)
/
4
)
ck
=
IntSetChecker
()
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
ck
.
verify
(
tvm
.
max
(
x
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
1
,
11
))
# expression containing variable a is on rhs
ck
.
verify
(
tvm
.
min
(
x
-
1
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
-
1
,
9
))
e1
=
(
c
>
a
*
4
+
b
)
ck
.
verify
(
tvm
.
min
(
x
,
y
),
{},
(
tvm
.
min
(
x
,
y
),
tvm
.
min
(
x
,
y
)))
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e1
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ck
.
verify
(
tvm
.
max
(
x
,
y
),
{},
(
tvm
.
max
(
x
,
y
),
tvm
.
max
(
x
,
y
)))
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res1
.
max
()))
==
str
(
ans1
)
e2
=
(
tvm
.
max
(
5
,
a
*
4
)
<
0
)
def
test_select
():
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
ck
=
IntSetChecker
()
assert
str
(
res2
.
max
())
==
"neg_inf"
x
,
y
=
tvm
.
var
(
"x"
),
tvm
.
var
(
"y"
)
assert
str
(
res2
.
min
())
==
"pos_inf"
ck
.
verify
(
tvm
.
expr
.
Select
(
x
>
0
,
x
-
1
,
x
+
1
),
{
x
:
tvm
.
arith
.
IntervalSet
(
0
,
10
)},
(
-
1
,
11
))
# expression containing variable a is on rhs
e2
=
(
zero
<
tvm
.
max
(
5
,
a
*
4
))
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
e2
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{})
assert
str
(
res2
.
max
())
==
"neg_inf"
assert
str
(
res2
.
min
())
==
"pos_inf"
e3
=
(
-
b
)
+
a
*
c
-
d
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
e3
>=
0
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
ans3
=
2
/
c
+
1
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min
()))
==
str
(
ans3
)
res3
=
tvm
.
arith
.
DeduceBound
(
a
,
zero
<=
e3
,
{
b
:
b_s
,
c
:
c_s
,
d
:
d_s
},
{
b
:
b_s
,
d
:
d_s
})
assert
str
(
tvm
.
ir_pass
.
Simplify
(
res3
.
min
()))
==
str
(
ans3
)
def
test_check
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
c
=
tvm
.
var
(
'c'
)
d
=
tvm
.
var
(
'd'
)
b_s
=
tvm
.
arith
.
intset_interval
(
2
,
3
)
c_s
=
tvm
.
arith
.
intset_interval
(
5
,
7
)
d_s
=
tvm
.
arith
.
intset_interval
(
-
3
,
-
1
)
# no compare operator
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
a
+
b
,
{
b
:
b_s
},
{})
assert
res1
.
is_nothing
()
# multiple compare operators
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
(
a
+
b
>
3
)
.
astype
(
c
.
dtype
)
>
c
,
{
b
:
b_s
,
c
:
c_s
},
{})
assert
res2
.
is_nothing
()
# multiple target variable
res2
=
tvm
.
arith
.
DeduceBound
(
a
,
a
*
2
-
a
>
b
,
{
b
:
b_s
},
{})
assert
res2
.
is_nothing
()
def
test_deduce_basic
():
def
test_basic
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
intset_interval
(
a1
,
a2
)
e0
=
b
+
a
*
coff
+
3
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
<
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>
17
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
17
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
<=
17
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>=
17
,
{
b
:
b_s
},
{
b
:
b_s
})
[
x
,
y
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
((
x
*
coff
+
3
+
y
)
>=
17
))
.
value
==
1
test_basic
(
0
,
4
,
4
)
test_basic
(
1
,
5
,
4
)
test_basic
(
2
,
6
,
4
)
test_basic
(
0
,
4
,
-
4
)
test_basic
(
1
,
5
,
-
4
)
test_basic
(
2
,
6
,
-
4
)
def
test_deduce_complex
():
def
test_complex
(
a1
,
a2
,
coff
):
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
b_s
=
tvm
.
arith
.
intset_interval
(
a1
,
a2
)
e0
=
(
b
*
3
+
a
*
coff
)
*
4
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
<
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
>=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
>
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
<=
63
))
.
value
==
1
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
e0
>
63
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>
63
))
.
value
==
1
# expression containing variable a is on rhs
res1
=
tvm
.
arith
.
DeduceBound
(
a
,
tvm
.
const
(
63
,
"int32"
)
<=
e0
,
{
b
:
b_s
},
{
b
:
b_s
})
[
t
,
x
]
=
[
res1
.
max
(),
b_s
.
max
()]
if
coff
<
0
else
[
res1
.
min
(),
b_s
.
min
()]
assert
(
tvm
.
ir_pass
.
Simplify
(((
x
*
3
+
t
*
coff
)
*
4
)
>=
63
))
.
value
==
1
test_complex
(
0
,
4
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
2
,
6
,
4
)
test_complex
(
0
,
4
,
-
4
)
test_complex
(
1
,
5
,
-
4
)
test_complex
(
2
,
6
,
-
4
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_basic
()
test_basic
()
test_vector
()
test_vector
()
test_deduce
()
test_add_sub
()
test_check
()
test_mul_div
()
test_deduce_basic
()
test_max_min
()
test_deduce_complex
()
test_select
()
test_mod
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment