Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
13383928
Commit
13383928
authored
Oct 22, 2016
by
Haichen Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add var binding for expr
parent
816419be
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
0 deletions
+138
-0
include/tvm/expr_util.h
+51
-0
src/expr/expr_util.cc
+62
-0
src/expr/op.cc
+6
-0
tests/cpp/expr_test.cc
+19
-0
No files found.
include/tvm/expr_util.h
View file @
13383928
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
#ifndef TVM_EXPR_UTIL_H_
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include <vector>
#include "./expr.h"
#include "./expr.h"
#include "./expr_node.h"
#include "./expr_node.h"
...
@@ -19,6 +21,14 @@ namespace tvm {
...
@@ -19,6 +21,14 @@ namespace tvm {
Expr
Simplify
(
Expr
src
);
Expr
Simplify
(
Expr
src
);
/*!
/*!
* \brief replace the variables in expression src by specification from dict
* \param src The source expression
* \param dict The specification for variable replacement
* \return the new expression with variable replaced
*/
Expr
Bind
(
Expr
src
,
std
::
unordered_map
<
Expr
,
Expr
>
dict
);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param expr The expression tree
* \param fvisit The visit function.
* \param fvisit The visit function.
...
@@ -55,6 +65,47 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
...
@@ -55,6 +65,47 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
fvisit
(
expr
);
fvisit
(
expr
);
}
}
/*!
* \brief transform the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
* \return the new expression after transformation
*/
template
<
typename
FVisit
>
inline
Expr
Transform
(
const
Expr
&
expr
,
FVisit
fvisit
)
{
// TODO(tqchen) change to stack based impl.
std
::
vector
<
Expr
>
children
;
switch
(
expr
.
node_type
())
{
case
kBinaryOpNode
:
{
const
auto
*
n
=
expr
.
Get
<
BinaryOpNode
>
();
Expr
e
=
Transform
(
n
->
lhs
,
fvisit
);
children
.
push_back
(
e
);
children
.
push_back
(
Transform
(
n
->
rhs
,
fvisit
));
break
;
}
case
kUnaryOpNode
:
{
const
auto
*
n
=
expr
.
Get
<
UnaryOpNode
>
();
children
.
push_back
(
Transform
(
n
->
src
,
fvisit
));
break
;
}
case
kReduceNode
:
{
const
auto
*
n
=
expr
.
Get
<
ReduceNode
>
();
children
.
push_back
(
Transform
(
n
->
src
,
fvisit
));
break
;
}
case
kTensorReadNode
:
{
const
auto
*
n
=
expr
.
Get
<
TensorReadNode
>
();
for
(
size_t
i
=
0
;
i
<
n
->
indices
.
size
();
++
i
)
{
children
.
push_back
(
Transform
(
n
->
indices
[
i
],
fvisit
));
}
break
;
}
default
:
break
;
}
Expr
ret
=
fvisit
(
expr
,
children
);
return
ret
;
}
}
// namespace tvm
}
// namespace tvm
#endif // TVM_EXPR_UTIL_H_
#endif // TVM_EXPR_UTIL_H_
src/expr/expr_util.cc
View file @
13383928
...
@@ -146,6 +146,68 @@ Expr Simplify(Expr src) {
...
@@ -146,6 +146,68 @@ Expr Simplify(Expr src) {
return
cexpr
.
AsExpr
();
return
cexpr
.
AsExpr
();
}
}
Expr
ExprWithNewChildren
(
Expr
src
,
std
::
vector
<
Expr
>
children
)
{
if
(
children
.
size
())
{
switch
(
src
.
node_type
())
{
case
kBinaryOpNode
:
{
const
auto
*
n
=
src
.
Get
<
BinaryOpNode
>
();
if
(
n
->
lhs
==
children
[
0
]
&&
n
->
rhs
==
children
[
0
])
return
src
;
return
(
*
n
->
op
)(
children
[
0
],
children
[
1
]);
}
case
kUnaryOpNode
:
{
const
auto
*
n
=
src
.
Get
<
UnaryOpNode
>
();
if
(
n
->
src
==
children
[
0
])
return
src
;
return
(
*
n
->
op
)(
children
[
0
]);
}
case
kReduceNode
:
{
const
auto
*
n
=
src
.
Get
<
ReduceNode
>
();
if
(
n
->
src
==
children
[
0
])
return
src
;
return
(
n
->
op
)
->
Reduce
(
children
[
0
],
n
->
rdom
);
}
case
kTensorReadNode
:
{
const
auto
*
n
=
src
.
Get
<
TensorReadNode
>
();
bool
same
=
true
;
for
(
size_t
i
=
0
;
i
<
n
->
indices
.
size
();
++
i
)
{
if
(
n
->
indices
[
i
]
!=
children
[
i
])
{
same
=
false
;
break
;
}
}
if
(
same
)
return
src
;
Array
<
Expr
>
indices
(
children
);
return
n
->
tensor
(
indices
);
}
default:
{
return
src
;
}
}
}
return
src
;
}
Expr
Bind
(
Expr
src
,
std
::
unordered_map
<
Expr
,
Expr
>
dict
)
{
auto
replace
=
[
&
](
Expr
e
,
std
::
vector
<
Expr
>
children
)
{
switch
(
e
.
node_type
())
{
case
kVarNode
:
{
auto
it
=
dict
.
find
(
e
);
if
(
it
!=
dict
.
end
())
{
return
it
->
second
;
}
return
e
;
}
default:
{
return
ExprWithNewChildren
(
e
,
children
);
}
}
};
return
Transform
(
src
,
replace
);
}
void
Expr
::
Print
(
std
::
ostream
&
os
)
const
{
void
Expr
::
Print
(
std
::
ostream
&
os
)
const
{
if
(
is_null
())
{
if
(
is_null
())
{
os
<<
"null"
;
return
;
os
<<
"null"
;
return
;
...
...
src/expr/op.cc
View file @
13383928
...
@@ -13,6 +13,12 @@ DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
...
@@ -13,6 +13,12 @@ DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
namespace
tvm
{
namespace
tvm
{
Expr
UnaryOp
::
operator
()(
Expr
src
)
const
{
auto
nptr
=
std
::
make_shared
<
UnaryOpNode
>
(
this
,
std
::
move
(
src
));
nptr
->
Verify
();
return
Expr
(
std
::
move
(
nptr
));
}
Expr
BinaryOp
::
operator
()(
Expr
lhs
,
Expr
rhs
)
const
{
Expr
BinaryOp
::
operator
()(
Expr
lhs
,
Expr
rhs
)
const
{
auto
nptr
=
std
::
make_shared
<
BinaryOpNode
>
(
auto
nptr
=
std
::
make_shared
<
BinaryOpNode
>
(
this
,
std
::
move
(
lhs
),
std
::
move
(
rhs
));
this
,
std
::
move
(
lhs
),
std
::
move
(
rhs
));
...
...
tests/cpp/expr_test.cc
View file @
13383928
...
@@ -30,6 +30,25 @@ TEST(Expr, Simplify) {
...
@@ -30,6 +30,25 @@ TEST(Expr, Simplify) {
CHECK
(
os
.
str
()
==
"((x * 100) + 1000)"
);
CHECK
(
os
.
str
()
==
"((x * 100) + 1000)"
);
}
}
TEST
(
Expr
,
Bind
)
{
using
namespace
tvm
;
Var
x
(
"x"
),
y
(
"y"
),
z
(
"z"
);
Var
i
(
"i"
),
j
(
"j"
);
Tensor
A
({
y
,
z
},
"A"
);
Expr
e1
=
x
*
5
;
std
::
unordered_map
<
Expr
,
Expr
>
dict
=
{{
x
,
y
*
10
+
z
}};
std
::
ostringstream
os1
,
os2
;
os1
<<
Bind
(
e1
,
dict
);
CHECK
(
os1
.
str
()
==
"(((y * 10) + z) * 5)"
);
Expr
e2
=
A
(
i
,
j
);
dict
.
clear
();
dict
[
i
]
=
64
*
x
;
dict
[
j
]
=
z
+
16
*
y
;
os2
<<
Bind
(
e2
,
dict
);
CHECK
(
os2
.
str
()
==
"A[(64 * x), (z + (16 * y))]"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment