Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dee8cf9b
Commit
dee8cf9b
authored
Feb 22, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Feb 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] fix anf for reference and pattern matching (#2637)
parent
cc5a3cf0
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
1 deletions
+107
-1
src/relay/pass/to_anf.cc
+65
-0
tests/python/relay/test_to_anf.py
+42
-1
No files found.
src/relay/pass/to_anf.cc
View file @
dee8cf9b
...
...
@@ -120,6 +120,22 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
Depend
(
n
,
t
->
tuple
);
}
void
VisitExpr_
(
const
RefCreateNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
RefReadNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
}
void
VisitExpr_
(
const
RefWriteNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
IfNode
*
i
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
i
)];
DependencyGraph
::
Node
*
t
=
NewNode
(
true
);
...
...
@@ -150,6 +166,21 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
MatchNode
*
m
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
m
)];
Depend
(
n
,
m
->
data
);
std
::
vector
<
DependencyGraph
::
Node
*>
v
;
for
(
const
Clause
&
c
:
m
->
clauses
)
{
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
c
->
rhs
);
v
.
push_back
(
b
);
}
for
(
auto
it
=
v
.
rbegin
();
it
!=
v
.
rend
();
++
it
)
{
graph_
.
post_dfs_order
.
push_back
(
*
it
);
}
}
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
GlobalVarNode
*
v
)
final
{
}
...
...
@@ -157,6 +188,8 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
void
VisitExpr_
(
const
ConstantNode
*
c
)
final
{
}
void
VisitExpr_
(
const
OpNode
*
o
)
final
{
}
void
VisitExpr_
(
const
ConstructorNode
*
c
)
final
{
}
};
DependencyGraph
DependencyGraph
::
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
)
{
...
...
@@ -305,6 +338,21 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
return
Compound
(
e
,
TupleGetItemNode
::
make
(
VisitExpr
(
t
->
tuple
),
t
->
index
),
v
);
}
Expr
VisitExpr_
(
const
RefCreateNode
*
r
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
r
);
return
Compound
(
e
,
RefCreateNode
::
make
(
VisitExpr
(
r
->
value
)),
v
);
}
Expr
VisitExpr_
(
const
RefReadNode
*
r
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
r
);
return
Compound
(
e
,
RefReadNode
::
make
(
VisitExpr
(
r
->
ref
)),
v
);
}
Expr
VisitExpr_
(
const
RefWriteNode
*
r
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
r
);
return
Compound
(
e
,
RefWriteNode
::
make
(
VisitExpr
(
r
->
ref
),
VisitExpr
(
r
->
value
)),
v
);
}
Expr
VisitExpr_
(
const
IfNode
*
i
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
i
);
Expr
ret
=
IfNode
::
make
(
VisitExpr
(
i
->
cond
),
...
...
@@ -356,6 +404,23 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr
VisitExpr_
(
const
OpNode
*
op
,
const
Var
&
v
)
final
{
return
GetRef
<
Expr
>
(
op
);
}
Expr
VisitExpr_
(
const
ConstructorNode
*
c
,
const
Var
&
v
)
final
{
return
GetRef
<
Expr
>
(
c
);
}
Expr
VisitExpr_
(
const
MatchNode
*
m
,
const
Var
&
v
)
final
{
Expr
e
=
GetRef
<
Expr
>
(
m
);
Expr
data
=
VisitExpr
(
m
->
data
);
std
::
vector
<
Clause
>
clauses
;
for
(
const
Clause
&
c
:
m
->
clauses
)
{
clauses
.
push_back
(
ClauseNode
::
make
(
c
->
lhs
,
GetSubScope
(
e
,
1
+
clauses
.
size
())
->
ll
->
Get
(
VisitExpr
(
c
->
rhs
))));
}
Expr
r
=
Compound
(
e
,
MatchNode
::
make
(
data
,
clauses
),
v
);
return
r
;
}
};
Expr
ToANFAux
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
)
{
...
...
tests/python/relay/test_to_anf.py
View file @
dee8cf9b
...
...
@@ -3,7 +3,8 @@ import tvm
from
tvm
import
relay
from
tvm.relay.ir_pass
import
to_anf
,
alpha_equal
,
infer_type
from
tvm.relay
import
op
,
create_executor
from
tvm.relay.backend.interpreter
import
Value
,
TupleValue
from
tvm.relay.backend.interpreter
import
Value
,
TupleValue
,
ConstructorValue
from
tvm.relay.prelude
import
Prelude
def
check_eval
(
expr
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
...
...
@@ -99,8 +100,48 @@ def test_recursion():
check_eval
(
f
(
relay
.
const
(
5
,
'int64'
)),
30.0
,
mod
=
mod
)
def
test_ref
():
i
=
relay
.
Var
(
'i'
)
iv
=
relay
.
Var
(
'iv'
)
u
=
relay
.
Var
(
'u'
)
uv
=
relay
.
Var
(
'uv'
)
body
=
relay
.
add
(
iv
,
uv
)
body
=
relay
.
Let
(
uv
,
relay
.
RefRead
(
i
),
body
)
body
=
relay
.
Let
(
u
,
relay
.
RefWrite
(
i
,
relay
.
const
(
2
)),
body
)
body
=
relay
.
Let
(
iv
,
relay
.
RefRead
(
i
),
body
)
body
=
relay
.
Let
(
i
,
relay
.
RefCreate
(
relay
.
const
(
1
)),
body
)
check_eval
(
body
,
3
)
check_eval
(
to_anf
(
body
),
3
)
# this is an example of using the adt value in python side
def
count
(
n
):
assert
isinstance
(
n
,
ConstructorValue
)
if
n
.
constructor
.
name_hint
==
's'
:
return
1
+
count
(
n
.
fields
[
0
])
else
:
assert
n
.
constructor
.
name_hint
==
'z'
return
0
def
test_add
():
mod
=
relay
.
Module
()
p
=
Prelude
(
mod
)
nat
=
p
.
nat
add
=
p
.
add
s
=
p
.
s
z
=
p
.
z
ctx
=
tvm
.
context
(
"llvm"
,
0
)
intrp
=
create_executor
(
mod
=
mod
,
ctx
=
ctx
,
target
=
"llvm"
)
assert
mod
[
add
]
.
checked_type
==
relay
.
FuncType
([
nat
(),
nat
()],
nat
())
assert
count
(
intrp
.
evaluate
(
add
(
s
(
z
()),
s
(
z
()))))
==
2
assert
count
(
intrp
.
evaluate
(
to_anf
(
add
(
s
(
z
()),
s
(
z
())),
mod
)))
==
2
assert
"let"
in
mod
[
add
]
.
astext
()
if
__name__
==
'__main__'
:
test_explicit_bound
()
test_order
()
test_if
()
test_recursion
()
test_ref
()
test_add
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment