Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c113712d
Unverified
Commit
c113712d
authored
Nov 19, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][BACKEND] Enable PlanMemory in the graph runtime. (#2120)
parent
6edb3564
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
450 additions
and
25 deletions
+450
-25
include/tvm/relay/expr.h
+2
-0
python/tvm/relay/backend/_backend.py
+1
-0
python/tvm/relay/backend/graph_runtime_codegen.py
+25
-8
python/tvm/relay/base.py
+12
-2
src/relay/backend/graph_plan_memory.cc
+349
-0
src/relay/ir/text_printer.cc
+27
-12
src/relay/pass/fuse_ops.cc
+1
-1
tests/python/relay/test_backend_graph_runtime.py
+33
-2
No files found.
include/tvm/relay/expr.h
View file @
c113712d
...
@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
...
@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
/*!
/*!
* \brief Print node as text format.
* \brief Print node as text format.
* \param node The node to be printed.
* \param node The node to be printed.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* additional comment block to an expr.
* \return The text representation.
* \return The text representation.
*/
*/
std
::
string
RelayPrint
(
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
const
NodeRef
&
node
,
bool
show_meta_data
=
true
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
=
nullptr
);
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
=
nullptr
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
python/tvm/relay/backend/_backend.py
View file @
c113712d
...
@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
...
@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
funcs : List[tvm.LoweredFunc]
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
The list of lowered functions.
target : tvm.Target
target : tvm.Target
The target to run the code on.
The target to run the code on.
...
...
python/tvm/relay/backend/graph_runtime_codegen.py
View file @
c113712d
...
@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
...
@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
json
import
json
import
attr
import
attr
from
.
import
_backend
from
.
import
compile_engine
from
.
import
compile_engine
from
..op
import
Op
from
..op
import
Op
from
..expr
import
Function
,
GlobalVar
,
ExprFunctor
from
..expr
import
Function
,
GlobalVar
,
ExprFunctor
...
@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
self
.
nodes
=
[]
self
.
nodes
=
[]
self
.
var_map
=
{}
self
.
var_map
=
{}
self
.
params
=
{}
self
.
params
=
{}
self
.
storage_map
=
None
self
.
compile_engine
=
compile_engine
.
get
()
self
.
compile_engine
=
compile_engine
.
get
()
self
.
lowered_funcs
=
set
()
self
.
lowered_funcs
=
set
()
self
.
_name_map
=
{}
self
.
_name_map
=
{}
def
add_node
(
self
,
node
,
checked_type
):
def
add_node
(
self
,
node
,
expr
):
"""
"""
Add a node to the graph.
Add a node to the graph.
...
@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
node: Node
node: Node
The node to add to the graph.
The node to add to the graph.
checked_type: Type
expr: tvm.relay.Expr
The
type of the node
.
The
corresponding expression
.
Returns
Returns
-------
-------
node_ref: Union[NodeRef, List[NodeRef]]
node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node.
A reference to the node.
"""
"""
checked_type
=
expr
.
checked_type
# setup storage ids
assert
expr
in
self
.
storage_map
node
.
attrs
[
"storage_id"
]
=
[
x
.
value
for
x
in
self
.
storage_map
[
expr
]
]
node_id
=
len
(
self
.
nodes
)
node_id
=
len
(
self
.
nodes
)
self
.
nodes
.
append
(
node
)
self
.
nodes
.
append
(
node
)
# Tuple return value, flatten as tuple
# Tuple return value, flatten as tuple
...
@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
name
=
"p
%
d"
%
index
name
=
"p
%
d"
%
index
self
.
params
[
name
]
=
op
.
data
self
.
params
[
name
]
=
op
.
data
node
=
InputNode
(
name
,
{})
node
=
InputNode
(
name
,
{})
return
self
.
add_node
(
node
,
op
.
checked_type
)
return
self
.
add_node
(
node
,
op
)
def
visit_function
(
self
,
_
):
def
visit_function
(
self
,
_
):
raise
RuntimeError
(
"function not supported"
)
raise
RuntimeError
(
"function not supported"
)
...
@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name
=
cached_func
.
func_name
op_name
=
cached_func
.
func_name
op_node
=
OpNode
(
self
.
_get_unique_name
(
op_name
),
{},
op_node
=
OpNode
(
self
.
_get_unique_name
(
op_name
),
{},
op_name
,
inputs
,
{})
op_name
,
inputs
,
{})
return
self
.
add_node
(
op_node
,
call
.
checked_type
)
return
self
.
add_node
(
op_node
,
call
)
def
_get_json
(
self
):
def
_get_json
(
self
):
"""
"""
...
@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
assert
node
.
num_outputs
==
len
(
node
.
attrs
[
"shape"
])
assert
node
.
num_outputs
==
len
(
node
.
attrs
[
"shape"
])
shapes
+=
node
.
attrs
[
"shape"
]
shapes
+=
node
.
attrs
[
"shape"
]
dltypes
+=
node
.
attrs
[
"dtype"
]
dltypes
+=
node
.
attrs
[
"dtype"
]
for
i
in
range
(
node
.
num_outputs
):
storage_ids
+=
node
.
attrs
[
"storage_id"
]
storage_ids
.
append
(
i
+
num_entry
)
num_entry
+=
node
.
num_outputs
num_entry
+=
node
.
num_outputs
node_row_ptr
.
append
(
num_entry
)
node_row_ptr
.
append
(
num_entry
)
...
@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
return
json
.
dumps
(
json_dict
,
indent
=
2
)
return
json
.
dumps
(
json_dict
,
indent
=
2
)
def
debug_dump_memory_plan
(
self
,
func
):
"""Debug function to dump memory plan."""
def
_annotate
(
expr
):
if
expr
in
self
.
storage_map
:
return
str
(
self
.
storage_map
[
expr
])
return
""
return
func
.
astext
(
show_meta_data
=
False
,
annotate
=
_annotate
)
def
codegen
(
self
,
func
):
def
codegen
(
self
,
func
):
"""Compile a single function into a graph.
"""Compile a single function into a graph.
...
@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
...
@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
params : Dict[str, tvm.nd.NDArray]
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
Additional constant parameters.
"""
"""
self
.
storage_map
=
_backend
.
GraphPlanMemory
(
func
)
# First we convert all the parameters into input nodes.
# First we convert all the parameters into input nodes.
for
param
in
func
.
params
:
for
param
in
func
.
params
:
node
=
InputNode
(
param
.
name_hint
,
{})
node
=
InputNode
(
param
.
name_hint
,
{})
self
.
var_map
[
param
]
=
self
.
add_node
(
self
.
var_map
[
param
]
=
self
.
add_node
(
node
,
param
.
type_annotation
)
node
,
param
)
# Then we compile the body into a graph which can depend
# Then we compile the body into a graph which can depend
# on input variables.
# on input variables.
...
...
python/tvm/relay/base.py
View file @
c113712d
...
@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
...
@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
class
RelayNode
(
NodeBase
):
class
RelayNode
(
NodeBase
):
"""Base class of all relay node."""
"""Base class of all relay node."""
def
astext
(
self
,
annotate
=
None
):
def
astext
(
self
,
show_meta_data
=
True
,
annotate
=
None
):
"""Get the text format of the expression.
"""Get the text format of the expression.
Returns
Returns
...
@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
...
@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
text : str
text : str
The text format of the expression.
The text format of the expression.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
Optional annotate function to provide additional
information in the comment block.
information in the comment block.
Note
----
meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big(constat weights),
so it can be helpful to skip printing the meta data section.
"""
"""
return
_expr
.
RelayPrint
(
self
,
annotate
)
return
_expr
.
RelayPrint
(
self
,
show_meta_data
,
annotate
)
@register_relay_node
@register_relay_node
...
...
src/relay/backend/graph_plan_memory.cc
0 → 100644
View file @
c113712d
This diff is collapsed.
Click to expand it.
src/relay/ir/text_printer.cc
View file @
c113712d
...
@@ -113,6 +113,11 @@ class TextMetaDataContext {
...
@@ -113,6 +113,11 @@ class TextMetaDataContext {
return
SaveJSON
(
Array
<
NodeRef
>
(
meta_data_
));
return
SaveJSON
(
Array
<
NodeRef
>
(
meta_data_
));
}
}
/*! \return whether the meta data context is empty. */
bool
empty
()
const
{
return
meta_data_
.
empty
();
}
private
:
private
:
/*! \brief additional metadata stored in TVM json format */
/*! \brief additional metadata stored in TVM json format */
std
::
vector
<
NodeRef
>
meta_data_
;
std
::
vector
<
NodeRef
>
meta_data_
;
...
@@ -125,8 +130,9 @@ class TextPrinter :
...
@@ -125,8 +130,9 @@ class TextPrinter :
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public:
public:
explicit
TextPrinter
(
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
explicit
TextPrinter
(
bool
show_meta_data
,
:
annotate_
(
annotate
)
{}
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
:
show_meta_data_
(
show_meta_data
),
annotate_
(
annotate
)
{}
/*!
/*!
* \brief Print a node to string.
* \brief Print a node to string.
* \param node.
* \param node.
...
@@ -144,13 +150,17 @@ class TextPrinter :
...
@@ -144,13 +150,17 @@ class TextPrinter :
}
else
{
}
else
{
stream_
<<
node
;
stream_
<<
node
;
}
}
std
::
string
meta_json
=
meta_
.
GetMetaSection
();
if
(
!
meta_
.
empty
())
{
if
(
meta_json
.
length
()
!=
0
)
{
if
(
show_meta_data_
)
{
// append meta data in the end.
std
::
string
meta_json
=
meta_
.
GetMetaSection
();
stream_
<<
"# meta data
\n
"
// append meta data in the end.
<<
"r
\"\"\"\n
"
stream_
<<
"# meta data
\n
"
<<
meta_json
<<
"
\n
"
<<
"r
\"\"\"\n
"
<<
"
\"\"\"
"
;
<<
meta_json
<<
"
\n
"
<<
"
\"\"\"
"
;
}
else
{
stream_
<<
"# meta data omitted. you can use show_meta_data=True to include meta-data
\n
"
;
}
}
}
return
stream_
.
str
();
return
stream_
.
str
();
}
}
...
@@ -227,7 +237,9 @@ class TextPrinter :
...
@@ -227,7 +237,9 @@ class TextPrinter :
TextValue
id
=
this
->
AllocTempVar
();
TextValue
id
=
this
->
AllocTempVar
();
this
->
PrintIndent
();
this
->
PrintIndent
();
stream_
<<
id
<<
" = "
<<
meta_
.
GetMetaNode
(
GetRef
<
NodeRef
>
(
op
));
stream_
<<
id
<<
" = "
<<
meta_
.
GetMetaNode
(
GetRef
<
NodeRef
>
(
op
));
this
->
PrintEndInst
(
"
\n
"
);
this
->
PrintEndInst
(
""
);
this
->
PrintOptionalInfo
(
GetRef
<
Expr
>
(
op
));
stream_
<<
'\n'
;
return
id
;
return
id
;
}
}
...
@@ -697,6 +709,8 @@ class TextPrinter :
...
@@ -697,6 +709,8 @@ class TextPrinter :
private
:
private
:
class
AttrPrinter
;
class
AttrPrinter
;
friend
class
AttrPrinter
;
friend
class
AttrPrinter
;
/*! \brief Whether to print meta data. */
bool
show_meta_data_
;
/*! \brief additional comment function */
/*! \brief additional comment function */
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate_
;
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate_
;
/*! \brief meta data context */
/*! \brief meta data context */
...
@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
...
@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
}
}
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
bool
show_meta_data
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
return
TextPrinter
(
annotate
).
Print
(
node
);
return
TextPrinter
(
show_meta_data
,
annotate
).
Print
(
node
);
}
}
TVM_REGISTER_API
(
"relay._expr.RelayPrint"
)
TVM_REGISTER_API
(
"relay._expr.RelayPrint"
)
.
set_body_typed
<
std
::
string
(
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
,
const
NodeRef
&
,
bool
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
RelayPrint
);
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
RelayPrint
);
}
// namespace relay
}
// namespace relay
...
...
src/relay/pass/fuse_ops.cc
View file @
c113712d
...
@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
...
@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
}
}
// Debug function, dump the group assignment in text.
// Debug function, dump the group assignment in text.
void
DebugDumpGroup
(
const
Expr
&
body
)
{
void
DebugDumpGroup
(
const
Expr
&
body
)
{
std
::
string
text
=
RelayPrint
(
body
,
[
this
](
const
Expr
&
expr
)
->
std
::
string
{
std
::
string
text
=
RelayPrint
(
body
,
false
,
[
this
](
const
Expr
&
expr
)
->
std
::
string
{
auto
it
=
gmap_
.
find
(
expr
.
get
());
auto
it
=
gmap_
.
find
(
expr
.
get
());
if
(
it
==
gmap_
.
end
())
return
""
;
if
(
it
==
gmap_
.
end
())
return
""
;
std
::
ostringstream
os
;
std
::
ostringstream
os
;
...
...
tests/python/relay/test_backend_graph_runtime.py
View file @
c113712d
...
@@ -77,7 +77,9 @@ def test_add_op_broadcast():
...
@@ -77,7 +77,9 @@ def test_add_op_broadcast():
def
test_with_params
():
def
test_with_params
():
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
5
))
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
5
))
y
=
relay
.
var
(
'y'
,
shape
=
(
1
,
5
))
y
=
relay
.
var
(
'y'
,
shape
=
(
1
,
5
))
func
=
relay
.
Function
([
x
,
y
],
add
(
x
,
y
))
z
=
relay
.
add
(
x
,
y
)
z
=
relay
.
exp
(
z
)
func
=
relay
.
Function
([
x
,
y
],
z
)
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
1
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
1
,
5
)
.
astype
(
'float32'
)
params
=
{
"y"
:
y_data
}
params
=
{
"y"
:
y_data
}
...
@@ -87,11 +89,40 @@ def test_with_params():
...
@@ -87,11 +89,40 @@ def test_with_params():
mod
.
set_input
(
x
=
x_data
)
mod
.
set_input
(
x
=
x_data
)
mod
.
run
()
mod
.
run
()
res
=
mod
.
get_output
(
0
)
.
asnumpy
()
res
=
mod
.
get_output
(
0
)
.
asnumpy
()
ref_res
=
y_data
+
x_data
ref_res
=
np
.
exp
(
y_data
+
x_data
)
tvm
.
testing
.
assert_allclose
(
res
,
ref_res
)
tvm
.
testing
.
assert_allclose
(
res
,
ref_res
)
def
test_plan_memory
():
# it is sufficient to cycle through two memories.
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,))
y
=
relay
.
var
(
"x"
,
shape
=
(
1
,))
y2
=
relay
.
exp
(
y
)
z
=
relay
.
add
(
x
,
y2
)
z
=
relay
.
exp
(
z
)
z
=
relay
.
exp
(
z
)
z
=
relay
.
exp
(
z
)
z
=
relay
.
exp
(
z
)
z
=
relay
.
exp
(
z
)
func
=
relay
.
Function
([
x
,
y
],
z
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
fuse_ops
(
func
,
opt_level
=
0
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
smap
=
relay
.
backend
.
_backend
.
GraphPlanMemory
(
func
)
storage_ids
=
set
()
for
k
,
v
in
smap
.
items
():
for
x
in
v
:
storage_ids
.
add
(
x
.
value
)
# Current rule requires vars have unique storage id
# because we don't do inplace, we will need another
# two alternating temporary space.
assert
len
(
storage_ids
)
==
4
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_plan_memory
()
test_with_params
()
test_with_params
()
test_add_op_scalar
()
test_add_op_scalar
()
test_add_op_tensor
()
test_add_op_tensor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment