Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c88bda51
Commit
c88bda51
authored
Feb 22, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Feb 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] GNF (#2492)
parent
97bae615
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
169 additions
and
26 deletions
+169
-26
include/tvm/relay/pass.h
+12
-1
python/tvm/relay/ir_pass.py
+17
-2
src/relay/pass/to_a_normal_form.cc
+16
-16
src/relay/pass/to_graph_normal_form.cc
+66
-0
tests/python/relay/test_to_a_normal_form.py
+7
-7
tests/python/relay/test_to_graph_normal_form.py
+51
-0
No files found.
include/tvm/relay/pass.h
View file @
c88bda51
...
...
@@ -320,7 +320,18 @@ struct StructuralHash {
*
* \return expression in A-Normal Form
*/
Expr
ToANF
(
const
Expr
&
e
,
const
Module
&
mod
);
Expr
ToANormalForm
(
const
Expr
&
e
,
const
Module
&
mod
);
/*! \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
Expr
ToGraphNormalForm
(
const
Expr
&
e
);
}
// namespace relay
}
// namespace tvm
...
...
python/tvm/relay/ir_pass.py
View file @
c88bda51
...
...
@@ -490,7 +490,7 @@ def collect_device_annotation_ops(expr):
return
_ir_pass
.
CollectDeviceAnnotationOps
(
expr
)
def
to_a
nf
(
expr
,
mod
=
None
):
def
to_a
_normal_form
(
expr
,
mod
=
None
):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
...
...
@@ -513,7 +513,21 @@ def to_anf(expr, mod=None):
expr: tvm.relay.Expr
The output expression.
"""
return
_ir_pass
.
to_anf
(
expr
,
mod
)
return
_ir_pass
.
to_a_normal_form
(
expr
,
mod
)
def
to_graph_normal_form
(
expr
):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
expr : tvm.relay.Expr
The output expression
"""
return
_ir_pass
.
to_graph_normal_form
(
expr
)
def
gradient
(
expr
,
mod
=
None
):
...
...
@@ -534,6 +548,7 @@ def gradient(expr, mod=None):
"""
return
_ir_pass
.
first_order_gradient
(
expr
,
mod
)
def
get_total_mac_number
(
expr
):
"""
Count the number of MACs (multiply-accumulate) of a model
...
...
src/relay/pass/to_a
nf
.cc
→
src/relay/pass/to_a
_normal_form
.cc
View file @
c88bda51
...
...
@@ -196,7 +196,7 @@ DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body)
return
Creator
(
arena
).
Create
(
body
);
}
Expr
ToAN
F
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
);
Expr
ToAN
ormalForm
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
);
struct
ScopeNode
;
using
Scope
=
std
::
shared_ptr
<
ScopeNode
>
;
...
...
@@ -258,11 +258,11 @@ bool IsPrimitiveFunction(const Expr& e) {
class
Fill
:
ExprFunctor
<
Expr
(
const
Expr
&
,
const
Var
&
)
>
{
public
:
static
Expr
ToAN
F
(
const
Expr
&
e
,
const
Module
&
m
,
const
DependencyGraph
&
dg
,
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>*
node_scope
,
std
::
set
<
GlobalVar
>*
gv
)
{
static
Expr
ToAN
ormalForm
(
const
Expr
&
e
,
const
Module
&
m
,
const
DependencyGraph
&
dg
,
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>*
node_scope
,
std
::
set
<
GlobalVar
>*
gv
)
{
Fill
fi
(
m
,
dg
,
node_scope
,
gv
);
return
fi
.
GetScope
(
e
)
->
ll
->
Get
(
fi
.
VisitExpr
(
e
));
}
...
...
@@ -396,7 +396,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
GlobalVar
gv
=
GetRef
<
GlobalVar
>
(
gvn
);
if
(
visited_
->
count
(
gv
)
==
0
)
{
visited_
->
insert
(
gv
);
mod_
->
Update
(
gv
,
Downcast
<
Function
>
(
relay
::
ToAN
F
(
mod_
->
Lookup
(
gv
),
mod_
,
visited_
)));
mod_
->
Update
(
gv
,
Downcast
<
Function
>
(
relay
::
ToAN
ormalForm
(
mod_
->
Lookup
(
gv
),
mod_
,
visited_
)));
}
return
gv
;
}
...
...
@@ -423,7 +423,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};
Expr
ToAN
F
Aux
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
Expr
ToAN
ormalForm
Aux
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
...
...
@@ -446,29 +446,29 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
* We do an additional pass to fill all the LetList and we are done.
*/
std
::
unordered_map
<
DependencyGraph
::
Node
*
,
Scope
>
node_scope
=
CalcScope
(
dg
);
return
Fill
::
ToAN
F
(
e
,
m
,
dg
,
&
node_scope
,
gv
);
return
Fill
::
ToAN
ormalForm
(
e
,
m
,
dg
,
&
node_scope
,
gv
);
}
Expr
ToAN
F
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
Expr
ToAN
ormalForm
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
if
(
const
auto
*
f
=
e
.
as
<
FunctionNode
>
())
{
return
FunctionNode
::
make
(
f
->
params
,
ToAN
F
Aux
(
f
->
body
,
m
,
gv
),
ToAN
ormalForm
Aux
(
f
->
body
,
m
,
gv
),
f
->
ret_type
,
f
->
type_params
,
f
->
attrs
);
}
else
{
return
ToAN
F
Aux
(
e
,
m
,
gv
);
return
ToAN
ormalForm
Aux
(
e
,
m
,
gv
);
}
}
Expr
ToAN
F
(
const
Expr
&
e
,
const
Module
&
m
)
{
Expr
ToAN
ormalForm
(
const
Expr
&
e
,
const
Module
&
m
)
{
std
::
set
<
GlobalVar
>
gv
;
return
ToAN
F
(
e
,
m
,
&
gv
);
return
ToAN
ormalForm
(
e
,
m
,
&
gv
);
}
TVM_REGISTER_API
(
"relay._ir_pass.to_a
nf
"
)
TVM_REGISTER_API
(
"relay._ir_pass.to_a
_normal_form
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ToAN
F
(
args
[
0
],
args
[
1
]);
*
ret
=
ToAN
ormalForm
(
args
[
0
],
args
[
1
]);
});
}
// namespace relay
...
...
src/relay/pass/to_graph_normal_form.cc
0 → 100644
View file @
c88bda51
/*!
* Copyright (c) 2018 by Contributors
*
* \file to_gnf.cc
*
* \brief Turn A normal form into graph normal form.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
namespace
tvm
{
namespace
relay
{
class
UseVarVisitor
:
public
ExprVisitor
{
public
:
explicit
UseVarVisitor
(
const
Var
&
v
)
:
v
(
v
)
{
}
static
bool
UseVar
(
const
Var
&
v
,
const
Expr
&
e
)
{
UseVarVisitor
uv
(
v
);
uv
(
e
);
return
uv
.
use_var
;
}
private
:
bool
use_var
=
false
;
Var
v
;
void
VisitExpr_
(
const
VarNode
*
vn
)
override
{
use_var
=
use_var
||
(
v
==
GetRef
<
Var
>
(
vn
));
}
};
class
GNF
:
public
ExprMutator
{
private
:
std
::
unordered_map
<
Var
,
Expr
,
NodeHash
,
NodeEqual
>
var_map_
;
Expr
VisitExpr_
(
const
VarNode
*
vn
)
override
{
Var
v
=
GetRef
<
Var
>
(
vn
);
return
var_map_
.
count
(
v
)
==
0
?
v
:
var_map_
.
at
(
v
);
}
static
bool
UseVar
(
const
Var
&
v
,
const
Expr
&
e
)
{
return
UseVarVisitor
::
UseVar
(
v
,
e
);
}
static
Expr
WrapRec
(
const
Var
&
var
,
const
Expr
&
val
)
{
return
UseVar
(
var
,
val
)
?
LetNode
::
make
(
var
,
val
,
var
)
:
val
;
}
Expr
VisitExpr_
(
const
LetNode
*
ln
)
override
{
var_map_
.
insert
(
std
::
pair
<
Var
,
Expr
>
(
ln
->
var
,
VisitExpr
(
WrapRec
(
ln
->
var
,
ln
->
value
))));
return
VisitExpr
(
ln
->
body
);
}
};
Expr
ToGraphNormalForm
(
const
Expr
&
e
)
{
return
GNF
()(
e
);
}
TVM_REGISTER_API
(
"relay._ir_pass.to_graph_normal_form"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ToGraphNormalForm
(
args
[
0
]);
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_to_a
nf
.py
→
tests/python/relay/test_to_a
_normal_form
.py
View file @
c88bda51
import
numpy
as
np
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
to_a
nf
,
alpha_equal
,
infer_type
from
tvm.relay.ir_pass
import
to_a
_normal_form
,
alpha_equal
,
infer_type
from
tvm.relay
import
op
,
create_executor
from
tvm.relay.backend.interpreter
import
Value
,
TupleValue
,
ConstructorValue
from
tvm.relay.prelude
import
Prelude
...
...
@@ -21,7 +21,7 @@ def test_explicit_bound():
z
=
op
.
add
(
y
,
y
)
f
=
relay
.
Function
([],
op
.
add
(
z
,
z
))
assert
not
"let"
in
f
.
astext
()
# assert the values are implicitly bounded
anf
=
to_a
nf
(
f
)
anf
=
to_a
_normal_form
(
f
)
assert
"let"
in
anf
.
astext
()
# assert the values are explicitly bounded
check_eval
(
f
(),
8.0
)
check_eval
(
anf
(),
8.0
)
...
...
@@ -35,7 +35,7 @@ def test_order():
x
=
relay
.
const
(
1
)
val
=
x
+
y
*
z
check_eval
(
val
,
7.0
)
anf
=
infer_type
(
to_a
nf
(
val
))
anf
=
infer_type
(
to_a
_normal_form
(
val
))
a
=
relay
.
Var
(
'a'
,
relay
.
IncompleteType
())
b
=
relay
.
Var
(
'b'
,
relay
.
IncompleteType
())
c
=
relay
.
Var
(
'c'
,
relay
.
IncompleteType
())
...
...
@@ -54,7 +54,7 @@ def test_order():
def
test_if
():
cond
=
relay
.
const
(
True
)
x
=
relay
.
If
(
cond
,
relay
.
const
(
2
),
relay
.
const
(
3
))
anf
=
infer_type
(
to_a
nf
(
x
))
anf
=
infer_type
(
to_a
_normal_form
(
x
))
a
=
relay
.
Var
(
'a'
,
relay
.
IncompleteType
())
b
=
relay
.
Var
(
'b'
,
relay
.
IncompleteType
())
c
=
relay
.
Var
(
'c'
,
relay
.
IncompleteType
())
...
...
@@ -96,7 +96,7 @@ def test_recursion():
mod
[
f
]
=
value
check_eval
(
f
(
relay
.
const
(
5
,
'int64'
)),
30.0
,
mod
=
mod
)
old_f
=
mod
[
f
]
f
=
to_a
nf
(
f
,
mod
=
mod
)
f
=
to_a
_normal_form
(
f
,
mod
=
mod
)
check_eval
(
f
(
relay
.
const
(
5
,
'int64'
)),
30.0
,
mod
=
mod
)
...
...
@@ -111,7 +111,7 @@ def test_ref():
body
=
relay
.
Let
(
iv
,
relay
.
RefRead
(
i
),
body
)
body
=
relay
.
Let
(
i
,
relay
.
RefCreate
(
relay
.
const
(
1
)),
body
)
check_eval
(
body
,
3
)
check_eval
(
to_a
nf
(
body
),
3
)
check_eval
(
to_a
_normal_form
(
body
),
3
)
# this is an example of using the adt value in python side
...
...
@@ -135,7 +135,7 @@ def test_add():
intrp
=
create_executor
(
mod
=
mod
,
ctx
=
ctx
,
target
=
"llvm"
)
assert
mod
[
add
]
.
checked_type
==
relay
.
FuncType
([
nat
(),
nat
()],
nat
())
assert
count
(
intrp
.
evaluate
(
add
(
s
(
z
()),
s
(
z
()))))
==
2
assert
count
(
intrp
.
evaluate
(
to_a
nf
(
add
(
s
(
z
()),
s
(
z
())),
mod
)))
==
2
assert
count
(
intrp
.
evaluate
(
to_a
_normal_form
(
add
(
s
(
z
()),
s
(
z
())),
mod
)))
==
2
assert
"let"
in
mod
[
add
]
.
astext
()
if
__name__
==
'__main__'
:
...
...
tests/python/relay/test_to_graph_normal_form.py
0 → 100644
View file @
c88bda51
import
numpy
as
np
import
tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
to_graph_normal_form
,
to_a_normal_form
,
alpha_equal
from
tvm.relay
import
op
,
create_executor
from
tvm.relay.backend.interpreter
import
Value
,
TupleValue
def
check_eval
(
expr
,
args
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
if
mod
is
None
:
mod
=
relay
.
Module
()
ctx
=
tvm
.
context
(
"llvm"
,
0
)
intrp
=
create_executor
(
mod
=
mod
,
ctx
=
ctx
,
target
=
"llvm"
)
result
=
intrp
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
def
test_implicit_share
():
x
=
relay
.
Var
(
'x'
)
y
=
relay
.
Var
(
'y'
)
z
=
relay
.
Var
(
'z'
)
body
=
relay
.
Let
(
z
,
op
.
add
(
y
,
y
),
op
.
add
(
z
,
z
))
body
=
relay
.
Let
(
y
,
op
.
add
(
x
,
x
),
body
)
f
=
relay
.
Function
([],
relay
.
Let
(
x
,
relay
.
const
(
1
),
body
))
g
=
to_graph_normal_form
(
f
)
assert
"let"
in
f
.
astext
()
assert
not
"let"
in
g
.
astext
()
check_eval
(
f
,
[],
8.0
)
check_eval
(
g
,
[],
8.0
)
def
test_round_trip
():
x
=
relay
.
Var
(
'x'
)
y
=
relay
.
Var
(
'y'
)
z
=
relay
.
Var
(
'z'
)
body
=
relay
.
Let
(
z
,
op
.
add
(
y
,
y
),
op
.
add
(
z
,
z
))
body
=
relay
.
Let
(
y
,
op
.
add
(
x
,
x
),
body
)
f
=
relay
.
Function
([],
relay
.
Let
(
x
,
relay
.
const
(
1
),
body
))
g
=
to_graph_normal_form
(
f
)
h
=
to_a_normal_form
(
g
)
assert
"let"
in
f
.
astext
()
assert
not
"let"
in
g
.
astext
()
check_eval
(
f
,
[],
8.0
)
check_eval
(
g
,
[],
8.0
)
check_eval
(
h
,
[],
8.0
)
if
__name__
==
'__main__'
:
test_implicit_share
()
test_round_trip
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment