Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
12aca82e
Commit
12aca82e
authored
Jan 24, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Jan 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] A Normal Form Canonicalization (#2251)
parent
911c3a36
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
577 additions
and
11 deletions
+577
-11
include/tvm/relay/pass.h
+20
-0
python/tvm/relay/ir_pass.py
+34
-8
src/relay/pass/let_list.h
+5
-1
src/relay/pass/to_anf.cc
+410
-0
tests/python/relay/test_pass_dead_code_elimination.py
+2
-2
tests/python/relay/test_to_anf.py
+106
-0
No files found.
include/tvm/relay/pass.h
View file @
12aca82e
...
...
@@ -296,6 +296,26 @@ struct StructuralHash {
size_t
operator
()(
const
Expr
&
expr
)
const
;
};
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
*/
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
mod
);
}
// namespace relay
}
// namespace tvm
...
...
python/tvm/relay/ir_pass.py
View file @
12aca82e
...
...
@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
...
...
@@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
...
...
@@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t
: tvm.relay.Type
The type to check
mod
: tvm.relay.Module, optional
The global module
mod
: Optional[tvm.relay.Module]
The global module
.
Returns
-------
...
...
@@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr):
return
_ir_pass
.
CollectDeviceAnnotationOps
(
expr
)
def
to_anf
(
expr
,
mod
=
None
):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return
_ir_pass
.
to_anf
(
expr
,
mod
)
def
gradient
(
expr
,
mod
=
None
):
""".
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
...
...
@@ -489,11 +516,10 @@ def gradient(expr, mod=None):
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret
: tvm.relay.Expr
A function that calculate the original result paired with gradient
.
expr
: tvm.relay.Expr
The output expression
.
"""
return
_ir_pass
.
first_order_gradient
(
expr
,
mod
)
src/relay/pass/let_list.h
View file @
12aca82e
...
...
@@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var
Push
(
Var
pv
,
Expr
expr
)
{
CHECK
(
!
used_
);
lets_
.
emplace_back
(
std
::
make_pair
(
pv
,
expr
));
return
pv
;
}
...
...
@@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr
Get
(
const
Expr
&
body
)
const
{
Expr
Get
(
const
Expr
&
body
)
{
CHECK
(
!
used_
);
Expr
ret
=
body
;
for
(
auto
rit
=
lets_
.
rbegin
();
rit
!=
lets_
.
rend
();
++
rit
)
{
ret
=
LetNode
::
make
(
std
::
get
<
0
>
(
*
rit
),
std
::
get
<
1
>
(
*
rit
),
ret
);
}
used_
=
true
;
return
ret
;
}
...
...
@@ -108,6 +111,7 @@ class LetList {
private
:
std
::
vector
<
std
::
pair
<
Var
,
Expr
>
>
lets_
;
bool
used_
=
false
;
};
}
// namespace relay
...
...
src/relay/pass/to_anf.cc
0 → 100644
View file @
12aca82e
/*!
* Copyright (c) 2018 by Contributors
*
* \file to_anf.cc
*
* \brief Turn implicit sharing into observable sharing.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
#include "../../common/arena.h"
namespace
tvm
{
namespace
relay
{
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class
DependencyGraph
{
public
:
/*! \brief A node in the graph. */
struct
Node
{
bool
new_scope
=
false
;
LinkedList
<
Node
*>
input
;
LinkedList
<
Node
*>
output
;
};
/*! \brief The node map that maps node to graph */
std
::
unordered_map
<
Expr
,
Node
*
,
NodeHash
,
NodeEqual
>
expr_node
;
/*! \brief All the nodes in post DFS order */
std
::
vector
<
Node
*>
post_dfs_order
;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static
DependencyGraph
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
);
private
:
class
Creator
;
};
// Creator of DependencyGraph
class
DependencyGraph
::
Creator
:
private
ExprFunctor
<
void
(
const
Expr
&
e
)
>
{
public
:
explicit
Creator
(
common
::
Arena
*
arena
)
:
arena_
(
arena
)
{}
DependencyGraph
Create
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
graph_
);
}
private
:
/*! \brief allocator of all the internal node object */
common
::
Arena
*
arena_
;
// The output.
DependencyGraph
graph_
;
// Update the message stored at the node.
void
Depend
(
DependencyGraph
::
Node
*
parent
,
const
Expr
&
child
)
{
VisitExpr
(
child
);
CHECK_NE
(
graph_
.
expr_node
.
count
(
child
),
0
);
Depend
(
parent
,
graph_
.
expr_node
[
child
]);
}
void
Depend
(
DependencyGraph
::
Node
*
parent
,
DependencyGraph
::
Node
*
child
)
{
auto
*
parent_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
parent_link
->
value
=
parent
;
child
->
output
.
Push
(
parent_link
);
auto
*
child_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
child_link
->
value
=
child
;
parent
->
input
.
Push
(
child_link
);
}
std
::
unordered_set
<
Expr
,
NodeHash
,
NodeEqual
>
visited_
;
DependencyGraph
::
Node
*
NewNode
(
bool
new_scope
)
{
auto
*
ret
=
arena_
->
make
<
DependencyGraph
::
Node
>
();
ret
->
new_scope
=
new_scope
;
return
ret
;
}
void
VisitExpr
(
const
Expr
&
e
)
final
{
if
(
visited_
.
count
(
e
)
==
0
)
{
if
(
graph_
.
expr_node
.
count
(
e
)
==
0
)
{
graph_
.
expr_node
[
e
]
=
NewNode
(
false
);
}
visited_
.
insert
(
e
);
ExprFunctor
<
void
(
const
Expr
&
)
>::
VisitExpr
(
e
);
graph_
.
post_dfs_order
.
push_back
(
graph_
.
expr_node
[
e
]);
}
}
void
VisitExpr_
(
const
CallNode
*
c
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
c
)];
Depend
(
n
,
c
->
op
);
for
(
const
auto
&
a
:
c
->
args
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
for
(
const
auto
&
a
:
t
->
fields
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleGetItemNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
Depend
(
n
,
t
->
tuple
);
}
void
VisitExpr_
(
const
IfNode
*
i
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
i
)];
DependencyGraph
::
Node
*
t
=
NewNode
(
true
);
DependencyGraph
::
Node
*
f
=
NewNode
(
true
);
Depend
(
n
,
i
->
cond
);
Depend
(
n
,
t
);
Depend
(
n
,
f
);
Depend
(
t
,
i
->
true_branch
);
Depend
(
f
,
i
->
false_branch
);
graph_
.
post_dfs_order
.
push_back
(
f
);
graph_
.
post_dfs_order
.
push_back
(
t
);
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
f
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
f
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
l
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
l
->
value
);
Depend
(
b
,
l
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
GlobalVarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
ConstantNode
*
c
)
final
{
}
void
VisitExpr_
(
const
OpNode
*
o
)
final
{
}
};
DependencyGraph
DependencyGraph
::
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
)
{
return
Creator
(
arena
).
Create
(
body
);
}
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
);
struct
ScopeNode
;
using
Scope
=
std
::
shared_ptr
<
ScopeNode
>
;
/* Invariant: when parent is null level is 0
*
* Invariant: when parent is not null level is 1 + parent->level
*/
struct
ScopeNode
{
size_t
level
;
Scope
parent
;
std
::
shared_ptr
<
LetList
>
ll
=
std
::
make_shared
<
LetList
>
();
explicit
ScopeNode
(
const
Scope
&
parent
)
:
level
(
1
+
parent
->
level
),
parent
(
parent
)
{
}
ScopeNode
()
:
level
(
0
)
{
}
};
Scope
ChildScope
(
const
Scope
&
s
)
{
return
std
::
make_shared
<
ScopeNode
>
(
s
);
}
Scope
LCA
(
Scope
lhs
,
Scope
rhs
)
{
while
(
lhs
!=
rhs
)
{
if
(
lhs
->
level
>
rhs
->
level
)
{
lhs
=
lhs
->
parent
;
}
else
if
(
lhs
->
level
<
rhs
->
level
)
{
rhs
=
rhs
->
parent
;
}
else
{
lhs
=
lhs
->
parent
;
rhs
=
rhs
->
parent
;
}
}
return
lhs
;
}
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>
CalcScope
(
const
DependencyGraph
&
dg
)
{
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>
expr_scope
;
Scope
global_scope
=
std
::
make_shared
<
ScopeNode
>
();
for
(
auto
it
=
dg
.
post_dfs_order
.
rbegin
();
it
!=
dg
.
post_dfs_order
.
rend
();
++
it
)
{
DependencyGraph
::
Node
*
n
=
*
it
;
auto
iit
=
n
->
output
.
head
;
Scope
s
;
if
(
iit
==
nullptr
)
{
s
=
global_scope
;
}
else
{
s
=
expr_scope
.
at
(
iit
->
value
);
iit
=
iit
->
next
;
for
(;
iit
!=
nullptr
;
iit
=
iit
->
next
)
{
s
=
LCA
(
s
,
expr_scope
.
at
(
iit
->
value
));
}
}
expr_scope
.
insert
({
n
,
n
->
new_scope
?
ChildScope
(
s
)
:
s
});
}
return
expr_scope
;
}
bool
IsPrimitiveFunction
(
const
Expr
&
e
)
{
return
e
.
as
<
FunctionNode
>
()
&&
Downcast
<
Function
>
(
e
)
->
IsPrimitive
();
}
class
Fill
:
ExprFunctor
<
Expr
(
const
Expr
&
,
const
Var
&
)
>
{
public
:
static
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
const
DependencyGraph
&
dg
,
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>*
node_scope
,
std
::
set
<
GlobalVar
>*
gv
)
{
Fill
fi
(
m
,
dg
,
node_scope
,
gv
);
return
fi
.
GetScope
(
e
)
->
ll
->
Get
(
fi
.
VisitExpr
(
e
));
}
private
:
Module
mod_
;
const
DependencyGraph
&
dg_
;
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>*
node_scope_
;
std
::
set
<
GlobalVar
>*
visited_
;
std
::
unordered_map
<
Expr
,
Expr
,
NodeHash
,
NodeEqual
>
memo
;
Fill
(
Module
mod
,
const
DependencyGraph
&
dg
,
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>*
node_scope
,
std
::
set
<
GlobalVar
>*
visited
)
:
mod_
(
mod
),
dg_
(
dg
),
node_scope_
(
node_scope
),
visited_
(
visited
)
{
}
Scope
GetScope
(
const
Expr
&
e
)
{
return
node_scope_
->
at
(
dg_
.
expr_node
.
at
(
e
));
}
Scope
GetSubScope
(
const
Expr
&
e
,
size_t
i
)
{
DependencyGraph
::
Node
*
n
=
dg_
.
expr_node
.
at
(
e
);
auto
h
=
n
->
input
.
head
;
while
(
i
!=
0
)
{
CHECK
(
h
);
--
i
;
h
=
h
->
next
;
}
CHECK
(
h
);
return
node_scope_
->
at
(
h
->
value
);
}
Expr
VisitExpr
(
const
Expr
&
e
,
const
Var
&
v
)
final
{
if
(
memo
.
count
(
e
)
==
0
)
{
memo
.
insert
({
e
,
ExprFunctor
<
Expr
(
const
Expr
&
,
const
Var
&
)
>::
VisitExpr
(
e
,
v
)});
}
return
memo
.
at
(
e
);
}
Expr
VisitExpr
(
const
Expr
&
e
)
{
Var
v
=
VarNode
::
make
(
std
::
string
(
"x"
),
IncompleteTypeNode
::
make
(
TypeVarNode
::
kType
));
return
this
->
VisitExpr
(
e
,
v
);
}
Expr
Compound
(
const
Expr
&
orig
,
const
Expr
&
now
,
const
Var
&
v
)
{
return
GetScope
(
orig
)
->
ll
->
Push
(
v
,
now
);
}
Expr
VisitExpr_
(
const
CallNode
*
c
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
c
);
std
::
vector
<
Expr
>
args
;
for
(
const
auto
&
a
:
c
->
args
)
{
args
.
push_back
(
VisitExpr
(
a
));
}
return
Compound
(
e
,
CallNode
::
make
(
VisitExpr
(
c
->
op
),
args
,
c
->
attrs
,
c
->
type_args
),
v
);
}
Expr
VisitExpr_
(
const
TupleNode
*
t
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
t
);
std
::
vector
<
Expr
>
fields
;
for
(
const
auto
&
a
:
t
->
fields
)
{
fields
.
push_back
(
VisitExpr
(
a
));
}
return
Compound
(
e
,
TupleNode
::
make
(
fields
),
v
);
}
Expr
VisitExpr_
(
const
TupleGetItemNode
*
t
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
t
);
return
Compound
(
e
,
TupleGetItemNode
::
make
(
VisitExpr
(
t
->
tuple
),
t
->
index
),
v
);
}
Expr
VisitExpr_
(
const
IfNode
*
i
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
i
);
Expr
ret
=
IfNode
::
make
(
VisitExpr
(
i
->
cond
),
GetSubScope
(
e
,
1
)
->
ll
->
Get
(
VisitExpr
(
i
->
true_branch
)),
GetSubScope
(
e
,
2
)
->
ll
->
Get
(
VisitExpr
(
i
->
false_branch
)));
return
Compound
(
e
,
ret
,
v
);
}
Expr
VisitExpr_
(
const
FunctionNode
*
f
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
f
);
Expr
ret
;
if
(
IsPrimitiveFunction
(
e
))
{
ret
=
e
;
}
else
{
ret
=
FunctionNode
::
make
(
f
->
params
,
GetSubScope
(
e
,
0
)
->
ll
->
Get
(
VisitExpr
(
f
->
body
)),
f
->
ret_type
,
f
->
type_params
,
f
->
attrs
);
}
return
Compound
(
e
,
ret
,
v
);
}
Expr
VisitExpr_
(
const
LetNode
*
l
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
l
);
VisitExpr
(
l
->
value
,
l
->
var
);
Expr
ret
=
GetSubScope
(
e
,
0
)
->
ll
->
Get
(
VisitExpr
(
l
->
body
));
return
Compound
(
e
,
ret
,
v
);
}
Expr
VisitExpr_
(
const
ConstantNode
*
c
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
c
);
return
Compound
(
e
,
e
,
v
);
}
Expr
VisitExpr_
(
const
VarNode
*
vn
,
const
Var
&
v
)
final
{
return
GetRef
<
Expr
>
(
vn
);
}
Expr
VisitExpr_
(
const
GlobalVarNode
*
gvn
,
const
Var
&
v
)
final
{
GlobalVar
gv
=
GetRef
<
GlobalVar
>
(
gvn
);
if
(
visited_
->
count
(
gv
)
==
0
)
{
visited_
->
insert
(
gv
);
mod_
->
Update
(
gv
,
Downcast
<
Function
>
(
relay
::
ToANF
(
mod_
->
Lookup
(
gv
),
mod_
,
visited_
)));
}
return
gv
;
}
Expr
VisitExpr_
(
const
OpNode
*
op
,
const
Var
&
v
)
final
{
return
GetRef
<
Expr
>
(
op
);
}
};
Expr
ToANFAux
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
*
* To make this more principled,
* we always determine the scope of parent before determining the scope of children.
*
* So we calculate all the dependency between nodes.
*/
common
::
Arena
arena
;
DependencyGraph
dg
=
DependencyGraph
::
Create
(
&
arena
,
e
);
/* In order to model new subscopes created by lambda, if else and pattern matching,
* we also assign scope to edge as well.
* The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
*
* So, the scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
*
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>
node_scope
=
CalcScope
(
dg
);
return
Fill
::
ToANF
(
e
,
m
,
dg
,
&
node_scope
,
gv
);
}
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
if
(
auto
f
=
e
.
as
<
FunctionNode
>
())
{
return
FunctionNode
::
make
(
f
->
params
,
ToANFAux
(
f
->
body
,
m
,
gv
),
f
->
ret_type
,
f
->
type_params
,
f
->
attrs
);
}
else
{
return
ToANFAux
(
e
,
m
,
gv
);
}
}
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
m
)
{
std
::
set
<
GlobalVar
>
gv
;
return
ToANF
(
e
,
m
,
&
gv
);
}
TVM_REGISTER_API
(
"relay._ir_pass.to_anf"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ToANF
(
args
[
0
],
args
[
1
]);
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_dead_code_elimination.py
View file @
12aca82e
...
...
@@ -62,9 +62,9 @@ def test_recursion():
relay
.
Call
(
f
,
[
subtract
(
n
,
relay
.
const
(
1.0
)),
log
(
data
)]))
value
=
relay
.
Function
([
n
,
data
],
funcbody
,
e
.
float32
,
[])
orig
=
relay
.
Let
(
f
,
funcbody
,
relay
.
Call
(
f
,
[
relay
.
const
(
2.0
),
relay
.
const
(
10000.0
)]))
orig
=
relay
.
Let
(
f
,
value
,
relay
.
Call
(
f
,
[
relay
.
const
(
2.0
),
relay
.
const
(
10000.0
)]))
assert
alpha_equal
(
dead_code_elimination
(
orig
),
orig
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
funcbody
,
e
.
three
)),
e
.
three
)
assert
alpha_equal
(
dead_code_elimination
(
relay
.
Let
(
f
,
value
,
e
.
three
)),
e
.
three
)
def
test_op_let
():
...
...
tests/python/relay/test_to_anf.py
0 → 100644
View file @
12aca82e
import
numpy
as
np
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
to_anf
,
alpha_equal
,
infer_type
from
tvm.relay
import
op
,
create_executor
from
tvm.relay.backend.interpreter
import
Value
,
TupleValue
def
check_eval
(
expr
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
ctx
=
tvm
.
context
(
"llvm"
,
0
)
intrp
=
create_executor
(
mod
=
mod
,
ctx
=
ctx
,
target
=
"llvm"
)
result
=
intrp
.
evaluate
(
expr
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
def
test_explicit_bound
():
x
=
relay
.
const
(
1
)
y
=
op
.
add
(
x
,
x
)
z
=
op
.
add
(
y
,
y
)
f
=
relay
.
Function
([],
op
.
add
(
z
,
z
))
assert
not
"let"
in
f
.
astext
()
# assert the values are implicitly bounded
anf
=
to_anf
(
f
)
assert
"let"
in
anf
.
astext
()
# assert the values are explicitly bounded
check_eval
(
f
(),
8.0
)
check_eval
(
anf
(),
8.0
)
# test that the construction order does not matter,
# and is instead ordered by the scope and by post-dfs ordering.
def
test_order
():
z
=
relay
.
const
(
3
)
y
=
relay
.
const
(
2
)
x
=
relay
.
const
(
1
)
val
=
x
+
y
*
z
check_eval
(
val
,
7.0
)
anf
=
infer_type
(
to_anf
(
val
))
a
=
relay
.
Var
(
'a'
,
relay
.
IncompleteType
())
b
=
relay
.
Var
(
'b'
,
relay
.
IncompleteType
())
c
=
relay
.
Var
(
'c'
,
relay
.
IncompleteType
())
d
=
relay
.
Var
(
'd'
,
relay
.
IncompleteType
())
e
=
relay
.
Var
(
'e'
,
relay
.
IncompleteType
())
expected_output
=
e
expected_output
=
relay
.
Let
(
e
,
a
+
d
,
expected_output
)
expected_output
=
relay
.
Let
(
d
,
b
*
c
,
expected_output
)
expected_output
=
relay
.
Let
(
c
,
z
,
expected_output
)
expected_output
=
relay
.
Let
(
b
,
y
,
expected_output
)
expected_output
=
relay
.
Let
(
a
,
x
,
expected_output
)
expected_output
=
infer_type
(
expected_output
)
assert
alpha_equal
(
anf
,
expected_output
)
def
test_if
():
cond
=
relay
.
const
(
True
)
x
=
relay
.
If
(
cond
,
relay
.
const
(
2
),
relay
.
const
(
3
))
anf
=
infer_type
(
to_anf
(
x
))
a
=
relay
.
Var
(
'a'
,
relay
.
IncompleteType
())
b
=
relay
.
Var
(
'b'
,
relay
.
IncompleteType
())
c
=
relay
.
Var
(
'c'
,
relay
.
IncompleteType
())
d
=
relay
.
Var
(
'd'
,
relay
.
IncompleteType
())
true_branch
=
relay
.
Let
(
a
,
relay
.
const
(
2
),
a
)
false_branch
=
relay
.
Let
(
b
,
relay
.
const
(
3
),
b
)
expected_output
=
relay
.
If
(
c
,
true_branch
,
false_branch
)
expected_output
=
relay
.
Let
(
d
,
expected_output
,
d
)
expected_output
=
relay
.
Let
(
c
,
cond
,
expected_output
)
expected_output
=
infer_type
(
expected_output
)
assert
alpha_equal
(
anf
,
expected_output
)
# make sure we dont infinite loop.
# it is too large so we wont check for the exact program.
def
test_recursion
():
"""
Program:
let sum_twice(n: i32) -> i32 = {
m = (n * 2)
if (n == 0) {
return m;
} else {
return m + sum(n - 1);
}
}
sum_twice(5);
"""
return
# cannot be run as fuse_ops need to recursively visit
mod
=
relay
.
Module
()
i64
=
relay
.
TensorType
((),
'int64'
)
f
=
relay
.
GlobalVar
(
"f"
)
n
=
relay
.
Var
(
"n"
,
i64
)
m
=
n
*
relay
.
const
(
2
,
'int64'
)
funcbody
=
relay
.
If
(
relay
.
equal
(
n
,
relay
.
const
(
0
,
'int64'
)),
m
,
m
+
f
(
n
-
relay
.
const
(
1
,
'int64'
)))
value
=
relay
.
Function
([
n
],
funcbody
,
i64
,
[])
mod
[
f
]
=
value
check_eval
(
f
(
relay
.
const
(
5
,
'int64'
)),
30.0
,
mod
=
mod
)
old_f
=
mod
[
f
]
f
=
to_anf
(
f
,
mod
=
mod
)
check_eval
(
f
(
relay
.
const
(
5
,
'int64'
)),
30.0
,
mod
=
mod
)
if
__name__
==
'__main__'
:
test_explicit_bound
()
test_order
()
test_if
()
test_recursion
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment