Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2a7aebe5
Unverified
Commit
2a7aebe5
authored
Jul 06, 2019
by
Tianqi Chen
Committed by
GitHub
Jul 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] More recursive rewrite rule, cleanup simplify tests (#3502)
parent
eadc4e38
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
43 additions
and
235 deletions
+43
-235
python/tvm/make.py
+0
-30
src/arithmetic/rewrite_simplify.cc
+2
-2
src/arithmetic/stmt_simplify.cc
+16
-19
tests/python/unittest/test_arith_canonical_simplify.py
+6
-0
tests/python/unittest/test_arith_simplify.py
+0
-125
tests/python/unittest/test_arith_stmt_simplify.py
+0
-58
tests/python/unittest/test_lang_operator.py
+18
-0
topi/python/topi/math.py
+1
-1
No files found.
python/tvm/make.py
View file @
2a7aebe5
...
...
@@ -24,7 +24,6 @@ You can use make function to build the IR node.
"""
from
__future__
import
absolute_import
as
_abs
from
._ffi.function
import
_init_api
from
._ffi.runtime_ctypes
import
TVMType
def
range_by_min_extent
(
min_value
,
extent
):
...
...
@@ -48,35 +47,6 @@ def range_by_min_extent(min_value, extent):
return
_range_by_min_extent
(
min_value
,
extent
)
def
static_cast
(
dtype
,
expr
):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type
=
TVMType
(
dtype
)
src_type
=
TVMType
(
expr
.
dtype
)
if
target_type
.
type_code
==
src_type
.
type_code
and
src_type
.
bits
==
target_type
.
bits
:
if
src_type
.
lanes
==
target_type
.
lanes
:
return
expr
if
src_type
.
lanes
==
1
and
target_type
.
lanes
>
1
:
return
Broadcast
(
expr
,
target_type
.
lanes
)
return
Cast
(
dtype
,
expr
)
def
node
(
type_key
,
**
kwargs
):
"""Make a new DSL node by its type key and fields
...
...
src/arithmetic/rewrite_simplify.cc
View file @
2a7aebe5
...
...
@@ -1194,9 +1194,9 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE
(
c1
-
y
<
x
,
c1
<
x
+
y
);
TVM_TRY_RECURSIVE_REWRITE
(
c1
+
y
<
x
,
c1
<
x
-
y
);
TVM_TRY_RECURSIVE_REWRITE
(
x
+
c1
<
c2
,
x
<
c2
-
c1
);
TVM_TRY_RECURSIVE_REWRITE
(
x
-
c1
<
c2
,
x
<
c2
+
c1
);
TVM_TRY_REWRITE
(
x
-
c1
<
0
,
x
<
c1
);
TVM_TRY_REWRITE
(
x
+
c1
<
c2
,
x
<
c2
-
c1
);
}
return
ret
;
}
...
...
src/arithmetic/stmt_simplify.cc
View file @
2a7aebe5
...
...
@@ -31,11 +31,24 @@
namespace
tvm
{
namespace
arith
{
// statement simplifier
using
namespace
ir
;
class
StmtSimplifier
:
public
IRMutator
{
public
:
using
IRMutator
::
Mutate
;
Expr
Mutate
(
Expr
expr
)
final
{
return
analyzer_
.
Simplify
(
expr
);
}
Stmt
Simplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
for
(
auto
kv
:
vrange
)
{
analyzer_
.
Bind
(
kv
.
first
,
kv
.
second
);
}
return
Mutate
(
stmt
);
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
Var
loop_var
(
op
->
loop_var
.
node_
);
analyzer_
.
Bind
(
loop_var
,
Range
::
make_by_min_extent
(
op
->
min
,
op
->
extent
));
...
...
@@ -124,28 +137,12 @@ class StmtSimplifier : public IRMutator {
std
::
unordered_map
<
const
Variable
*
,
Range
>
var_dom_
;
};
class
CanonicalStmtSimplifier
:
public
StmtSimplifier
{
public
:
using
StmtSimplifier
::
Mutate
;
Expr
Mutate
(
Expr
expr
)
final
{
return
analyzer_
.
canonical_simplify
(
expr
);
}
Stmt
CanonicalSimplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
for
(
auto
kv
:
vrange
)
{
analyzer_
.
Bind
(
kv
.
first
,
kv
.
second
);
}
return
Mutate
(
stmt
);
}
};
}
// namespace arith
namespace
ir
{
Stmt
CanonicalSimplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
return
arith
::
CanonicalStmtSimplifier
().
Canonical
Simplify
(
return
arith
::
StmtSimplifier
().
Simplify
(
stmt
,
vrange
);
}
...
...
@@ -167,7 +164,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}
Stmt
Simplify
(
Stmt
stmt
,
Map
<
Var
,
Range
>
vrange
)
{
return
arith
::
CanonicalStmtSimplifier
().
Canonical
Simplify
(
return
arith
::
StmtSimplifier
().
Simplify
(
stmt
,
vrange
);
}
}
// namespace ir
...
...
tests/python/unittest/test_arith_canonical_simplify.py
View file @
2a7aebe5
...
...
@@ -81,6 +81,10 @@ def test_canonical_mixed():
z
=
tvm
.
const
(
3
,
"int32"
)
ck
.
verify
(
x
/
(
z
*
z
)
-
x
/
(
z
*
z
),
0
)
ck
.
verify
(
x
/
(
z
+
z
)
-
x
/
(
z
+
z
),
0
)
ck
.
verify
(
x
-
2
<
3
,
x
<
5
)
ck
.
verify
(
tvm
.
max
(
x
,
1
)
-
tvm
.
max
(
x
,
1
),
0
)
ck
.
verify
(
tvm
.
min
(
x
,
1
)
-
tvm
.
min
(
x
,
1
),
0
)
ck
.
verify
(
x
*
x
-
x
*
x
,
0
)
def
test_reduce_combiner_simplify
():
...
...
@@ -211,6 +215,8 @@ def test_complex_cases():
ck
.
verify
(
res3
,
((((
x
*
1024
)
+
y
)
/
256
)
-
(
y
/
256
))
-
(
x
*
4
))
if
__name__
==
"__main__"
:
test_simplify_if_then_else
()
test_div_simplify
()
...
...
tests/python/unittest/test_arith_simplify.py
deleted
100644 → 0
View file @
eadc4e38
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
def
csimplify
(
z
):
return
tvm
.
ir_pass
.
CanonicalSimplify
(
tvm
.
make
.
Evaluate
(
z
))
.
value
def
test_simplify
():
x
=
tvm
.
var
(
'n'
)
z
=
x
*
4
-
x
*
2
zz
=
csimplify
(
z
)
assert
zz
.
b
.
value
==
2
z
=
(
x
/
4
)
*
2
-
(
x
/
4
)
zz
=
csimplify
(
z
)
assert
zz
.
a
==
x
and
zz
.
b
.
value
==
4
z
=
(
x
%
4
)
*
3
+
(
x
%
4
)
zz
=
csimplify
(
z
)
assert
zz
.
b
.
value
==
4
zz
=
zz
.
a
assert
zz
.
a
==
x
and
zz
.
b
.
value
==
4
n
=
tvm
.
var
(
'n'
)
assert
tvm
.
ir_pass
.
Equal
(
tvm
.
ir_pass
.
CanonicalSimplify
(
n
%
1
),
tvm
.
const
(
0
,
"int32"
))
assert
tvm
.
ir_pass
.
Equal
(
tvm
.
ir_pass
.
CanonicalSimplify
(
n
/
1
),
n
)
tvm
.
ir_pass
.
CanonicalSimplify
(
n
/
(
-
1
))
# This is not true in the current implementation
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n))
def
test_simplify_mod
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
'n'
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
16
,
name
=
"i"
)
as
i
:
A
[
i
]
=
A
[(
j
*
32
+
i
+
1
)
%
16
]
body
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
body
)
diff
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
.
body
.
body
.
value
.
index
-
(
1
+
i
)
%
16
)
assert
diff
.
value
==
0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index
=
tvm
.
ir_pass
.
CanonicalSimplify
((
j
+
16
)
%
16
)
assert
index
!=
j
index
=
tvm
.
ir_pass
.
CanonicalSimplify
((
j
+
16
)
%
16
,
{
j
:
tvm
.
Range
(
0
,
6
)})
assert
index
==
j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index
=
tvm
.
ir_pass
.
CanonicalSimplify
(
(
j
+
n
*
32
)
%
16
,
{
j
:
tvm
.
Range
(
0
,
6
)})
assert
index
!=
j
index
=
tvm
.
ir_pass
.
CanonicalSimplify
(
(
j
+
n
*
32
)
%
16
,
{
j
:
tvm
.
Range
(
0
,
6
),
n
:
tvm
.
Range
(
0
,
10
)})
assert
index
==
j
def
test_simplify_minmax
():
x
=
tvm
.
var
(
'x'
)
e1
=
tvm
.
max
(
x
,
1
)
-
tvm
.
max
(
x
,
1
)
e1s
=
tvm
.
ir_pass
.
CanonicalSimplify
(
e1
)
assert
e1s
.
value
==
0
e2
=
tvm
.
min
(
x
,
1
)
-
tvm
.
min
(
x
,
1
)
e2s
=
tvm
.
ir_pass
.
CanonicalSimplify
(
e2
)
assert
e2s
.
value
==
0
def
test_mul
():
x
=
tvm
.
var
(
'x'
)
e
=
x
*
x
-
x
*
x
es
=
tvm
.
ir_pass
.
CanonicalSimplify
(
e
)
assert
es
.
value
==
0
def
test_modular
():
rx
=
tvm
.
var
(
"rx"
)
ry
=
tvm
.
var
(
"ry"
)
y
=
tvm
.
var
(
"y"
)
x
=
tvm
.
var
(
"x"
)
i32_const
=
lambda
x
:
tvm
.
const
(
x
,
"int32"
)
vmap
=
{
rx
:
tvm
.
Range
(
i32_const
(
0
),
i32_const
(
3
)),
ry
:
tvm
.
Range
(
i32_const
(
0
),
i32_const
(
3
)),
y
:
tvm
.
Range
(
i32_const
(
0
),
i32_const
(
2
)),
x
:
tvm
.
Range
(
i32_const
(
0
),
i32_const
(
14
))}
idx
=
ry
*
16
+
rx
+
y
*
16
+
x
z2
=
tvm
.
ir_pass
.
CanonicalSimplify
(
idx
%
16
,
vmap
)
z1
=
tvm
.
ir_pass
.
CanonicalSimplify
(
idx
//
16
,
vmap
)
assert
tvm
.
ir_pass
.
CanonicalSimplify
(
z1
-
(
ry
+
y
))
.
value
==
0
assert
tvm
.
ir_pass
.
CanonicalSimplify
(
z2
-
(
rx
+
x
))
.
value
==
0
def
test_const_propagation
():
x1
=
tvm
.
const
(
4
,
"int32"
)
x2
=
x1
+
5
assert
isinstance
(
x2
,
tvm
.
expr
.
IntImm
)
and
x2
.
value
==
9
x3
=
x2
/
3
assert
isinstance
(
x3
,
tvm
.
expr
.
IntImm
)
and
x3
.
value
==
3
x4
=
x3
+
0.5
assert
isinstance
(
x4
,
tvm
.
expr
.
FloatImm
)
and
x4
.
value
==
3.5
x5
=
tvm
.
ceil
(
x4
)
assert
isinstance
(
x5
,
tvm
.
expr
.
FloatImm
)
and
x5
.
value
==
4
x6
=
x5
.
astype
(
'int'
)
assert
isinstance
(
x6
,
tvm
.
expr
.
IntImm
)
and
x6
.
value
==
4
y
=
(
tvm
.
round
((
tvm
.
const
(
6.5
,
'float32'
)
-
1
)
/
1.5
)
+
2
)
.
astype
(
'int'
)
assert
isinstance
(
y
,
tvm
.
expr
.
IntImm
)
and
y
.
value
==
6
if
__name__
==
"__main__"
:
test_modular
()
test_simplify
()
test_mul
()
test_simplify_minmax
()
test_const_propagation
()
test_simplify_mod
()
tests/python/unittest/test_arith_stmt_simplify.py
deleted
100644 → 0
View file @
eadc4e38
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import
tvm
import
numpy
from
tvm
import
comm_reducer
from
tvm.ir_pass
import
Simplify
,
CanonicalSimplify
,
Equal
def
test_simplify
():
"""Not yet working, mock design"""
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
i
=
tvm
.
var
(
'i'
)
j
=
tvm
.
var
(
'j'
)
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
2
,
n
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
n
,
0
,
0
,
tvm
.
make
.
IfThenElse
(
tvm
.
make
.
LT
(
i
+
2
,
n
),
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
+
4
)
+
1
,
(
j
+
1
)
*
4
-
4
*
j
+
i
),
None
)))
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
)
def
test_basic
():
m
=
tvm
.
var
(
'm'
)
ret
=
tvm
.
ir_pass
.
CanonicalSimplify
(
tvm
.
make
.
Evaluate
(
m
-
1
))
assert
str
(
ret
.
value
)
==
"(m - 1)"
def
test_bound
():
m
=
tvm
.
var
(
'm'
)
vrange
=
tvm
.
convert
({
m
:
tvm
.
Range
(
tvm
.
const
(
0
,
"int32"
),
tvm
.
const
(
10
,
"int32"
))})
ret
=
tvm
.
ir_pass
.
Simplify
(
m
%
10
,
vrange
)
assert
ret
==
m
if
__name__
==
"__main__"
:
test_bound
()
test_basic
()
test_simplify
()
tests/python/unittest/test_lang_operator.py
View file @
2a7aebe5
...
...
@@ -83,7 +83,25 @@ def test_const_fold3():
assert
tvm
.
any
(
x
,
true
)
.
same_as
(
true
)
assert
tvm
.
any
(
true
,
x
)
.
same_as
(
true
)
def
test_const_fold4
():
x1
=
tvm
.
const
(
4
,
"int32"
)
x2
=
x1
+
5
assert
isinstance
(
x2
,
tvm
.
expr
.
IntImm
)
and
x2
.
value
==
9
x3
=
x2
/
3
assert
isinstance
(
x3
,
tvm
.
expr
.
IntImm
)
and
x3
.
value
==
3
x4
=
x3
+
0.55
assert
isinstance
(
x4
,
tvm
.
expr
.
FloatImm
)
and
abs
(
x4
.
value
-
3.55
)
<
1e-6
x5
=
tvm
.
ceil
(
x4
)
assert
isinstance
(
x5
,
tvm
.
expr
.
FloatImm
)
and
x5
.
value
==
4
x6
=
x5
.
astype
(
'int'
)
assert
isinstance
(
x6
,
tvm
.
expr
.
IntImm
)
and
x6
.
value
==
4
,
"x6={}"
.
format
(
x6
)
y
=
(
tvm
.
round
((
tvm
.
const
(
6.5
,
'float32'
)
-
1
)
/
1.5
)
+
2
)
.
astype
(
'int'
)
assert
isinstance
(
y
,
tvm
.
expr
.
IntImm
)
and
y
.
value
==
6
if
__name__
==
"__main__"
:
test_const_fold
()
test_const_fold2
()
test_const_fold3
()
test_const_fold4
()
topi/python/topi/math.py
View file @
2a7aebe5
...
...
@@ -342,4 +342,4 @@ def cast(x, dtype):
if
isinstance
(
x
,
tvm
.
tensor
.
Tensor
):
return
tvm
.
compute
(
x
.
shape
,
lambda
*
i
:
x
(
*
i
)
.
astype
(
dtype
),
tag
=
tag
.
ELEMWISE
)
return
tvm
.
make
.
static
_cast
(
dtype
,
x
)
return
tvm
.
make
.
_cast
(
dtype
,
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment