Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
39c116f0
Commit
39c116f0
authored
Mar 31, 2019
by
Siyuan Feng
Committed by
Lianmin Zheng
Mar 30, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix intersect of modular set (#2904)
Fix comment bugs and code style
parent
fb7fa8e4
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
62 deletions
+113
-62
src/arithmetic/modular_set.cc
+96
-62
tests/python/unittest/test_arith_modular_set.py
+17
-0
No files found.
src/arithmetic/modular_set.cc
View file @
39c116f0
...
@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
...
@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
int64_t
coeff
{
1
};
int64_t
coeff
{
1
};
int64_t
base
{
0
};
int64_t
base
{
0
};
Entry
()
=
default
;
Entry
(
int64_t
coeff
,
int64_t
base
)
{
CHECK_GE
(
coeff
,
0
);
this
->
coeff
=
coeff
;
if
(
coeff
!=
0
)
{
base
=
base
%
coeff
;
if
(
base
<
0
)
base
+=
coeff
;
}
this
->
base
=
base
;
}
bool
is_const
()
const
{
bool
is_const
()
const
{
return
coeff
==
0
;
return
coeff
==
0
;
}
}
...
@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
...
@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
if
(
!
override
)
{
if
(
!
override
)
{
CHECK
(
!
var_map_
.
count
(
var
));
CHECK
(
!
var_map_
.
count
(
var
));
}
}
Entry
e
;
var_map_
[
var
]
=
Entry
(
info
->
coeff
,
info
->
base
);
e
.
coeff
=
info
->
coeff
;
e
.
base
=
info
->
base
;
var_map_
[
var
]
=
e
;
}
}
// Detect useful constraints and use them in the analysis scope.
// Detect useful constraints and use them in the analysis scope.
...
@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
...
@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
PVar
<
Integer
>
coeff
,
base
;
PVar
<
Integer
>
coeff
,
base
;
// pattern match interesting constraints
// pattern match interesting constraints
if
(((
var
%
coeff
)
==
base
).
Match
(
constraint
))
{
if
(((
var
%
coeff
)
==
base
).
Match
(
constraint
))
{
Entry
entry
;
Entry
entry
(
coeff
.
Eval
()
->
value
,
base
.
Eval
()
->
value
);
entry
.
coeff
=
coeff
.
Eval
()
->
value
;
entry
.
base
=
base
.
Eval
()
->
value
;
return
UpdateByIntersect
(
var
.
Eval
(),
entry
);
return
UpdateByIntersect
(
var
.
Eval
(),
entry
);
}
}
return
nullptr
;
return
nullptr
;
...
@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
...
@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
}
}
Entry
VisitExpr_
(
const
IntImm
*
op
)
final
{
Entry
VisitExpr_
(
const
IntImm
*
op
)
final
{
Entry
ret
;
return
Entry
(
0
,
op
->
value
);
ret
.
base
=
op
->
value
;
ret
.
coeff
=
0
;
return
ret
;
}
}
Entry
VisitExpr_
(
const
UIntImm
*
op
)
final
{
Entry
VisitExpr_
(
const
UIntImm
*
op
)
final
{
if
(
op
->
value
<
std
::
numeric_limits
<
int64_t
>::
max
())
{
if
(
op
->
value
<
std
::
numeric_limits
<
int64_t
>::
max
())
{
Entry
ret
;
return
Entry
(
0
,
static_cast
<
int
>
(
op
->
value
));
ret
.
base
=
static_cast
<
int
>
(
op
->
value
);
ret
.
coeff
=
0
;
return
ret
;
}
else
{
}
else
{
return
Everything
();
return
Everything
();
}
}
...
@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
...
@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
Entry
VisitExpr_
(
const
Add
*
op
)
final
{
Entry
VisitExpr_
(
const
Add
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
a
);
Entry
a
=
VisitExpr
(
op
->
a
);
Entry
b
=
VisitExpr
(
op
->
b
);
Entry
b
=
VisitExpr
(
op
->
b
);
Entry
ret
;
int64_t
coeff
=
ZeroAwareGCD
(
a
.
coeff
,
b
.
coeff
);
ret
.
coeff
=
ZeroAwareGCD
(
a
.
coeff
,
b
.
coeff
);
return
Entry
(
coeff
,
a
.
base
+
b
.
base
);
ret
.
base
=
BaseSimplify
(
a
.
base
+
b
.
base
,
ret
.
coeff
);
return
ret
;
}
}
Entry
VisitExpr_
(
const
Sub
*
op
)
final
{
Entry
VisitExpr_
(
const
Sub
*
op
)
final
{
Entry
a
=
VisitExpr
(
op
->
a
);
Entry
a
=
VisitExpr
(
op
->
a
);
Entry
b
=
VisitExpr
(
op
->
b
);
Entry
b
=
VisitExpr
(
op
->
b
);
Entry
ret
;
int64_t
coeff
=
ZeroAwareGCD
(
a
.
coeff
,
b
.
coeff
);
ret
.
coeff
=
ZeroAwareGCD
(
a
.
coeff
,
b
.
coeff
);
return
Entry
(
coeff
,
a
.
base
-
b
.
base
);
ret
.
base
=
BaseSimplify
(
a
.
base
-
b
.
base
,
ret
.
coeff
);
return
ret
;
}
}
Entry
VisitExpr_
(
const
Mul
*
op
)
final
{
Entry
VisitExpr_
(
const
Mul
*
op
)
final
{
...
@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
...
@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
int64_t
pq
=
a
.
coeff
*
b
.
coeff
;
int64_t
pq
=
a
.
coeff
*
b
.
coeff
;
int64_t
pm
=
a
.
coeff
*
b
.
base
;
int64_t
pm
=
a
.
coeff
*
b
.
base
;
int64_t
qn
=
a
.
base
*
b
.
coeff
;
int64_t
qn
=
a
.
base
*
b
.
coeff
;
Entry
ret
;
int64_t
coeff
=
ZeroAwareGCD
(
pq
,
ZeroAwareGCD
(
pm
,
qn
));
ret
.
coeff
=
ZeroAwareGCD
(
pq
,
ZeroAwareGCD
(
pm
,
qn
));
return
Entry
(
coeff
,
a
.
base
*
b
.
base
);
ret
.
base
=
BaseSimplify
(
a
.
base
*
b
.
base
,
ret
.
coeff
);
return
ret
;
}
}
Entry
DivByConst
(
const
Expr
&
lhs
,
Entry
DivByConst
(
const
Expr
&
lhs
,
...
@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
...
@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
Entry
a
=
VisitExpr
(
lhs
);
Entry
a
=
VisitExpr
(
lhs
);
CHECK_NE
(
val
,
0
);
CHECK_NE
(
val
,
0
);
if
(
a
.
coeff
%
val
==
0
)
{
if
(
a
.
coeff
%
val
==
0
)
{
Entry
ret
;
if
(
a
.
base
==
0
)
{
if
(
a
.
base
==
0
)
{
// a c x / c -> a x
// a c x / c -> a x
ret
.
coeff
=
std
::
abs
(
a
.
coeff
/
val
);
return
Entry
(
std
::
abs
(
a
.
coeff
/
val
),
0
);
ret
.
base
=
0
;
return
ret
;
}
}
// positive division have a clear rounding mode.
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
// Only handle case where we clearly know we need to round down.
if
(
a
.
base
>
0
&&
val
>
0
&&
if
(
a
.
base
>
0
&&
val
>
0
&&
(
round_down
||
parent_
->
CanProveGreaterEqual
(
lhs
,
0
)))
{
(
round_down
||
parent_
->
CanProveGreaterEqual
(
lhs
,
0
)))
{
ret
.
coeff
=
a
.
coeff
/
val
;
return
Entry
(
a
.
coeff
/
val
,
a
.
base
/
val
);
ret
.
base
=
a
.
base
/
val
;
return
ret
;
}
}
}
}
return
Everything
();
return
Everything
();
...
@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
...
@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
}
}
int64_t
base0
=
a
.
base
%
coeff
;
int64_t
base0
=
a
.
base
%
coeff
;
int64_t
base1
=
b
.
base
%
coeff
;
int64_t
base1
=
b
.
base
%
coeff
;
Entry
ret
;
if
(
base0
==
base1
)
{
if
(
base0
==
base1
)
{
ret
.
coeff
=
coeff
;
return
Entry
(
coeff
,
base0
);
ret
.
base
=
base0
;
return
ret
;
}
else
{
}
else
{
ret
.
coeff
=
ZeroAwareGCD
(
ZeroAwareGCD
(
base0
,
base1
),
coeff
);
return
Entry
(
ZeroAwareGCD
(
ZeroAwareGCD
(
base0
,
base1
),
coeff
),
base0
);
ret
.
base
=
0
;
return
ret
;
}
}
}
}
/*!
/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static
int64_t
ExtendedEuclidean
(
int64_t
a
,
int64_t
b
,
int64_t
*
x
,
int64_t
*
y
)
{
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t
s
=
0
,
old_s
=
1
;
int64_t
r
=
b
,
old_r
=
a
>=
0
?
a
:
-
a
;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while
(
r
!=
0
)
{
int64_t
q
=
old_r
/
r
;
int64_t
tmp
=
old_r
;
old_r
=
r
;
r
=
tmp
-
q
*
r
;
tmp
=
old_s
;
old_s
=
s
;
s
=
tmp
-
q
*
s
;
}
*
x
=
a
>=
0
?
old_s
:
-
old_s
;
if
(
b
!=
0
)
{
*
y
=
(
old_r
-
(
*
x
)
*
a
)
/
b
;
}
else
{
*
y
=
1
;
}
return
old_r
;
}
/*!
* \brief Create interect of two sets.
* \brief Create interect of two sets.
* \param a The left operand.
* \param a The left operand.
* \param b the right operand.
* \param b the right operand.
*/
*/
static
Entry
Intersect
(
Entry
a
,
Entry
b
)
{
static
Entry
Intersect
(
Entry
a
,
Entry
b
)
{
// simple rule for now: pick higher constraints.
int64_t
x
,
y
;
// TODO(team-team): Use extended euclidean algorithm.
int64_t
c1
=
a
.
coeff
,
b1
=
a
.
base
,
c2
=
b
.
coeff
,
b2
=
b
.
base
;
if
(
a
.
coeff
==
0
)
return
a
;
// z = c1 * p + b1
if
(
b
.
coeff
==
0
)
return
b
;
// z = c2 * q + b2
if
(
a
.
coeff
>=
b
.
coeff
)
return
a
;
// c1 * x + c2 * y = gcd(c1, c2)
return
b
;
// -> c1 * p - c2 * q = b2 - b1
}
// -> p = (b2 - b1) / gcd * x
/*!
// -> q = (b2 - b1) / gcd * (-y)
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
// -> z = LCM(x, y) * k + (c1 * p + b1)
* \param base The base value.
int64_t
gcd
=
ExtendedEuclidean
(
c1
,
c2
,
&
x
,
&
y
);
* \param coeff The coeff value.
int64_t
v
=
b2
-
b1
;
* \return The simplified base.
if
(
v
%
gcd
==
0
)
{
*/
x
=
v
/
gcd
*
x
;
static
int64_t
BaseSimplify
(
int64_t
base
,
int64_t
coeff
)
{
y
=
v
/
gcd
*
(
-
y
);
if
(
coeff
==
0
)
return
base
;
int64_t
coeff
=
c1
/
gcd
*
c2
;
base
=
base
%
coeff
;
return
Entry
(
coeff
,
x
*
c1
+
b1
);
if
(
base
<
0
)
base
+=
coeff
;
}
else
{
return
base
;
return
Nothing
();
}
}
}
/*!
/*!
* \brief Take GCD of a and b.
* \brief Take GCD of a and b.
...
@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
...
@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
*/
static
Entry
Everything
()
{
static
Entry
Everything
()
{
Entry
ret
;
return
Entry
(
1
,
0
);
ret
.
coeff
=
1
;
ret
.
base
=
0
;
}
return
ret
;
/*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
static
Entry
Nothing
()
{
return
Entry
(
0
,
1
);
}
}
};
};
...
...
tests/python/unittest/test_arith_modular_set.py
View file @
39c116f0
...
@@ -117,6 +117,22 @@ def test_constraint_scope():
...
@@ -117,6 +117,22 @@ def test_constraint_scope():
assert
m
.
coeff
==
1
assert
m
.
coeff
==
1
assert
m
.
base
==
0
assert
m
.
base
==
0
def
test_intersect
():
a
=
tvm
.
var
(
"a"
)
analyzer
=
tvm
.
arith
.
Analyzer
()
with
analyzer
.
constraint_scope
(
a
%
4
==
1
):
with
analyzer
.
constraint_scope
(
a
%
3
==
1
):
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
12
assert
m
.
base
==
1
with
analyzer
.
constraint_scope
(
a
%
3
==
2
):
with
analyzer
.
constraint_scope
(
a
%
5
==
3
):
with
analyzer
.
constraint_scope
(
a
%
7
==
2
):
m
=
analyzer
.
modular_set
(
a
)
assert
m
.
coeff
==
105
assert
m
.
base
==
23
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cast
()
test_cast
()
...
@@ -126,3 +142,4 @@ if __name__ == "__main__":
...
@@ -126,3 +142,4 @@ if __name__ == "__main__":
test_min_max_select
()
test_min_max_select
()
test_mix_index
()
test_mix_index
()
test_constraint_scope
()
test_constraint_scope
()
test_intersect
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment