Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dc97e527
Commit
dc97e527
authored
Apr 12, 2019
by
Josh Pollock
Committed by
Tianqi Chen
Apr 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Text Format] Pretty Printer Smart Inlining (#2881)
parent
fefbb006
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
324 additions
and
290 deletions
+324
-290
python/tvm/relay/ir_pass.py
+0
-33
src/relay/ir/pretty_printer.cc
+63
-49
src/relay/pass/dependency_graph.cc
+165
-0
src/relay/pass/dependency_graph.h
+57
-0
src/relay/pass/to_a_normal_form.cc
+3
-185
tests/python/relay/test_ir_text_printer.py
+17
-5
tests/python/relay/test_op_level1.py
+2
-2
tests/python/relay/test_op_level4.py
+1
-1
tests/python/relay/test_type_infer.py
+16
-15
No files found.
python/tvm/relay/ir_pass.py
View file @
dc97e527
...
@@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
...
@@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
"""
return
_ir_pass
.
eliminate_common_subexpr
(
expr
,
fskip
)
return
_ir_pass
.
eliminate_common_subexpr
(
expr
,
fskip
)
def
pass_debug_print
(
ast
,
show_meta_data
=
True
,
annotate
=
None
,
gnf
=
True
):
"""
THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT!
USE `.astext()` INSTEAD!
A version of the pretty printer intended for debugging passes. Contains
advanced printing options.
Parameters
----------
ast : Union[relay.Expr, relay.Module, relay.Type]
The relay fragment to be turned into text.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
gnf : bool
Whether to print in GNF. If it is disabled, pointers are left implicit.
Returns
-------
text : str
A text representation of `ast`.
"""
return
_ir_pass
.
pass_debug_print
(
ast
,
show_meta_data
,
annotate
,
gnf
)
def
partial_evaluate
(
expr
):
def
partial_evaluate
(
expr
):
"""
"""
Evaluate the static fragment of the code.
Evaluate the static fragment of the code.
...
...
src/relay/ir/pretty_printer.cc
View file @
dc97e527
...
@@ -22,12 +22,22 @@
...
@@ -22,12 +22,22 @@
* \file pretty_printer.cc
* \file pretty_printer.cc
* \brief Pretty printer for Relay programs
* \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata.
* Supports ANF, GNF, and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "doc.h"
#include "type_functor.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"
#include "../../lang/attr_functor.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -135,10 +145,8 @@ class PrettyPrinter :
...
@@ -135,10 +145,8 @@ class PrettyPrinter :
public
TypeFunctor
<
Doc
(
const
Type
&
)
>
,
public
TypeFunctor
<
Doc
(
const
Type
&
)
>
,
public
AttrFunctor
<
Doc
(
const
NodeRef
&
)
>
{
public
AttrFunctor
<
Doc
(
const
NodeRef
&
)
>
{
public
:
public
:
explicit
PrettyPrinter
(
bool
GNF
,
explicit
PrettyPrinter
(
bool
show_meta_data
,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
:
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
:
GNF_
(
GNF
),
show_meta_data_
(
show_meta_data
),
show_meta_data_
(
show_meta_data
),
annotate_
(
annotate
)
{}
annotate_
(
annotate
)
{}
...
@@ -150,10 +158,9 @@ class PrettyPrinter :
...
@@ -150,10 +158,9 @@ class PrettyPrinter :
Doc
doc
;
Doc
doc
;
// additional information in comment.
// additional information in comment.
if
(
annotate_
!=
nullptr
)
{
if
(
annotate_
!=
nullptr
)
{
return
doc
<<
" /
/ "
<<
annotate_
(
expr
)
;
return
doc
<<
" /
* "
<<
annotate_
(
expr
)
<<
" */"
;
}
else
if
(
expr
->
checked_type_
.
defined
())
{
}
else
if
(
expr
->
checked_type_
.
defined
())
{
doc
<<
" // ty="
;
return
doc
<<
" /* ty="
<<
Print
(
expr
->
checked_type
())
<<
" */"
;
return
doc
<<
Print
(
expr
->
checked_type
());
}
else
{
}
else
{
return
doc
;
return
doc
;
}
}
...
@@ -176,13 +183,18 @@ class PrettyPrinter :
...
@@ -176,13 +183,18 @@ class PrettyPrinter :
// print in a new scope
// print in a new scope
doc_stack_
.
push_back
(
Doc
());
doc_stack_
.
push_back
(
Doc
());
// must print first so doc_stack_.back() reference doesn't become stale
// must print first so doc_stack_.back() reference doesn't become stale
Doc
doc
=
Print
(
node
);
Doc
doc
=
Print
(
node
,
false
,
true
);
doc
=
doc_stack_
.
back
()
<<
doc
;
doc
=
doc_stack_
.
back
()
<<
doc
;
doc_stack_
.
pop_back
();
doc_stack_
.
pop_back
();
return
doc
;
return
doc
;
}
}
Doc
PrintFinal
(
const
NodeRef
&
node
)
{
Doc
PrintFinal
(
const
NodeRef
&
node
)
{
if
(
node
.
as_derived
<
ExprNode
>
())
{
Expr
expr
=
Downcast
<
Expr
>
(
node
);
dg_
=
DependencyGraph
::
Create
(
&
arena_
,
expr
);
}
Doc
doc
;
Doc
doc
;
doc
<<
PrintScope
(
node
);
doc
<<
PrintScope
(
node
);
if
(
!
meta_
.
empty
())
{
if
(
!
meta_
.
empty
())
{
...
@@ -200,9 +212,9 @@ class PrettyPrinter :
...
@@ -200,9 +212,9 @@ class PrettyPrinter :
Doc
PrintAttrs
(
const
Attrs
&
attrs
,
const
Expr
&
op
);
Doc
PrintAttrs
(
const
Attrs
&
attrs
,
const
Expr
&
op
);
Doc
Print
(
const
NodeRef
&
node
,
bool
meta
=
false
)
{
Doc
Print
(
const
NodeRef
&
node
,
bool
meta
=
false
,
bool
try_inline
=
false
)
{
if
(
node
.
as_derived
<
ExprNode
>
())
{
if
(
node
.
as_derived
<
ExprNode
>
())
{
return
PrintExpr
(
Downcast
<
Expr
>
(
node
),
meta
);
return
PrintExpr
(
Downcast
<
Expr
>
(
node
),
meta
,
try_inline
);
}
else
if
(
node
.
as_derived
<
TypeNode
>
())
{
}
else
if
(
node
.
as_derived
<
TypeNode
>
())
{
return
PrintType
(
Downcast
<
Type
>
(
node
),
meta
);
return
PrintType
(
Downcast
<
Type
>
(
node
),
meta
);
}
else
if
(
node
.
as_derived
<
ModuleNode
>
())
{
}
else
if
(
node
.
as_derived
<
ModuleNode
>
())
{
...
@@ -308,7 +320,12 @@ class PrettyPrinter :
...
@@ -308,7 +320,12 @@ class PrettyPrinter :
return
val
;
return
val
;
}
}
inline
bool
IsAtomicExpr
(
const
Expr
&
expr
)
{
bool
IsUnique
(
const
Expr
&
expr
)
{
return
!
(
dg_
.
expr_node
.
at
(
expr
)
->
parents
.
head
&&
dg_
.
expr_node
.
at
(
expr
)
->
parents
.
head
->
next
);
}
bool
AlwaysInline
(
const
Expr
&
expr
)
{
return
expr
.
as
<
GlobalVarNode
>
()
||
expr
.
as
<
ConstantNode
>
()
||
return
expr
.
as
<
GlobalVarNode
>
()
||
expr
.
as
<
ConstantNode
>
()
||
expr
.
as
<
OpNode
>
()
||
expr
.
as
<
VarNode
>
();
expr
.
as
<
OpNode
>
()
||
expr
.
as
<
VarNode
>
();
}
}
...
@@ -316,17 +333,25 @@ class PrettyPrinter :
...
@@ -316,17 +333,25 @@ class PrettyPrinter :
//------------------------------------
//------------------------------------
// Overload of Expr printing functions
// Overload of Expr printing functions
//------------------------------------
//------------------------------------
Doc
PrintExpr
(
const
Expr
&
expr
,
bool
meta
)
{
Doc
PrintExpr
(
const
Expr
&
expr
,
bool
meta
,
bool
try_inline
)
{
// Exploit memoization to print GNF.
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.
// This works since hashing uses pointer equality.
// determine whether to inline
bool
inline_expr
=
AlwaysInline
(
expr
);
if
(
try_inline
)
{
inline_expr
|=
IsUnique
(
expr
);
}
auto
it
=
memo_
.
find
(
expr
);
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
return
it
->
second
;
if
(
it
!=
memo_
.
end
())
return
it
->
second
;
Doc
printed_expr
;
Doc
printed_expr
;
if
(
meta
)
{
if
(
meta
)
{
printed_expr
=
meta_
.
GetMetaNode
(
GetRef
<
NodeRef
>
(
expr
.
get
()));
printed_expr
=
meta_
.
GetMetaNode
(
GetRef
<
NodeRef
>
(
expr
.
get
()));
}
else
if
(
GNF_
&&
expr
.
as
<
LetNode
>
())
{
}
else
if
(
!
inline_expr
&&
expr
.
as
<
LetNode
>
())
{
// wrap GNFed let in brackets
// wrap GNFed let in brackets
Doc
body
;
Doc
body
;
printed_expr
<<
"{"
;
printed_expr
<<
"{"
;
...
@@ -335,28 +360,26 @@ class PrettyPrinter :
...
@@ -335,28 +360,26 @@ class PrettyPrinter :
}
else
{
}
else
{
printed_expr
=
VisitExpr
(
expr
);
printed_expr
=
VisitExpr
(
expr
);
}
}
// we choose to inline atomic exprs
if
(
GNF_
&&
!
IsAtomicExpr
(
expr
))
{
if
(
expr
.
as
<
CallNode
>
())
{
Doc
temp_var
=
AllocTemp
();
printed_expr
<<
PrintOptionalInfo
(
expr
);
memo_
[
expr
]
=
temp_var
;
}
doc_stack_
.
back
()
<<
temp_var
<<
" = "
<<
printed_expr
;
if
(
expr
.
as
<
CallNode
>
())
{
// add expr to doc
doc_stack_
.
back
()
<<
PrintOptionalInfo
(
expr
);
if
(
expr
.
as
<
VarNode
>
())
{
}
doc_stack_
.
back
()
<<
"
\n
"
;
return
temp_var
;
}
else
if
(
expr
.
as
<
VarNode
>
())
{
// This is our first time visiting the var and we hit the VarNode case
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
// in the visitor. Thus the variable is free.
doc_stack_
.
back
()
<<
"free_var "
<<
printed_expr
<<
"
\n
"
;
doc_stack_
.
back
()
<<
"free_var "
<<
printed_expr
<<
"
\n
"
;
// Memoization is done in AllocVar.
// Memoization is done in AllocVar.
return
memo_
[
expr
];
return
memo_
[
expr
];
}
else
{
}
else
if
(
inline_expr
)
{
memo_
[
expr
]
=
printed_expr
;
memo_
[
expr
]
=
printed_expr
;
if
(
GNF_
&&
expr
.
as
<
CallNode
>
())
{
printed_expr
<<
PrintOptionalInfo
(
expr
);
}
return
printed_expr
;
return
printed_expr
;
}
else
{
Doc
temp_var
=
AllocTemp
();
memo_
[
expr
]
=
temp_var
;
doc_stack_
.
back
()
<<
temp_var
<<
" = "
<<
printed_expr
<<
"
\n
"
;
return
temp_var
;
}
}
}
}
...
@@ -420,8 +443,9 @@ class PrettyPrinter :
...
@@ -420,8 +443,9 @@ class PrettyPrinter :
Doc
VisitExpr_
(
const
LetNode
*
op
)
final
{
Doc
VisitExpr_
(
const
LetNode
*
op
)
final
{
Doc
doc
;
Doc
doc
;
doc
<<
"let "
<<
AllocVar
(
op
->
var
)
<<
" = "
<<
Print
(
op
->
value
)
<<
"
\n
"
;
doc
<<
"let "
<<
AllocVar
(
op
->
var
)
<<
" = "
<<
Print
(
op
->
value
,
false
,
true
)
<<
"
\n
"
;
// we use a scope here so GNF hoisting doesn't escape too far
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc
<<
PrintScope
(
op
->
body
);
doc
<<
PrintScope
(
op
->
body
);
return
doc
;
return
doc
;
}
}
...
@@ -456,6 +480,8 @@ class PrettyPrinter :
...
@@ -456,6 +480,8 @@ class PrettyPrinter :
Doc
doc
;
Doc
doc
;
int
counter
=
0
;
int
counter
=
0
;
for
(
const
auto
&
kv
:
mod
->
functions
)
{
for
(
const
auto
&
kv
:
mod
->
functions
)
{
dg_
=
DependencyGraph
::
Create
(
&
arena_
,
kv
.
second
);
std
::
ostringstream
os
;
std
::
ostringstream
os
;
if
(
counter
++
!=
0
)
{
if
(
counter
++
!=
0
)
{
doc
<<
"
\n
"
;
doc
<<
"
\n
"
;
...
@@ -664,8 +690,6 @@ class PrettyPrinter :
...
@@ -664,8 +690,6 @@ class PrettyPrinter :
}
}
private
:
private
:
/*! \brief Whether to use GNF. */
bool
GNF_
;
/*! \brief Whether to print meta data. */
/*! \brief Whether to print meta data. */
bool
show_meta_data_
;
bool
show_meta_data_
;
/*! \brief additional comment function */
/*! \brief additional comment function */
...
@@ -682,6 +706,10 @@ class PrettyPrinter :
...
@@ -682,6 +706,10 @@ class PrettyPrinter :
TextMetaDataContext
meta_
;
TextMetaDataContext
meta_
;
/*! \brief counter of temporary variable */
/*! \brief counter of temporary variable */
size_t
temp_var_counter_
{
0
};
size_t
temp_var_counter_
{
0
};
/*! \brief arena for dependency graph */
common
::
Arena
arena_
;
/*! \brief dependency graph of the expr */
DependencyGraph
dg_
;
class
AttrPrinter
;
class
AttrPrinter
;
friend
class
AttrPrinter
;
friend
class
AttrPrinter
;
};
};
...
@@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
...
@@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
std
::
string
PrettyPrint_
(
const
NodeRef
&
node
,
std
::
string
PrettyPrint_
(
const
NodeRef
&
node
,
bool
show_meta_data
,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
bool
gnf
)
{
Doc
doc
;
Doc
doc
;
doc
<<
"v0.0.1"
<<
"
\n
"
doc
<<
"v0.0.1"
<<
"
\n
"
<<
PrettyPrinter
(
gnf
,
show_meta_data
,
annotate
).
PrintFinal
(
node
);
<<
PrettyPrinter
(
show_meta_data
,
annotate
).
PrintFinal
(
node
);
return
doc
.
str
();
return
doc
.
str
();
}
}
std
::
string
AsText
(
const
NodeRef
&
node
,
std
::
string
AsText
(
const
NodeRef
&
node
,
bool
show_meta_data
,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
return
PrettyPrint_
(
node
,
show_meta_data
,
annotate
,
true
);
return
PrettyPrint_
(
node
,
show_meta_data
,
annotate
);
}
std
::
string
PassDebugPrint
(
const
NodeRef
&
node
,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
,
bool
gnf
)
{
return
PrettyPrint_
(
node
,
show_meta_data
,
annotate
,
gnf
);
}
}
TVM_REGISTER_API
(
"relay._expr.AsText"
)
TVM_REGISTER_API
(
"relay._expr.AsText"
)
...
@@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText")
...
@@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText")
bool
,
bool
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
AsText
);
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
AsText
);
TVM_REGISTER_API
(
"relay._ir_pass.pass_debug_print"
)
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
,
bool
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
,
bool
)
>
(
PassDebugPrint
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/dependency_graph.cc
0 → 100644
View file @
dc97e527
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/pass/dependency_graph.cc
* \brief
*/
#include "dependency_graph.h"
#include <tvm/relay/expr_functor.h>
#include <unordered_set>
#include <utility>
namespace
tvm
{
namespace
relay
{
// Creator of DependencyGraph
class
DependencyGraph
::
Creator
:
private
ExprFunctor
<
void
(
const
Expr
&
e
)
>
{
public
:
explicit
Creator
(
common
::
Arena
*
arena
)
:
arena_
(
arena
)
{}
DependencyGraph
Create
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
graph_
);
}
private
:
/*! \brief allocator of all the internal node object */
common
::
Arena
*
arena_
;
// The output.
DependencyGraph
graph_
;
// Update the message stored at the node.
void
Depend
(
DependencyGraph
::
Node
*
parent
,
const
Expr
&
child
)
{
VisitExpr
(
child
);
CHECK_NE
(
graph_
.
expr_node
.
count
(
child
),
0
);
Depend
(
parent
,
graph_
.
expr_node
[
child
]);
}
void
Depend
(
DependencyGraph
::
Node
*
parent
,
DependencyGraph
::
Node
*
child
)
{
auto
*
parent_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
parent_link
->
value
=
parent
;
child
->
parents
.
Push
(
parent_link
);
auto
*
child_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
child_link
->
value
=
child
;
parent
->
children
.
Push
(
child_link
);
}
std
::
unordered_set
<
Expr
,
NodeHash
,
NodeEqual
>
visited_
;
DependencyGraph
::
Node
*
NewNode
(
bool
new_scope
)
{
auto
*
ret
=
arena_
->
make
<
DependencyGraph
::
Node
>
();
ret
->
new_scope
=
new_scope
;
return
ret
;
}
void
VisitExpr
(
const
Expr
&
e
)
final
{
if
(
visited_
.
count
(
e
)
==
0
)
{
if
(
graph_
.
expr_node
.
count
(
e
)
==
0
)
{
graph_
.
expr_node
[
e
]
=
NewNode
(
false
);
}
visited_
.
insert
(
e
);
ExprFunctor
<
void
(
const
Expr
&
)
>::
VisitExpr
(
e
);
graph_
.
post_dfs_order
.
push_back
(
graph_
.
expr_node
[
e
]);
}
}
void
VisitExpr_
(
const
CallNode
*
c
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
c
)];
Depend
(
n
,
c
->
op
);
for
(
const
auto
&
a
:
c
->
args
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
for
(
const
auto
&
a
:
t
->
fields
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleGetItemNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
Depend
(
n
,
t
->
tuple
);
}
void
VisitExpr_
(
const
RefCreateNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
RefReadNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
}
void
VisitExpr_
(
const
RefWriteNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
IfNode
*
i
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
i
)];
DependencyGraph
::
Node
*
t
=
NewNode
(
true
);
DependencyGraph
::
Node
*
f
=
NewNode
(
true
);
Depend
(
n
,
i
->
cond
);
Depend
(
n
,
t
);
Depend
(
n
,
f
);
Depend
(
t
,
i
->
true_branch
);
Depend
(
f
,
i
->
false_branch
);
graph_
.
post_dfs_order
.
push_back
(
f
);
graph_
.
post_dfs_order
.
push_back
(
t
);
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
f
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
f
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
l
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
l
->
value
);
Depend
(
b
,
l
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
MatchNode
*
m
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
m
)];
Depend
(
n
,
m
->
data
);
std
::
vector
<
DependencyGraph
::
Node
*>
v
;
for
(
const
Clause
&
c
:
m
->
clauses
)
{
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
c
->
rhs
);
v
.
push_back
(
b
);
}
for
(
auto
it
=
v
.
rbegin
();
it
!=
v
.
rend
();
++
it
)
{
graph_
.
post_dfs_order
.
push_back
(
*
it
);
}
}
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
GlobalVarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
ConstantNode
*
c
)
final
{
}
void
VisitExpr_
(
const
OpNode
*
o
)
final
{
}
void
VisitExpr_
(
const
ConstructorNode
*
c
)
final
{
}
};
DependencyGraph
DependencyGraph
::
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
)
{
return
Creator
(
arena
).
Create
(
body
);
}
}
// namespace relay
}
// namespace tvm
src/relay/pass/dependency_graph.h
0 → 100644
View file @
dc97e527
/*!
* Copyright (c) 2019 by Contributors.
* \file tvm/relay/pass/dependency_graph.h
* \brief
*/
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#include <tvm/relay/expr.h>
#include <unordered_map>
#include <vector>
#include "let_list.h"
#include "../../common/arena.h"
namespace
tvm
{
namespace
relay
{
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class
DependencyGraph
{
public
:
/*! \brief A node in the graph. */
struct
Node
{
// Determine scope boundaries. Used for calculating scopes, not for
// constructing dependency graph.
bool
new_scope
=
false
;
// incoming edges
LinkedList
<
Node
*>
children
;
// outgoing edges
LinkedList
<
Node
*>
parents
;
};
/*! \brief Maps a Relay Expr to its node in the dependency graph. */
std
::
unordered_map
<
Expr
,
Node
*
,
NodeHash
,
NodeEqual
>
expr_node
;
/*! \brief The dependency graph in post DFS order. */
std
::
vector
<
Node
*>
post_dfs_order
;
/*!
* \brief Create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static
DependencyGraph
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
);
private
:
class
Creator
;
};
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
src/relay/pass/to_a_normal_form.cc
View file @
dc97e527
...
@@ -29,193 +29,11 @@
...
@@ -29,193 +29,11 @@
#include "let_list.h"
#include "let_list.h"
#include "../../common/arena.h"
#include "../../common/arena.h"
#include "pass_util.h"
#include "pass_util.h"
#include "dependency_graph.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class
DependencyGraph
{
public
:
/*! \brief A node in the graph. */
struct
Node
{
bool
new_scope
=
false
;
LinkedList
<
Node
*>
input
;
LinkedList
<
Node
*>
output
;
};
/*! \brief The node map that maps node to graph */
std
::
unordered_map
<
Expr
,
Node
*
,
NodeHash
,
NodeEqual
>
expr_node
;
/*! \brief All the nodes in post DFS order */
std
::
vector
<
Node
*>
post_dfs_order
;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static
DependencyGraph
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
);
private
:
class
Creator
;
};
// Creator of DependencyGraph
class
DependencyGraph
::
Creator
:
private
ExprFunctor
<
void
(
const
Expr
&
e
)
>
{
public
:
explicit
Creator
(
common
::
Arena
*
arena
)
:
arena_
(
arena
)
{}
DependencyGraph
Create
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
graph_
);
}
private
:
/*! \brief allocator of all the internal node object */
common
::
Arena
*
arena_
;
// The output.
DependencyGraph
graph_
;
// Update the message stored at the node.
void
Depend
(
DependencyGraph
::
Node
*
parent
,
const
Expr
&
child
)
{
VisitExpr
(
child
);
CHECK_NE
(
graph_
.
expr_node
.
count
(
child
),
0
);
Depend
(
parent
,
graph_
.
expr_node
[
child
]);
}
void
Depend
(
DependencyGraph
::
Node
*
parent
,
DependencyGraph
::
Node
*
child
)
{
auto
*
parent_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
parent_link
->
value
=
parent
;
child
->
output
.
Push
(
parent_link
);
auto
*
child_link
=
arena_
->
make
<
LinkNode
<
DependencyGraph
::
Node
*>
>
();
child_link
->
value
=
child
;
parent
->
input
.
Push
(
child_link
);
}
std
::
unordered_set
<
Expr
,
NodeHash
,
NodeEqual
>
visited_
;
DependencyGraph
::
Node
*
NewNode
(
bool
new_scope
)
{
auto
*
ret
=
arena_
->
make
<
DependencyGraph
::
Node
>
();
ret
->
new_scope
=
new_scope
;
return
ret
;
}
void
VisitExpr
(
const
Expr
&
e
)
final
{
if
(
visited_
.
count
(
e
)
==
0
)
{
if
(
graph_
.
expr_node
.
count
(
e
)
==
0
)
{
graph_
.
expr_node
[
e
]
=
NewNode
(
false
);
}
visited_
.
insert
(
e
);
ExprFunctor
<
void
(
const
Expr
&
)
>::
VisitExpr
(
e
);
graph_
.
post_dfs_order
.
push_back
(
graph_
.
expr_node
[
e
]);
}
}
void
VisitExpr_
(
const
CallNode
*
c
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
c
)];
Depend
(
n
,
c
->
op
);
for
(
const
auto
&
a
:
c
->
args
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
for
(
const
auto
&
a
:
t
->
fields
)
{
Depend
(
n
,
a
);
}
}
void
VisitExpr_
(
const
TupleGetItemNode
*
t
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
t
)];
Depend
(
n
,
t
->
tuple
);
}
void
VisitExpr_
(
const
RefCreateNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
RefReadNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
}
void
VisitExpr_
(
const
RefWriteNode
*
r
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
r
)];
Depend
(
n
,
r
->
ref
);
Depend
(
n
,
r
->
value
);
}
void
VisitExpr_
(
const
IfNode
*
i
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
i
)];
DependencyGraph
::
Node
*
t
=
NewNode
(
true
);
DependencyGraph
::
Node
*
f
=
NewNode
(
true
);
Depend
(
n
,
i
->
cond
);
Depend
(
n
,
t
);
Depend
(
n
,
f
);
Depend
(
t
,
i
->
true_branch
);
Depend
(
f
,
i
->
false_branch
);
graph_
.
post_dfs_order
.
push_back
(
f
);
graph_
.
post_dfs_order
.
push_back
(
t
);
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
f
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
f
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
l
)];
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
l
->
value
);
Depend
(
b
,
l
->
body
);
graph_
.
post_dfs_order
.
push_back
(
b
);
}
void
VisitExpr_
(
const
MatchNode
*
m
)
final
{
DependencyGraph
::
Node
*
n
=
graph_
.
expr_node
[
GetRef
<
Expr
>
(
m
)];
Depend
(
n
,
m
->
data
);
std
::
vector
<
DependencyGraph
::
Node
*>
v
;
for
(
const
Clause
&
c
:
m
->
clauses
)
{
DependencyGraph
::
Node
*
b
=
NewNode
(
true
);
Depend
(
n
,
b
);
Depend
(
b
,
c
->
rhs
);
v
.
push_back
(
b
);
}
for
(
auto
it
=
v
.
rbegin
();
it
!=
v
.
rend
();
++
it
)
{
graph_
.
post_dfs_order
.
push_back
(
*
it
);
}
}
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
GlobalVarNode
*
v
)
final
{
}
void
VisitExpr_
(
const
ConstantNode
*
c
)
final
{
}
void
VisitExpr_
(
const
OpNode
*
o
)
final
{
}
void
VisitExpr_
(
const
ConstructorNode
*
c
)
final
{
}
};
DependencyGraph
DependencyGraph
::
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
)
{
return
Creator
(
arena
).
Create
(
body
);
}
Expr
ToANormalForm
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
);
Expr
ToANormalForm
(
const
Expr
&
e
,
const
Module
&
m
,
std
::
set
<
GlobalVar
>*
gv
);
struct
ScopeNode
;
struct
ScopeNode
;
...
@@ -256,7 +74,7 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
...
@@ -256,7 +74,7 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
Scope
global_scope
=
std
::
make_shared
<
ScopeNode
>
();
Scope
global_scope
=
std
::
make_shared
<
ScopeNode
>
();
for
(
auto
it
=
dg
.
post_dfs_order
.
rbegin
();
it
!=
dg
.
post_dfs_order
.
rend
();
++
it
)
{
for
(
auto
it
=
dg
.
post_dfs_order
.
rbegin
();
it
!=
dg
.
post_dfs_order
.
rend
();
++
it
)
{
DependencyGraph
::
Node
*
n
=
*
it
;
DependencyGraph
::
Node
*
n
=
*
it
;
auto
iit
=
n
->
output
.
head
;
auto
iit
=
n
->
parents
.
head
;
Scope
s
;
Scope
s
;
if
(
iit
==
nullptr
)
{
if
(
iit
==
nullptr
)
{
s
=
global_scope
;
s
=
global_scope
;
...
@@ -313,7 +131,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
...
@@ -313,7 +131,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Scope
GetSubScope
(
const
Expr
&
e
,
size_t
i
)
{
Scope
GetSubScope
(
const
Expr
&
e
,
size_t
i
)
{
DependencyGraph
::
Node
*
n
=
dg_
.
expr_node
.
at
(
e
);
DependencyGraph
::
Node
*
n
=
dg_
.
expr_node
.
at
(
e
);
auto
h
=
n
->
input
.
head
;
auto
h
=
n
->
children
.
head
;
while
(
i
!=
0
)
{
while
(
i
!=
0
)
{
CHECK
(
h
);
CHECK
(
h
);
--
i
;
--
i
;
...
...
tests/python/relay/test_ir_text_printer.py
View file @
dc97e527
...
@@ -50,8 +50,8 @@ def test_env():
...
@@ -50,8 +50,8 @@ def test_env():
text
=
env
.
astext
()
text
=
env
.
astext
()
assert
"def @myf"
in
text
assert
"def @myf"
in
text
assert
"def @myf"
in
str
(
env
)
assert
"def @myf"
in
str
(
env
)
assert
"
%1
= add(
%0
,
%0
) // ty=float32
"
in
text
assert
"
add(
%0
,
%0
) /* ty=float32 */
"
in
text
assert
"
%1
= add(
%0
,
%0
) // ty=float32
"
in
str
(
env
)
assert
"
add(
%0
,
%0
) /* ty=float32 */
"
in
str
(
env
)
show
(
env
.
astext
(
annotate
=
lambda
x
:
str
(
x
.
checked_type
.
dtype
)))
show
(
env
.
astext
(
annotate
=
lambda
x
:
str
(
x
.
checked_type
.
dtype
)))
show
(
text
)
show
(
text
)
...
@@ -112,7 +112,7 @@ def test_let_if_scope():
...
@@ -112,7 +112,7 @@ def test_let_if_scope():
f
=
relay
.
Function
([
x
,
y
,
cond
],
result
)
f
=
relay
.
Function
([
x
,
y
,
cond
],
result
)
text
=
f
.
astext
()
text
=
f
.
astext
()
assert
text
.
count
(
"{"
)
==
6
assert
text
.
count
(
"{"
)
==
4
assert
"
%
cond: bool"
in
text
assert
"
%
cond: bool"
in
text
show
(
f
.
astext
())
show
(
f
.
astext
())
...
@@ -180,8 +180,19 @@ def test_call_node_order():
...
@@ -180,8 +180,19 @@ def test_call_node_order():
"
%2
= fn (
%
x) {
\n
"
"
%2
= fn (
%
x) {
\n
"
"
%
x
\n
"
"
%
x
\n
"
"}
\n
"
"}
\n
"
"
%3
=
%2
(
%1
)
\n
"
"
%2
(
%1
)"
)
"
%3
"
)
def
test_let_inlining
():
tup
=
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
0
)])
x
=
relay
.
var
(
"x"
)
assert
relay
.
Let
(
x
,
tup
,
tup
)
.
astext
()
==
SEMVER
+
\
(
"
%0
= (0, 0)
\n
"
"let
%
x =
%0
\n
"
"
%0
"
)
assert
relay
.
Let
(
x
,
tup
,
x
)
.
astext
()
==
SEMVER
+
\
(
"let
%
x = (0, 0)
\n
"
"
%
x"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
do_print
[
0
]
=
True
do_print
[
0
]
=
True
...
@@ -201,3 +212,4 @@ if __name__ == "__main__":
...
@@ -201,3 +212,4 @@ if __name__ == "__main__":
test_let_if_scope
()
test_let_if_scope
()
test_variable_name
()
test_variable_name
()
test_call_node_order
()
test_call_node_order
()
test_let_inlining
()
tests/python/relay/test_op_level1.py
View file @
dc97e527
...
@@ -38,7 +38,7 @@ def test_unary_op():
...
@@ -38,7 +38,7 @@ def test_unary_op():
x
=
relay
.
var
(
"x"
,
tp
)
x
=
relay
.
var
(
"x"
,
tp
)
y
=
opfunc
(
x
)
y
=
opfunc
(
x
)
# test printer
# test printer
assert
(
"
%0
=
{}(
%
x)"
.
format
(
y
.
op
.
name
))
in
y
.
astext
()
assert
(
"{}(
%
x)"
.
format
(
y
.
op
.
name
))
in
y
.
astext
()
# test type inference
# test type inference
assert
relay
.
ir_pass
.
infer_type
(
y
)
.
checked_type
==
tp
assert
relay
.
ir_pass
.
infer_type
(
y
)
.
checked_type
==
tp
...
@@ -78,7 +78,7 @@ def test_binary_op():
...
@@ -78,7 +78,7 @@ def test_binary_op():
y
=
relay
.
var
(
"y"
,
t2
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
opfunc
(
x
,
y
)
z
=
opfunc
(
x
,
y
)
# test printer
# test printer
assert
(
"
%0
=
{}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
(
"{}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
if
ref
is
not
None
:
if
ref
is
not
None
:
...
...
tests/python/relay/test_op_level4.py
View file @
dc97e527
...
@@ -29,7 +29,7 @@ def test_binary_op():
...
@@ -29,7 +29,7 @@ def test_binary_op():
y
=
relay
.
var
(
"y"
,
t2
)
y
=
relay
.
var
(
"y"
,
t2
)
z
=
opfunc
(
x
,
y
)
z
=
opfunc
(
x
,
y
)
# test printer
# test printer
assert
(
"
%0
=
{}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
(
"{}(
%
x,
%
y)"
.
format
(
z
.
op
.
name
))
in
z
.
astext
()
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
assert
relay
.
ir_pass
.
infer_type
(
z
)
.
checked_type
==
t1
if
ref
is
not
None
:
if
ref
is
not
None
:
...
...
tests/python/relay/test_type_infer.py
View file @
dc97e527
...
@@ -44,7 +44,7 @@ def initialize_box_adt(mod):
...
@@ -44,7 +44,7 @@ def initialize_box_adt(mod):
def
test_monomorphic_let
():
def
test_monomorphic_let
():
"Program: let
x = 1;
x"
"Program: let
%
x = 1;
%
x"
sb
=
relay
.
ScopeBuilder
()
sb
=
relay
.
ScopeBuilder
()
x
=
sb
.
let
(
'x'
,
relay
.
const
(
1.0
,
"float64"
))
x
=
sb
.
let
(
'x'
,
relay
.
const
(
1.0
,
"float64"
))
sb
.
ret
(
x
)
sb
.
ret
(
x
)
...
@@ -53,7 +53,7 @@ def test_monomorphic_let():
...
@@ -53,7 +53,7 @@ def test_monomorphic_let():
def
test_single_op
():
def
test_single_op
():
"Program: fn (
x : float32) { let t1 = f(x);
t1 }"
"Program: fn (
%
x : float32) { let
%
t1 = f(
%
x);
%
t1 }"
x
=
relay
.
var
(
'x'
,
shape
=
[])
x
=
relay
.
var
(
'x'
,
shape
=
[])
func
=
relay
.
Function
([
x
],
op
.
log
(
x
))
func
=
relay
.
Function
([
x
],
op
.
log
(
x
))
ttype
=
relay
.
TensorType
([],
dtype
=
'float32'
)
ttype
=
relay
.
TensorType
([],
dtype
=
'float32'
)
...
@@ -63,8 +63,9 @@ def test_single_op():
...
@@ -63,8 +63,9 @@ def test_single_op():
def
test_add_broadcast_op
():
def
test_add_broadcast_op
():
"""
"""
Program:
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
fn (
%
x: Tensor[(10, 4), float32],
%
y: Tensor[(5, 10, 1), float32])
x + y
-> Tensor[(5, 10, 4), float32] {
%
x +
%
y
}
}
"""
"""
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
4
))
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
4
))
...
@@ -80,10 +81,10 @@ def test_add_broadcast_op():
...
@@ -80,10 +81,10 @@ def test_add_broadcast_op():
def
test_dual_op
():
def
test_dual_op
():
"""Program:
"""Program:
fn (
x : Tensor[f32, (10, 10)
]) {
fn (
%
x : Tensor[(10, 10), float32
]) {
let t1 = log(x);
let
%
t1 = log(x);
let
t2 = add(t1,
x);
let
%
t2 = add(
%
t1,
%
x);
t1
%
t1
}
}
"""
"""
tp
=
relay
.
TensorType
((
10
,
10
),
"float32"
)
tp
=
relay
.
TensorType
((
10
,
10
),
"float32"
)
...
@@ -99,8 +100,8 @@ def test_dual_op():
...
@@ -99,8 +100,8 @@ def test_dual_op():
def
test_decl
():
def
test_decl
():
"""Program:
"""Program:
def
f(x : Tensor[(10, 10), f
32]) {
def
@f(
%
x : Tensor[(10, 10), float
32]) {
log(x)
log(
%
x)
}
}
"""
"""
tp
=
relay
.
TensorType
((
10
,
10
))
tp
=
relay
.
TensorType
((
10
,
10
))
...
@@ -113,11 +114,11 @@ def test_decl():
...
@@ -113,11 +114,11 @@ def test_decl():
def
test_recursion
():
def
test_recursion
():
"""
"""
Program:
Program:
def
f(n: i32, data: f32) -> f
32 {
def
@f(
%
n: int32,
%
data: float32) -> float
32 {
if (n == 0) {
if (
%
n == 0) {
data
%
data
} else {
} else {
f(n - 1, log(
data))
@f(
%
n - 1, log(
%
data))
}
}
}
}
"""
"""
...
@@ -134,7 +135,7 @@ def test_recursion():
...
@@ -134,7 +135,7 @@ def test_recursion():
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
sb
.
ret
(
f
(
relay
.
subtract
(
n
,
relay
.
const
(
1
,
ti32
)),
relay
.
log
(
data
)))
mod
=
relay
.
Module
()
mod
=
relay
.
Module
()
mod
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
mod
[
f
]
=
relay
.
Function
([
n
,
data
],
sb
.
get
())
assert
"
%3
= @f(
%1
,
%2
)
"
in
mod
.
astext
()
assert
"
@f(
%1
,
%2
) /* ty=float32 */
"
in
mod
.
astext
()
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
assert
mod
[
f
]
.
checked_type
==
relay
.
FuncType
([
ti32
,
tf32
],
tf32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment