Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2bb1d8e4
Commit
2bb1d8e4
authored
Nov 28, 2017
by
Tianqi Chen
Committed by
GitHub
Nov 28, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Upgrade CanonicalSimplify to Simplify Mod (#676)
parent
2e3f8e74
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
14 deletions
+103
-14
include/tvm/ir_pass.h
+6
-2
src/api/api_pass.cc
+8
-0
src/arithmetic/canonical.cc
+70
-11
src/arithmetic/canonical.h
+1
-1
tests/python/unittest/test_arith_simplify.py
+18
-0
No files found.
include/tvm/ir_pass.h
View file @
2bb1d8e4
...
...
@@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt
CanonicalSimplify
(
Stmt
stmt
);
Stmt
CanonicalSimplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
=
Map
<
Var
,
Range
>
());
/*!
* \brief Simplify by applying canonical form.
* \param expr The statement to be canonically simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized expression.
*/
Expr
CanonicalSimplify
(
Expr
expr
);
Expr
CanonicalSimplify
(
Expr
expr
,
Map
<
Var
,
Range
>
vrange
=
Map
<
Var
,
Range
>
());
/*!
* \brief Deep compare lhs and rhs
...
...
src/api/api_pass.cc
View file @
2bb1d8e4
...
...
@@ -33,10 +33,18 @@ TVM_REGISTER_API("ir_pass.Simplify")
TVM_REGISTER_API
(
"ir_pass.CanonicalSimplify"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
IsNodeType
<
Stmt
>
())
{
if
(
args
.
size
()
>
1
)
{
*
ret
=
CanonicalSimplify
(
args
[
0
].
operator
Stmt
(),
args
[
1
]);
}
else
{
*
ret
=
CanonicalSimplify
(
args
[
0
].
operator
Stmt
());
}
}
else
{
if
(
args
.
size
()
>
1
)
{
*
ret
=
CanonicalSimplify
(
args
[
0
].
operator
Expr
(),
args
[
1
]);
}
else
{
*
ret
=
CanonicalSimplify
(
args
[
0
].
operator
Expr
());
}
}
});
TVM_REGISTER_API
(
"ir_pass.Equal"
)
...
...
src/arithmetic/canonical.cc
View file @
2bb1d8e4
...
...
@@ -129,6 +129,11 @@ inline Expr Binary_(const T* op,
// internal of canonical engine.
class
Canonical
::
Internal
:
public
IRMutator
{
public
:
explicit
Internal
(
Map
<
Var
,
Range
>
vrange
)
{
for
(
auto
kv
:
vrange
)
{
SetRange
(
kv
.
first
,
kv
.
second
,
0
);
}
}
// stack entry.
struct
StackEntry
{
int
max_level
{
0
};
...
...
@@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator {
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
final
{
return
Binary
(
op
,
e
);
}
// Mod operator
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
e
)
final
{
if
(
!
EnableOpt
(
op
->
type
))
{
return
Binary
(
op
,
e
);
}
CacheEntry
a
=
Produce
(
op
->
a
);
CacheEntry
b
=
Produce
(
op
->
b
);
if
(
a
.
has_side_effect
||
b
.
has_side_effect
)
{
return
Binary_
(
op
,
e
,
a
.
value
,
b
.
value
);
}
if
(
is_const
(
a
.
value
)
&&
is_const
(
b
.
value
))
{
return
ComputeExpr
<
Mul
>
(
a
.
value
,
b
.
value
);
}
else
if
(
is_const
(
b
.
value
))
{
return
SumModConst
(
a
.
AsSum
(),
b
.
value
);
}
else
{
return
Binary
(
op
,
e
);
}
}
Expr
Mutate_
(
const
And
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
And
>
();
...
...
@@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator {
private
:
template
<
typename
T
>
Expr
Binary
(
const
T
*
op
,
const
Expr
&
e
)
{
Expr
Binary
(
const
T
*
op
,
Expr
e
)
{
Expr
a
=
this
->
Mutate
(
op
->
a
);
Expr
b
=
this
->
Mutate
(
op
->
b
);
BinaryExpr
key
{
static_cast
<
int
>
(
T
::
_type_info
),
a
,
b
};
...
...
@@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator {
std
::
vector
<
Var
>
var_rec_
;
// level counter
int
level_counter_
{
0
};
//
subroutine to do produc
e
Expr
SumMulConst
(
ComExpr
a
,
Expr
v
)
{
//
get constant int valu
e
int64_t
GetConstIntValue
(
const
Expr
&
v
)
{
int64_t
value
=
0
;
const
int64_t
*
v1
=
as_const_int
(
v
);
const
uint64_t
*
v2
=
as_const_uint
(
v
);
...
...
@@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator {
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int64_t
>::
max
()));
value
=
static_cast
<
int64_t
>
(
*
v2
);
}
return
value
;
}
// subroutine to do produce a % v
Expr
SumModConst
(
ComExpr
a
,
Expr
v
)
{
int64_t
value
=
GetConstIntValue
(
v
);
std
::
shared_ptr
<
ComExprNode
>
n
=
std
::
make_shared
<
ComExprNode
>
();
int
mod_level
=
0
;
n
->
base
=
a
->
base
%
value
;
if
(
n
->
base
!=
0
)
mod_level
=
1
;
for
(
auto
e
:
a
->
elem
)
{
if
(
e
.
scale
%
value
==
0
)
continue
;
e
.
scale
=
e
.
scale
%
value
;
if
(
!
EvalSet
(
v
-
e
.
value
,
var_range_
).
can_prove_positive
())
{
mod_level
=
2
;
}
else
{
++
mod_level
;
}
n
->
elem
.
push_back
(
e
);
}
// cannot remove mode because there are more than two parts
if
(
mod_level
>=
2
)
{
Expr
ret
=
Sum2Expr
(
ComExpr
(
n
),
v
.
type
())
%
v
;
return
Binary
(
ret
.
as
<
Mod
>
(),
ret
);
}
ret_entry_
.
sum
=
ComExpr
(
n
);
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
if
(
it
!=
cache_sum_
.
end
())
{
ret_entry_
=
it
->
second
;
}
else
{
ret_entry_
.
value
=
Sum2Expr
(
ret_entry_
.
sum
,
v
.
type
());
cache_sum_
[
ret_entry_
.
sum
]
=
ret_entry_
;
}
return
ret_entry_
.
value
;
}
// subroutine to do produce
Expr
SumMulConst
(
ComExpr
a
,
Expr
v
)
{
int64_t
value
=
GetConstIntValue
(
v
);
if
(
value
==
0
)
{
return
make_zero
(
v
.
type
());
}
...
...
@@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator {
for
(
auto
&
e
:
vsum
->
elem
)
{
e
.
scale
*=
value
;
}
ret_entry_
.
sum
=
ComExpr
(
vsum
);
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
ret_entry_
.
sum
=
ComExpr
(
vsum
);
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
if
(
it
!=
cache_sum_
.
end
())
{
ret_entry_
=
it
->
second
;
...
...
@@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator {
using
CInternal
=
Canonical
::
Internal
;
Canonical
::
Canonical
()
:
ptr_
(
std
::
make_shared
<
Internal
>
())
{}
Canonical
::
Canonical
(
Map
<
Var
,
Range
>
vrange
)
:
ptr_
(
std
::
make_shared
<
Internal
>
(
vrange
))
{}
Expr
Canonical
::
Simplify
(
Expr
expr
)
{
return
ptr_
->
Mutate
(
expr
);
...
...
@@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) {
}
// namespace arith
namespace
ir
{
Stmt
CanonicalSimplify
(
Stmt
stmt
)
{
return
arith
::
Canonical
().
Simplify
(
stmt
);
Stmt
CanonicalSimplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
return
arith
::
Canonical
(
vrange
).
Simplify
(
stmt
);
}
Expr
CanonicalSimplify
(
Expr
expr
)
{
return
arith
::
Canonical
().
Simplify
(
expr
);
Expr
CanonicalSimplify
(
Expr
expr
,
Map
<
Var
,
Range
>
vrange
)
{
return
arith
::
Canonical
(
vrange
).
Simplify
(
expr
);
}
template
<
typename
T
>
...
...
src/arithmetic/canonical.h
View file @
2bb1d8e4
...
...
@@ -22,7 +22,7 @@ namespace arith {
class
Canonical
{
public
:
/*! \brief constructor */
Canonical
(
);
explicit
Canonical
(
Map
<
Var
,
Range
>
var_range
);
/*!
* \brief simplify expression e.
* \param expr The expression to be simplified.
...
...
tests/python/unittest/test_arith_simplify.py
View file @
2bb1d8e4
...
...
@@ -20,5 +20,23 @@ def test_simplify():
zz
=
zz
.
a
assert
zz
.
a
==
x
and
zz
.
b
.
value
==
4
def
test_simplify_mod
():
"""Not yet working, mock design"""
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
'n'
)
j
=
tvm
.
var
(
'j'
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
16
,
name
=
"i"
)
as
i
:
A
[
i
]
=
A
[((
n
*
4
+
j
*
2
)
*
8
+
i
+
1
)
%
16
]
body
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
body
)
diff
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
.
body
.
value
.
index
-
(
1
+
i
)
%
16
)
assert
diff
.
value
==
0
index
=
tvm
.
ir_pass
.
CanonicalSimplify
(
(
j
+
n
*
32
)
%
16
,
{
j
:
tvm
.
Range
(
0
,
6
)})
assert
index
==
j
if
__name__
==
"__main__"
:
test_simplify_mod
()
test_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment