Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3c0dc79d
Commit
3c0dc79d
authored
Oct 21, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Simplify for cxx
parent
9595a9c1
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
237 additions
and
64 deletions
+237
-64
include/tvm/expr_util.h
+1
-3
include/tvm/tvm.h
+1
-0
python/tvm/cpp/_ctypes/_api.py
+1
-0
src/c_api/c_api_function.cc
+6
-0
src/expr/expr.cc
+0
-57
src/expr/expr_util.cc
+197
-0
tests/cpp/expr_test.cc
+9
-0
tests/python/test_cpp.py
+22
-4
No files found.
include/tvm/expr_util.h
View file @
3c0dc79d
...
...
@@ -16,9 +16,7 @@ namespace tvm {
* \param src The source expression
* \return the simplified expression.
*/
inline
Expr
Simplify
(
Expr
src
)
{
return
src
;
}
Expr
Simplify
(
Expr
src
);
/*!
* \brief visit the exression node in expr tree in post DFS order.
...
...
include/tvm/tvm.h
View file @
3c0dc79d
...
...
@@ -12,5 +12,6 @@
#include "./tensor.h"
#include "./domain.h"
#include "./array.h"
#include "./expr_util.h"
#endif // TVM_TVM_H_
python/tvm/cpp/_ctypes/_api.py
View file @
3c0dc79d
...
...
@@ -172,6 +172,7 @@ def register_node(type_key):
"""
def
register
(
cls
):
NODE_TYPE
[
type_key
]
=
cls
return
cls
return
register
...
...
src/c_api/c_api_function.cc
View file @
3c0dc79d
...
...
@@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/tensor.h>
#include <tvm/expr_util.h>
#include "./c_api_registry.h"
namespace
dmlc
{
...
...
@@ -104,6 +105,11 @@ TVM_REGISTER_API(_TensorInput)
static_cast
<
DataType
>
(
static_cast
<
int
>
(
args
.
at
(
1
))));
});
TVM_REGISTER_API
(
simplify
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Simplify
(
args
.
at
(
0
));
});
// transformations
TVM_REGISTER_API
(
format_str
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
...
...
src/expr/expr.cc
View file @
3c0dc79d
...
...
@@ -9,63 +9,6 @@
namespace
tvm
{
void
Expr
::
Print
(
std
::
ostream
&
os
)
const
{
if
(
is_null
())
{
os
<<
"null"
;
return
;
}
switch
(
this
->
node_type
())
{
case
kVarNode
:
{
os
<<
Get
<
VarNode
>
()
->
name
;
return
;
}
case
kIntNode
:
{
os
<<
Get
<
IntNode
>
()
->
value
;
return
;
}
case
kFloatNode
:
{
os
<<
Get
<
FloatNode
>
()
->
value
;
return
;
}
case
kBinaryOpNode
:
{
const
auto
*
n
=
Get
<
BinaryOpNode
>
();
const
char
*
fname
=
n
->
op
->
FunctionName
();
if
(
fname
[
1
]
==
'\0'
&&
!
isalpha
(
fname
[
0
]))
{
os
<<
'('
;
n
->
lhs
.
Print
(
os
);
os
<<
' '
<<
fname
[
0
]
<<
' '
;
n
->
rhs
.
Print
(
os
);
os
<<
')'
;
}
else
{
os
<<
fname
<<
'('
;
n
->
lhs
.
Print
(
os
);
os
<<
", "
;
n
->
rhs
.
Print
(
os
);
os
<<
')'
;
}
return
;
}
case
kUnaryOpNode
:
{
const
auto
*
n
=
Get
<
UnaryOpNode
>
();
os
<<
n
->
op
->
FunctionName
()
<<
'('
;
n
->
src
.
Print
(
os
);
os
<<
')'
;
return
;
}
case
kReduceNode
:
{
const
auto
*
n
=
Get
<
ReduceNode
>
();
os
<<
"reduce("
<<
n
->
op
->
FunctionName
()
<<
", "
;
n
->
src
.
Print
(
os
);
os
<<
", "
<<
n
->
rdom
<<
')'
;
return
;
}
case
kTensorReadNode
:
{
const
auto
*
n
=
Get
<
TensorReadNode
>
();
os
<<
n
->
tensor
.
name
()
<<
n
->
indices
;
return
;
}
default
:
{
LOG
(
FATAL
)
<<
"not able to handle type "
<<
typeid
(
node_
.
get
()).
name
();
}
}
}
Var
::
Var
(
std
::
string
name
,
DataType
dtype
)
{
auto
node
=
std
::
make_shared
<
VarNode
>
();
node
->
name
=
std
::
move
(
name
);
...
...
src/expr/expr_util.cc
View file @
3c0dc79d
...
...
@@ -3,8 +3,205 @@
* \file expr_util.cc
*/
#include <tvm/expr_util.h>
#include <tvm/op.h>
namespace
tvm
{
inline
bool
is_ingeter
(
DataType
t
)
{
return
t
==
kInt32
;
}
/*! \brief Canonical form of expression */
struct
CanonicalExpr
{
/*! \brief the e->value */
std
::
unordered_map
<
Expr
,
int64_t
>
dict
;
/*! \brief constant value in the expresssion */
int64_t
constant
{
0
};
// change CanonicalExpr as expr
inline
Expr
AsExpr
()
const
{
Expr
e
;
using
KV
=
std
::
pair
<
Expr
,
int64_t
>
;
std
::
vector
<
KV
>
tlist
(
dict
.
begin
(),
dict
.
end
());
std
::
sort
(
tlist
.
begin
(),
tlist
.
end
(),
[](
const
KV
&
lhs
,
const
KV
&
rhs
)
{
return
lhs
.
first
.
hash
()
<
rhs
.
first
.
hash
();
});
for
(
auto
&
kv
:
tlist
)
{
if
(
kv
.
second
==
0
)
continue
;
Expr
tmp
;
if
(
kv
.
second
==
1
)
{
tmp
=
kv
.
first
;
}
else
{
tmp
=
kv
.
first
*
kv
.
second
;
}
if
(
e
.
is_null
())
{
e
=
tmp
;
}
else
{
e
=
e
+
tmp
;
}
}
if
(
e
.
is_null
())
{
return
IntConstant
(
constant
);
}
else
{
if
(
constant
!=
0
)
e
=
e
+
constant
;
return
e
;
}
}
inline
void
Add
(
const
Expr
&
e
,
int
beta
)
{
auto
it
=
dict
.
find
(
e
);
if
(
it
!=
dict
.
end
())
{
it
->
second
+=
beta
;
if
(
it
->
second
==
0
)
dict
.
erase
(
it
);
}
else
{
dict
[
e
]
=
beta
;
}
}
};
// out += beta * Canonicalize(e)
void
AddCanonical
(
const
Expr
&
e
,
CanonicalExpr
*
out
,
int
beta
)
{
static
const
BinaryOp
*
add_op
=
BinaryOp
::
Get
(
"+"
);
static
const
BinaryOp
*
sub_op
=
BinaryOp
::
Get
(
"-"
);
static
const
BinaryOp
*
mul_op
=
BinaryOp
::
Get
(
"*"
);
static
const
BinaryOp
*
max_op
=
BinaryOp
::
Get
(
"max"
);
static
const
BinaryOp
*
min_op
=
BinaryOp
::
Get
(
"min"
);
CHECK
(
!
e
.
is_null
())
<<
"cannot simplify null"
;
switch
(
e
.
node_type
())
{
case
kIntNode
:
{
out
->
constant
+=
(
e
.
Get
<
IntNode
>
()
->
value
)
*
beta
;
return
;
}
case
kBinaryOpNode
:
{
const
auto
*
n
=
e
.
Get
<
BinaryOpNode
>
();
if
(
n
->
op
==
add_op
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
);
AddCanonical
(
n
->
rhs
,
out
,
beta
);
return
;
}
if
(
n
->
op
==
sub_op
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
);
AddCanonical
(
n
->
rhs
,
out
,
-
beta
);
return
;
}
if
(
n
->
op
==
mul_op
)
{
if
(
n
->
lhs
.
node_type
()
==
kIntNode
)
{
AddCanonical
(
n
->
rhs
,
out
,
beta
*
(
n
->
lhs
.
Get
<
IntNode
>
()
->
value
));
return
;
}
else
if
(
n
->
rhs
.
node_type
()
==
kIntNode
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
*
(
n
->
rhs
.
Get
<
IntNode
>
()
->
value
));
return
;
}
CanonicalExpr
clhs
,
crhs
;
AddCanonical
(
n
->
lhs
,
&
clhs
,
1
);
if
(
clhs
.
dict
.
size
()
==
0
)
{
AddCanonical
(
n
->
rhs
,
out
,
beta
*
clhs
.
constant
);
return
;
}
AddCanonical
(
n
->
rhs
,
&
crhs
,
1
);
if
(
crhs
.
dict
.
size
()
==
0
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
*
crhs
.
constant
);
return
;
}
out
->
Add
(
e
,
beta
);
return
;
}
if
(
n
->
op
==
max_op
)
{
CanonicalExpr
res
;
AddCanonical
(
n
->
lhs
,
&
res
,
1
);
AddCanonical
(
n
->
rhs
,
&
res
,
-
1
);
if
(
res
.
dict
.
size
()
==
0
)
{
if
(
res
.
constant
>
0
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
);
return
;
}
else
{
AddCanonical
(
n
->
rhs
,
out
,
beta
);
return
;
}
}
else
{
out
->
Add
(
e
,
beta
);
return
;
}
}
if
(
n
->
op
==
min_op
)
{
CanonicalExpr
res
;
AddCanonical
(
n
->
lhs
,
&
res
,
1
);
AddCanonical
(
n
->
rhs
,
&
res
,
-
1
);
if
(
res
.
dict
.
size
()
==
0
)
{
if
(
res
.
constant
<=
0
)
{
AddCanonical
(
n
->
lhs
,
out
,
beta
);
return
;
}
else
{
AddCanonical
(
n
->
rhs
,
out
,
beta
);
return
;
}
}
else
{
out
->
Add
(
e
,
beta
);
return
;
}
}
out
->
Add
(
e
,
beta
);
return
;
}
default:
{
out
->
Add
(
e
,
beta
);
return
;
}
}
}
Expr
Simplify
(
Expr
src
)
{
CanonicalExpr
cexpr
;
AddCanonical
(
src
,
&
cexpr
,
1
);
return
cexpr
.
AsExpr
();
}
void
Expr
::
Print
(
std
::
ostream
&
os
)
const
{
if
(
is_null
())
{
os
<<
"null"
;
return
;
}
switch
(
this
->
node_type
())
{
case
kVarNode
:
{
os
<<
Get
<
VarNode
>
()
->
name
;
return
;
}
case
kIntNode
:
{
os
<<
Get
<
IntNode
>
()
->
value
;
return
;
}
case
kFloatNode
:
{
os
<<
Get
<
FloatNode
>
()
->
value
;
return
;
}
case
kBinaryOpNode
:
{
const
auto
*
n
=
Get
<
BinaryOpNode
>
();
const
char
*
fname
=
n
->
op
->
FunctionName
();
if
(
fname
[
1
]
==
'\0'
&&
!
isalpha
(
fname
[
0
]))
{
os
<<
'('
;
n
->
lhs
.
Print
(
os
);
os
<<
' '
<<
fname
[
0
]
<<
' '
;
n
->
rhs
.
Print
(
os
);
os
<<
')'
;
}
else
{
os
<<
fname
<<
'('
;
n
->
lhs
.
Print
(
os
);
os
<<
", "
;
n
->
rhs
.
Print
(
os
);
os
<<
')'
;
}
return
;
}
case
kUnaryOpNode
:
{
const
auto
*
n
=
Get
<
UnaryOpNode
>
();
os
<<
n
->
op
->
FunctionName
()
<<
'('
;
n
->
src
.
Print
(
os
);
os
<<
')'
;
return
;
}
case
kReduceNode
:
{
const
auto
*
n
=
Get
<
ReduceNode
>
();
os
<<
"reduce("
<<
n
->
op
->
FunctionName
()
<<
", "
;
n
->
src
.
Print
(
os
);
os
<<
", "
<<
n
->
rdom
<<
')'
;
return
;
}
case
kTensorReadNode
:
{
const
auto
*
n
=
Get
<
TensorReadNode
>
();
os
<<
n
->
tensor
.
name
()
<<
n
->
indices
;
return
;
}
default
:
{
LOG
(
FATAL
)
<<
"not able to handle type "
<<
typeid
(
node_
.
get
()).
name
();
}
}
}
}
// namespace tvm
tests/cpp/expr_test.cc
View file @
3c0dc79d
...
...
@@ -21,6 +21,15 @@ TEST(Expr, Reduction) {
CHECK
(
os
.
str
()
==
"reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))"
);
}
TEST
(
Expr
,
Simplify
)
{
using
namespace
tvm
;
Var
x
(
"x"
);
auto
z
=
max
(
x
+
1
+
2
,
x
+
10
)
*
100
;
std
::
ostringstream
os
;
os
<<
Simplify
(
z
);
CHECK
(
os
.
str
()
==
"((x * 100) + 1000)"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
tests/python/test_cpp.py
View file @
3c0dc79d
from
tvm
import
cpp
as
tvm
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
c
=
a
+
b
assert
a
==
c
.
lhs
assert
c
.
dtype
==
tvm
.
int32
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
...
...
@@ -13,11 +13,29 @@ def test_basic():
def
test_array
():
a
=
tvm
.
Var
(
'a'
)
x
=
tvm
.
function
.
_symbol
([
1
,
2
,
a
])
print
type
(
x
)
print
len
(
x
)
print
x
[
4
]
def
assert_equal
(
x
,
y
):
z
=
tvm
.
simplify
(
x
-
y
)
assert
isinstance
(
z
,
tvm
.
expr
.
IntExpr
)
assert
z
.
value
==
0
def
test_simplify
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
e1
=
a
*
(
2
+
1
)
+
b
*
1
e2
=
a
*
(
2
+
1
)
-
b
*
1
e3
=
tvm
.
max
(
a
*
3
+
5
,
3
+
3
*
a
)
e4
=
a
-
a
assert_equal
(
e1
,
a
*
3
+
b
)
assert_equal
(
e2
,
a
*
3
-
b
)
assert_equal
(
e3
,
a
*
3
+
5
)
assert_equal
(
e4
,
0
)
if
__name__
==
"__main__"
:
test_basic
()
test_array
()
test_simplify
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment