Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
fc4ba796
Commit
fc4ba796
authored
Oct 13, 2016
by
Haichen Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add expr simplify and canonical
parent
77345051
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
13 deletions
+149
-13
python/tvm/expr.py
+14
-10
python/tvm/expr_util.py
+28
-0
python/tvm/op.py
+94
-2
tests/python/test_basic.py
+13
-1
No files found.
python/tvm/expr.py
View file @
fc4ba796
"""Base class of symbolic expression"""
"""Base class of symbolic expression"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
numbers
import
Number
as
_Number
from
numbers
import
Number
as
_Number
from
.
import
op
as
_op
from
.
import
var_name
as
_name
from
.
import
var_name
as
_name
__addop__
=
None
__subop__
=
None
__mulop__
=
None
__divop__
=
None
class
Expr
(
object
):
class
Expr
(
object
):
"""Base class of expression.
"""Base class of expression.
...
@@ -20,28 +24,28 @@ class Expr(object):
...
@@ -20,28 +24,28 @@ class Expr(object):
return
()
return
()
def
__add__
(
self
,
other
):
def
__add__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
add
,
self
,
other
)
return
BinaryOpExpr
(
_
_addop__
,
self
,
other
)
def
__radd__
(
self
,
other
):
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
def
__sub__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
sub
,
self
,
other
)
return
BinaryOpExpr
(
_
_subop__
,
self
,
other
)
def
__rsub__
(
self
,
other
):
def
__rsub__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
sub
,
other
,
self
)
return
BinaryOpExpr
(
_
_subop__
,
other
,
self
)
def
__mul__
(
self
,
other
):
def
__mul__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
mul
,
self
,
other
)
return
BinaryOpExpr
(
_
_mulop__
,
self
,
other
)
def
__rmul__
(
self
,
other
):
def
__rmul__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
mul
,
other
,
self
)
return
BinaryOpExpr
(
_
_mulop__
,
other
,
self
)
def
__div__
(
self
,
other
):
def
__div__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
div
,
self
,
other
)
return
BinaryOpExpr
(
_
_divop__
,
self
,
other
)
def
__rdiv__
(
self
,
other
):
def
__rdiv__
(
self
,
other
):
return
BinaryOpExpr
(
_
op
.
div
,
other
,
self
)
return
BinaryOpExpr
(
_
_divop__
,
other
,
self
)
def
__truediv__
(
self
,
other
):
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
return
self
.
__div__
(
other
)
...
@@ -75,7 +79,8 @@ class Var(Expr):
...
@@ -75,7 +79,8 @@ class Var(Expr):
optional name to the var.
optional name to the var.
"""
"""
def
__init__
(
self
,
name
=
None
):
def
__init__
(
self
,
name
=
None
):
self
.
name
=
name
if
name
else
_name
.
NameManager
.
current
.
get
(
name
)
if
name
is
None
:
name
=
'i'
self
.
name
=
_name
.
NameManager
.
current
.
get
(
name
)
class
ConstExpr
(
Expr
):
class
ConstExpr
(
Expr
):
...
@@ -95,7 +100,6 @@ class BinaryOpExpr(Expr):
...
@@ -95,7 +100,6 @@ class BinaryOpExpr(Expr):
def
children
(
self
):
def
children
(
self
):
return
(
self
.
lhs
,
self
.
rhs
)
return
(
self
.
lhs
,
self
.
rhs
)
_op
.
binary_op_cls
=
BinaryOpExpr
class
UnaryOpExpr
(
Expr
):
class
UnaryOpExpr
(
Expr
):
"""Unary operator expression."""
"""Unary operator expression."""
...
...
python/tvm/expr_util.py
View file @
fc4ba796
"""Utilities to manipulate expression"""
"""Utilities to manipulate expression"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
op
as
_op
def
expr_with_new_children
(
e
,
children
):
def
expr_with_new_children
(
e
,
children
):
"""Returns same expr as e but with new children
"""Returns same expr as e but with new children
...
@@ -48,6 +49,7 @@ def transform(e, f):
...
@@ -48,6 +49,7 @@ def transform(e, f):
result : return value of f
result : return value of f
The final result of transformation.
The final result of transformation.
"""
"""
assert
isinstance
(
e
,
_expr
.
Expr
)
return
f
(
e
,
[
transform
(
c
,
f
)
for
c
in
e
.
children
()])
return
f
(
e
,
[
transform
(
c
,
f
)
for
c
in
e
.
children
()])
...
@@ -77,6 +79,32 @@ def format_str(expr):
...
@@ -77,6 +79,32 @@ def format_str(expr):
raise
TypeError
(
"Do not know how to handle type "
+
str
(
type
(
e
)))
raise
TypeError
(
"Do not know how to handle type "
+
str
(
type
(
e
)))
return
transform
(
expr
,
make_str
)
return
transform
(
expr
,
make_str
)
def
simplify
(
expr
):
"""simplify expression
Parameters
----------
expr : Expr
Input expression
Returns
-------
e : Expr
Simplified expression
"""
def
canonical
(
e
,
result_children
):
if
isinstance
(
e
,
_expr
.
BinaryOpExpr
):
return
e
.
op
.
canonical
(
result_children
[
0
],
result_children
[
1
])
elif
isinstance
(
e
,
_expr
.
UnaryOpExpr
):
return
e
.
op
.
canonical
(
result_children
[
0
])
elif
isinstance
(
e
,
_expr
.
ConstExpr
):
return
{
_op
.
constant_canonical_key
:
e
.
value
}
elif
isinstance
(
e
,
_expr
.
Var
):
return
{
e
:
1
}
else
:
raise
TypeError
(
"Do not know how to handle type "
+
str
(
type
(
e
)))
return
_op
.
canonical_to_expr
(
transform
(
expr
,
canonical
))
def
bind
(
expr
,
update_dict
):
def
bind
(
expr
,
update_dict
):
"""Replace the variable in e by specification from kwarg
"""Replace the variable in e by specification from kwarg
...
...
python/tvm/op.py
View file @
fc4ba796
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
expr
as
_expr
_binary_op_cls
=
None
constant_canonical_key
=
'__constant__'
def
canonical_to_expr
(
c
):
elements
=
[]
for
k
,
v
in
sorted
(
c
.
items
()):
if
k
==
constant_canonical_key
:
elements
.
append
(
_expr
.
const
(
v
))
elif
v
==
0
:
continue
elif
v
==
1
:
elements
.
append
(
k
)
else
:
elements
.
append
(
k
*
v
)
if
elements
:
expr
=
elements
[
0
]
for
i
in
range
(
1
,
len
(
elements
)):
expr
=
expr
+
elements
[
i
]
return
expr
else
:
return
_expr
.
const
(
0
)
class
BinaryOp
(
object
):
class
BinaryOp
(
object
):
"""Base class of binary operator"""
"""Base class of binary operator"""
def
__call__
(
self
,
lhs
,
rhs
):
def
__call__
(
self
,
lhs
,
rhs
):
return
_
binary_op_cls
(
self
,
lhs
,
rhs
)
return
_
expr
.
BinaryOpExpr
(
self
,
lhs
,
rhs
)
class
AddOp
(
BinaryOp
):
class
AddOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s +
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s +
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
lhs
=
lhs
.
copy
()
for
k
,
v
in
rhs
.
items
():
if
k
in
lhs
:
lhs
[
k
]
+=
v
else
:
lhs
[
k
]
=
v
return
lhs
class
SubOp
(
BinaryOp
):
class
SubOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s -
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s -
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
lhs
=
lhs
.
copy
()
for
k
,
v
in
rhs
.
items
():
if
k
in
lhs
:
lhs
[
k
]
-=
v
else
:
lhs
[
k
]
=
-
v
return
lhs
class
MulOp
(
BinaryOp
):
class
MulOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s *
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s *
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
elhs
=
canonical_to_expr
(
lhs
)
erhs
=
canonical_to_expr
(
rhs
)
if
isinstance
(
erhs
,
_expr
.
ConstExpr
):
lhs
=
lhs
.
copy
()
for
k
,
v
in
lhs
.
items
():
lhs
[
k
]
*=
erhs
.
value
return
lhs
if
isinstance
(
elhs
,
_expr
.
ConstExpr
):
rhs
=
rhs
.
copy
()
for
k
,
v
in
rhs
.
items
():
rhs
[
k
]
*=
elhs
.
value
return
rhs
return
{
elhs
*
erhs
:
1
}
class
DivOp
(
BinaryOp
):
class
DivOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'(
%
s /
%
s)'
%
(
lhs
,
rhs
)
return
'(
%
s /
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
erhs
=
canonical_to_expr
(
rhs
)
if
isinstance
(
erhs
,
_expr
.
ConstExpr
):
lhs
=
lhs
.
copy
()
for
k
,
v
in
lhs
.
items
():
lhs
[
k
]
/=
erhs
.
value
return
lhs
elhs
=
canonical_to_expr
(
lhs
)
return
{
elhs
/
erhs
:
1
}
class
MaxOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'max(
%
s,
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
diff
=
SubOp
()
.
canonical
(
lhs
,
rhs
)
ediff
=
canonical_to_expr
(
diff
)
if
isinstance
(
ediff
,
_expr
.
ConstExpr
):
return
lhs
if
ediff
.
value
>=
0
else
rhs
return
{
MaxOp
()(
lhs
,
rhs
):
1
}
class
MinOp
(
BinaryOp
):
def
format_str
(
self
,
lhs
,
rhs
):
return
'min(
%
s,
%
s)'
%
(
lhs
,
rhs
)
def
canonical
(
self
,
lhs
,
rhs
):
diff
=
SubOp
()
.
canonical
(
lhs
,
rhs
)
ediff
=
canonical_to_expr
(
diff
)
if
isinstance
(
ediff
,
_expr
.
ConstExpr
):
return
rhs
if
ediff
.
value
>=
0
else
lhs
return
{
MinOp
()(
lhs
,
rhs
):
1
}
add
=
AddOp
()
add
=
AddOp
()
sub
=
SubOp
()
sub
=
SubOp
()
mul
=
MulOp
()
mul
=
MulOp
()
div
=
DivOp
()
div
=
DivOp
()
max
=
MaxOp
()
min
=
MinOp
()
_expr
.
__addop__
=
add
_expr
.
__subop__
=
sub
_expr
.
__mulop__
=
mul
_expr
.
__divop__
=
div
tests/python/test_basic.py
View file @
fc4ba796
...
@@ -9,12 +9,24 @@ def test_bind():
...
@@ -9,12 +9,24 @@ def test_bind():
def
test_basic
():
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
b
=
tvm
.
Var
(
'b'
)
c
=
a
+
b
c
=
a
+
b
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
def
test_simplify
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
e1
=
a
*
(
2
+
1
)
+
b
*
1
e2
=
a
*
(
2
+
1
)
-
b
*
1
e3
=
tvm
.
max
(
a
*
3.3
+
5
,
3
+
3.3
*
a
)
e4
=
a
-
a
assert
tvm
.
format_str
(
tvm
.
simplify
(
e1
))
==
'((
%
s * 3) +
%
s)'
%
(
a
.
name
,
b
.
name
)
assert
tvm
.
format_str
(
tvm
.
simplify
(
e2
))
==
'((
%
s * 3) + (
%
s * -1))'
%
(
a
.
name
,
b
.
name
)
assert
tvm
.
format_str
(
tvm
.
simplify
(
e3
))
==
'((
%
s * 3.3) + 5)'
%
(
a
.
name
)
assert
tvm
.
format_str
(
tvm
.
simplify
(
e4
))
==
'0'
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_simplify
()
test_basic
()
test_basic
()
test_bind
()
test_bind
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment