Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
65038950
Commit
65038950
authored
Oct 10, 2017
by
Tianqi Chen
Committed by
GitHub
Oct 10, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Improve detect linear equation (#529)
* [ARITH] Improve detect linear equation * fix doc
parent
46082223
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
236 additions
and
33 deletions
+236
-33
HalideIR
+1
-1
include/tvm/arithmetic.h
+20
-8
src/api/api_arith.cc
+5
-0
src/arithmetic/detect_linear_equation.cc
+156
-10
src/pass/narrow_channel_access.cc
+3
-3
tests/python/unittest/test_arith_detect_clip_bound.py
+21
-0
tests/python/unittest/test_arith_detect_linear_equation.py
+30
-11
No files found.
HalideIR
@
a40a3e2f
Subproject commit
dbf043a8d8bf379b05c56d8aa9025db55f589d6d
Subproject commit
a40a3e2fedee88d2f7b97ba4caf8a9d0eb25886f
include/tvm/arithmetic.h
View file @
65038950
...
@@ -118,7 +118,7 @@ class IntSet : public NodeRef {
...
@@ -118,7 +118,7 @@ class IntSet : public NodeRef {
* \brief Range of a linear integer function.
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
* Use to do specify the possible index values.
*
*
* set = {
base + coeff * x
| x in Z }
* set = {
coeff * x + base
| x in Z }
*
*
* When coeff != 0, it can also be written as
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
* set = { n | n % coeff == base }
...
@@ -127,16 +127,17 @@ class IntSet : public NodeRef {
...
@@ -127,16 +127,17 @@ class IntSet : public NodeRef {
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
*/
struct
ModularEntry
{
struct
ModularEntry
{
/*! \brief The base */
int
base
{
0
};
/*! \brief linear co-efficient */
/*! \brief linear co-efficient */
int
coeff
{
1
};
int
coeff
{
1
};
/*! \brief The base */
int
base
{
0
};
/*! \return entry represent everything */
/*! \return entry represent everything */
static
ModularEntry
everything
()
{
static
ModularEntry
everything
()
{
// always safe to set 0 + x, so it can be everything.
// always safe to set 0 + x, so it can be everything.
ModularEntry
e
;
ModularEntry
e
;
e
.
base
=
0
;
e
.
coeff
=
1
;
e
.
coeff
=
1
;
e
.
base
=
0
;
return
e
;
return
e
;
}
}
/*!
/*!
...
@@ -157,14 +158,25 @@ struct IntSetNode : public Node {
...
@@ -157,14 +158,25 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
TVM_DECLARE_BASE_NODE_INFO
(
IntSetNode
,
Node
);
};
};
/*!
/*!
* \brief Detect if e can be rewritten as e =
base + var * coeff
* \brief Detect if e can be rewritten as e =
sum_{i=0}^n var[i] * coeff[i] + coeff[n]
* Where coeff and base are invariant of var.
* Where coeff and base are invariant of var.
*
*
* \return [base, coeff] if it is possible, empty array if it is not.
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
*/
Array
<
Expr
>
Detect
LinearEquation
(
Expr
e
,
Var
var
);
Array
<
Expr
>
Detect
ClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
);
/*!
/*!
* \brief Find an symbolic integer set that contains all possible values of
* \brief Find an symbolic integer set that contains all possible values of
...
...
src/api/api_arith.cc
View file @
65038950
...
@@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation")
...
@@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation")
*
ret
=
DetectLinearEquation
(
args
[
0
],
args
[
1
]);
*
ret
=
DetectLinearEquation
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_API
(
"arith.DetectClipBound"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DetectClipBound
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_API
(
"arith.DeduceBound"
)
TVM_REGISTER_API
(
"arith.DeduceBound"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
DeduceBound
(
args
[
0
],
args
[
1
],
*
ret
=
DeduceBound
(
args
[
0
],
args
[
1
],
...
...
src/arithmetic/detect_linear_equation.cc
View file @
65038950
...
@@ -21,22 +21,27 @@ struct LinearEqEntry {
...
@@ -21,22 +21,27 @@ struct LinearEqEntry {
Expr
coeff
;
Expr
coeff
;
};
};
struct
IntervalEntry
{
Expr
min_value
;
Expr
max_value
;
};
class
LinearEqDetector
class
LinearEqDetector
:
public
ExprFunctor
<
LinearEqEntry
(
const
Expr
&
,
const
Expr
&
)
>
{
:
public
ExprFunctor
<
LinearEqEntry
(
const
Expr
&
,
const
Expr
&
)
>
{
public
:
public
:
explicit
LinearEqDetector
(
Var
var
)
explicit
LinearEqDetector
(
Var
var
)
:
var_
(
var
)
{}
:
var_
(
var
)
{}
Array
<
Expr
>
Detect
(
const
Expr
&
e
)
{
bool
Detect
(
const
Expr
&
e
,
LinearEqEntry
*
ret
)
{
LinearEqEntry
ret
=
VisitExpr
(
e
,
e
);
*
ret
=
VisitExpr
(
e
,
e
);
if
(
fail_
)
return
Array
<
Expr
>
()
;
if
(
fail_
)
return
false
;
if
(
!
ret
.
base
.
defined
())
{
if
(
!
ret
->
base
.
defined
())
{
ret
.
base
=
make_zero
(
var_
.
type
());
ret
->
base
=
make_zero
(
var_
.
type
());
}
}
if
(
!
ret
.
coeff
.
defined
())
{
if
(
!
ret
->
coeff
.
defined
())
{
ret
.
coeff
=
make_zero
(
var_
.
type
());
ret
->
coeff
=
make_zero
(
var_
.
type
());
}
}
return
Array
<
Expr
>
{
ret
.
base
,
ret
.
coeff
}
;
return
true
;
}
}
LinearEqEntry
VisitExpr_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
LinearEqEntry
VisitExpr_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
...
@@ -48,6 +53,17 @@ class LinearEqDetector
...
@@ -48,6 +53,17 @@ class LinearEqDetector
ret
.
coeff
=
AddCombine
(
a
.
coeff
,
b
.
coeff
);
ret
.
coeff
=
AddCombine
(
a
.
coeff
,
b
.
coeff
);
return
ret
;
return
ret
;
}
}
LinearEqEntry
VisitExpr_
(
const
Sub
*
op
,
const
Expr
&
e
)
final
{
if
(
fail_
)
return
LinearEqEntry
();
LinearEqEntry
a
=
VisitExpr
(
op
->
a
,
op
->
a
);
LinearEqEntry
b
=
VisitExpr
(
op
->
b
,
op
->
b
);
LinearEqEntry
ret
;
ret
.
base
=
SubCombine
(
a
.
base
,
b
.
base
);
ret
.
coeff
=
SubCombine
(
a
.
coeff
,
b
.
coeff
);
return
ret
;
}
LinearEqEntry
VisitExpr_
(
const
Mul
*
op
,
const
Expr
&
e
)
final
{
LinearEqEntry
VisitExpr_
(
const
Mul
*
op
,
const
Expr
&
e
)
final
{
if
(
fail_
)
return
LinearEqEntry
();
if
(
fail_
)
return
LinearEqEntry
();
LinearEqEntry
a
=
VisitExpr
(
op
->
a
,
op
->
a
);
LinearEqEntry
a
=
VisitExpr
(
op
->
a
,
op
->
a
);
...
@@ -94,6 +110,11 @@ class LinearEqDetector
...
@@ -94,6 +110,11 @@ class LinearEqDetector
if
(
!
b
.
defined
())
return
a
;
if
(
!
b
.
defined
())
return
a
;
return
ComputeExpr
<
Add
>
(
a
,
b
);
return
ComputeExpr
<
Add
>
(
a
,
b
);
}
}
Expr
SubCombine
(
Expr
a
,
Expr
b
)
{
if
(
!
a
.
defined
())
return
-
b
;
if
(
!
b
.
defined
())
return
a
;
return
ComputeExpr
<
Sub
>
(
a
,
b
);
}
Expr
MulCombine
(
Expr
a
,
Expr
b
)
{
Expr
MulCombine
(
Expr
a
,
Expr
b
)
{
if
(
!
a
.
defined
())
return
a
;
if
(
!
a
.
defined
())
return
a
;
if
(
!
b
.
defined
())
return
b
;
if
(
!
b
.
defined
())
return
b
;
...
@@ -101,9 +122,134 @@ class LinearEqDetector
...
@@ -101,9 +122,134 @@ class LinearEqDetector
}
}
};
};
Array
<
Expr
>
DetectLinearEquation
(
Expr
e
,
Var
var
)
{
Array
<
Expr
>
DetectLinearEquation
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
)
{
return
LinearEqDetector
(
var
).
Detect
(
e
);
CHECK_GE
(
vars
.
size
(),
1U
);
Expr
base
=
e
;
Array
<
Expr
>
coeff
;
for
(
Var
v
:
vars
)
{
LinearEqEntry
ret
;
if
(
!
LinearEqDetector
(
v
).
Detect
(
base
,
&
ret
))
{
return
Array
<
Expr
>
();
}
coeff
.
push_back
(
ret
.
coeff
);
base
=
std
::
move
(
ret
.
base
);
}
std
::
unordered_set
<
const
Variable
*>
vset
;
for
(
size_t
i
=
vars
.
size
();
i
!=
1
;
--
i
)
{
vset
.
insert
(
vars
[
i
-
1
].
get
());
// The previous coeff contains the variable
if
(
ExprUseVar
(
coeff
[
i
-
2
],
vset
))
{
return
Array
<
Expr
>
();
}
}
coeff
.
push_back
(
base
);
return
coeff
;
}
// Detect clip condition as min max value
bool
DetectClipBound
(
const
Expr
&
cond
,
std
::
unordered_map
<
const
Variable
*
,
IntervalEntry
>*
bmap
)
{
int
flag
=
0
;
Var
var
;
auto
fvisit
=
[
&
bmap
,
&
flag
,
&
var
](
const
NodeRef
&
n
)
{
if
(
const
Variable
*
v
=
n
.
as
<
Variable
>
())
{
if
(
bmap
->
count
(
v
))
{
if
(
flag
==
0
)
{
var
=
Var
(
n
.
node_
);
flag
=
1
;
}
else
if
(
flag
==
1
)
{
if
(
!
var
.
same_as
(
n
))
{
flag
=
-
1
;
}
}
}
}
};
PostOrderVisit
(
cond
,
fvisit
);
if
(
flag
!=
1
)
return
false
;
// canonical form: exp >= 0
Expr
canonical
;
if
(
const
LT
*
op
=
cond
.
as
<
LT
>
())
{
if
(
!
op
->
a
.
type
().
is_int
())
return
false
;
canonical
=
op
->
b
-
op
->
a
-
make_const
(
op
->
a
.
type
(),
1
);
}
else
if
(
const
LE
*
op
=
cond
.
as
<
LE
>
())
{
if
(
!
op
->
a
.
type
().
is_int
())
return
false
;
canonical
=
op
->
b
-
op
->
a
;
}
else
if
(
const
GT
*
op
=
cond
.
as
<
GT
>
())
{
if
(
!
op
->
a
.
type
().
is_int
())
return
false
;
canonical
=
op
->
a
-
op
->
b
-
make_const
(
op
->
a
.
type
(),
1
);
}
else
if
(
const
GE
*
op
=
cond
.
as
<
GE
>
())
{
if
(
!
op
->
a
.
type
().
is_int
())
return
false
;
canonical
=
op
->
a
-
op
->
b
;
}
else
{
return
false
;
}
LinearEqEntry
ret
;
if
(
!
LinearEqDetector
(
var
).
Detect
(
canonical
,
&
ret
))
return
false
;
ret
.
coeff
=
Simplify
(
ret
.
coeff
);
IntervalEntry
&
p
=
(
*
bmap
)[
var
.
get
()];
if
(
is_one
(
ret
.
coeff
))
{
// var + shift >=0 -> var >= -shift
if
(
p
.
min_value
.
defined
())
{
p
.
min_value
=
ir
::
Max
::
make
(
p
.
min_value
,
-
ret
.
base
);
}
else
{
p
.
min_value
=
-
ret
.
base
;
}
return
true
;
}
if
(
is_const
(
ret
.
coeff
,
-
1
))
{
// -var + shift >=0 -> var <= shift
if
(
p
.
max_value
.
defined
())
{
p
.
max_value
=
ir
::
Min
::
make
(
p
.
max_value
,
ret
.
base
);
}
else
{
p
.
max_value
=
ret
.
base
;
}
return
true
;
}
return
false
;
}
template
<
typename
OP
>
void
SplitCommExpr
(
const
Expr
&
e
,
std
::
vector
<
Expr
>*
ret
)
{
if
(
const
OP
*
op
=
e
.
as
<
OP
>
())
{
SplitCommExpr
<
OP
>
(
op
->
a
,
ret
);
SplitCommExpr
<
OP
>
(
op
->
b
,
ret
);
}
else
{
ret
->
push_back
(
e
);
}
}
// Detect the lower and upper bound from the expression.
// e must be connected by and.
Array
<
Expr
>
DetectClipBound
(
const
Expr
&
e
,
const
Array
<
Var
>&
vars
)
{
std
::
vector
<
Expr
>
splits
;
SplitCommExpr
<
ir
::
And
>
(
e
,
&
splits
);
std
::
unordered_map
<
const
Variable
*
,
IntervalEntry
>
rmap
;
for
(
Var
v
:
vars
)
{
rmap
[
v
.
get
()]
=
IntervalEntry
();
}
for
(
Expr
cond
:
splits
)
{
if
(
!
DetectClipBound
(
cond
,
&
rmap
))
return
Array
<
Expr
>
();
}
Array
<
Expr
>
ret
;
for
(
Var
v
:
vars
)
{
IntervalEntry
e
=
rmap
[
v
.
get
()];
if
(
e
.
min_value
.
defined
())
{
e
.
min_value
=
Simplify
(
e
.
min_value
);
}
if
(
e
.
max_value
.
defined
())
{
e
.
max_value
=
Simplify
(
e
.
max_value
);
}
ret
.
push_back
(
e
.
min_value
);
ret
.
push_back
(
e
.
max_value
);
}
return
ret
;
}
}
}
// namespace arith
}
// namespace arith
}
// namespace tvm
}
// namespace tvm
src/pass/narrow_channel_access.cc
View file @
65038950
...
@@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator {
...
@@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator {
r
=
Range
::
make_by_min_extent
(
r
=
Range
::
make_by_min_extent
(
ir
::
Simplify
(
r
->
min
),
ir
::
Simplify
(
r
->
extent
));
ir
::
Simplify
(
r
->
min
),
ir
::
Simplify
(
r
->
extent
));
if
(
ExprUseVar
(
r
->
extent
,
var
))
return
body
;
if
(
ExprUseVar
(
r
->
extent
,
var
))
return
body
;
Array
<
Expr
>
linear_eq
=
DetectLinearEquation
(
r
->
min
,
var
);
Array
<
Expr
>
linear_eq
=
DetectLinearEquation
(
r
->
min
,
{
var
}
);
if
(
linear_eq
.
size
()
==
0
)
return
body
;
if
(
linear_eq
.
size
()
==
0
)
return
body
;
Expr
base
=
linear_eq
[
0
];
Expr
coeff
=
linear_eq
[
0
];
Expr
coeff
=
linear_eq
[
1
];
Expr
base
=
linear_eq
[
1
];
if
(
!
is_zero
(
base
))
return
body
;
if
(
!
is_zero
(
base
))
return
body
;
Expr
left
=
ir
::
Simplify
(
adv_op
->
value
-
coeff
*
for_op
->
extent
);
Expr
left
=
ir
::
Simplify
(
adv_op
->
value
-
coeff
*
for_op
->
extent
);
if
(
!
can_prove
(
left
>=
0
))
return
body
;
if
(
!
can_prove
(
left
>=
0
))
return
body
;
...
...
tests/python/unittest/test_arith_detect_clip_bound.py
0 → 100644
View file @
65038950
import
tvm
def
test_basic
():
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
c
=
tvm
.
var
(
"c"
)
m
=
tvm
.
arith
.
DetectClipBound
(
tvm
.
all
(
a
*
1
<
b
*
6
,
a
-
1
>
0
),
[
a
])
assert
tvm
.
ir_pass
.
Simplify
(
m
[
1
]
-
(
b
*
6
-
1
))
.
value
==
0
assert
m
[
0
]
.
value
==
2
m
=
tvm
.
arith
.
DetectClipBound
(
tvm
.
all
(
a
*
1
<
b
*
6
,
a
-
1
>
0
),
[
a
,
b
])
assert
len
(
m
)
==
0
m
=
tvm
.
arith
.
DetectClipBound
(
tvm
.
all
(
a
+
10
*
c
<=
20
,
b
-
1
>
0
),
[
a
,
b
])
assert
tvm
.
ir_pass
.
Simplify
(
m
[
1
]
-
(
20
-
10
*
c
))
.
value
==
0
assert
tvm
.
ir_pass
.
Simplify
(
m
[
2
]
-
2
)
.
value
==
0
if
__name__
==
"__main__"
:
test_basic
()
tests/python/unittest/test_arith_detect_linear_equation.py
View file @
65038950
...
@@ -3,22 +3,41 @@ import tvm
...
@@ -3,22 +3,41 @@ import tvm
def
test_basic
():
def
test_basic
():
a
=
tvm
.
var
(
"a"
)
a
=
tvm
.
var
(
"a"
)
b
=
tvm
.
var
(
"b"
)
b
=
tvm
.
var
(
"b"
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
+
b
*
6
+
7
,
a
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
+
b
*
6
+
7
,
[
a
]
)
assert
m
[
1
]
.
value
==
4
assert
m
[
0
]
.
value
==
4
assert
tvm
.
ir_pass
.
Simplify
(
m
[
0
]
-
(
b
*
6
+
7
))
.
value
==
0
assert
tvm
.
ir_pass
.
Simplify
(
m
[
1
]
-
(
b
*
6
+
7
))
.
value
==
0
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
*
(
a
+
1
)
+
b
*
6
+
7
,
a
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
*
(
a
+
1
)
+
b
*
6
+
7
,
[
a
]
)
assert
len
(
m
)
==
0
assert
len
(
m
)
==
0
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
+
(
a
+
1
)
+
b
*
6
+
7
,
a
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
4
+
(
a
+
1
)
+
b
*
6
+
7
,
[
a
]
)
assert
m
[
1
]
.
value
==
5
assert
m
[
0
]
.
value
==
5
assert
tvm
.
ir_pass
.
Simplify
(
m
[
0
]
-
(
b
*
6
+
7
+
1
))
.
value
==
0
assert
tvm
.
ir_pass
.
Simplify
(
m
[
1
]
-
(
b
*
6
+
7
+
1
))
.
value
==
0
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
b
+
7
,
a
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
a
*
b
+
7
,
[
a
]
)
assert
m
[
1
]
==
b
assert
m
[
0
]
==
b
m
=
tvm
.
arith
.
DetectLinearEquation
(
b
*
7
,
a
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
b
*
7
,
[
a
])
assert
m
[
1
]
.
value
==
0
assert
m
[
0
]
.
value
==
0
def
test_multivariate
():
v
=
[
tvm
.
var
(
"v
%
d"
%
i
)
for
i
in
range
(
4
)]
b
=
tvm
.
var
(
"b"
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
v
[
0
]
*
(
b
+
4
)
+
v
[
0
]
+
v
[
1
]
*
8
,
v
)
assert
(
tvm
.
ir_pass
.
Equal
(
tvm
.
ir_pass
.
Simplify
(
m
[
0
]),
b
+
5
))
assert
(
m
[
1
]
.
value
==
8
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
v
[
0
]
*
(
b
+
4
)
+
v
[
0
]
+
v
[
1
]
*
8
*
v
[
2
],
v
)
assert
(
len
(
m
)
==
0
)
m
=
tvm
.
arith
.
DetectLinearEquation
(
v
[
0
]
*
(
b
+
4
)
+
v
[
0
]
+
v
[
1
]
*
8
*
v
[
1
]
+
v
[
3
],
v
)
assert
(
len
(
m
)
==
0
)
m
=
tvm
.
arith
.
DetectLinearEquation
(((
v
[
0
]
*
b
+
v
[
1
])
*
8
+
v
[
2
]
+
1
)
*
2
,
v
)
assert
(
m
[
1
]
.
value
==
16
)
assert
(
m
[
2
]
.
value
==
2
)
assert
(
m
[
len
(
m
)
-
1
]
.
value
==
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_basic
()
test_basic
()
test_multivariate
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment