Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d781a57f
Commit
d781a57f
authored
Sep 23, 2017
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Minimum runtime module (#31)
parent
343c19a5
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
612 additions
and
29 deletions
+612
-29
nnvm/deploy/nnvm_runtime.cc
+3
-1
nnvm/src/compiler/graph_fuse.cc
+2
-1
nnvm/src/compiler/graph_hash.cc
+1
-1
nnvm/src/runtime/graph_executor.cc
+2
-3
nnvm/src/runtime/graph_executor.h
+1
-23
nnvm/src/runtime/graph_runtime.cc
+563
-0
nnvm/src/runtime/graph_runtime.h
+40
-0
No files found.
nnvm/deploy/nnvm_runtime.cc
View file @
d781a57f
...
...
@@ -3,9 +3,11 @@
* All in one runtime
* \file nnvm_runtime.cc
*/
/*
#include "../src/core/graph.cc"
#include "../src/core/node.cc"
#include "../src/core/pass.cc"
#include "../src/core/op.cc"
#include "../src/pass/saveload_json.cc"
#include "../src/runtime/graph_executor.cc"
#include "../src/runtime/graph_executor.cc"*/
#include "../src/runtime/graph_runtime.cc"
nnvm/src/compiler/graph_fuse.cc
View file @
d781a57f
...
...
@@ -335,8 +335,9 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
{
// only copy over name since that is sufficient.
nnvm
::
NodePtr
np
=
nnvm
::
Node
::
Create
();
np
->
attrs
=
inode
.
source
->
attrs
;
np
->
attrs
.
name
=
inode
.
source
->
attrs
.
name
;
old_new
[
nid
]
=
np
;
continue
;
}
...
...
nnvm/src/compiler/graph_hash.cc
View file @
d781a57f
...
...
@@ -97,8 +97,8 @@ size_t GraphHash(const Graph& graph) {
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
IndexedGraph
::
Node
&
inode
=
idx
[
nid
];
// Use name instad op address so it is deterministic across runs
key
=
dmlc
::
HashCombine
(
key
,
inode
.
source
->
op
()
p
);
if
(
inode
.
source
->
is_variable
())
continue
;
key
=
dmlc
::
HashCombine
(
key
,
inode
.
source
->
op
()
->
name
);
hash_temp
.
clear
();
for
(
const
auto
&
kv
:
GetAttrDict
(
inode
.
source
->
attrs
))
{
hash_temp
.
push_back
(
dmlc
::
HashCombine
(
str_hash
(
kv
.
first
),
kv
.
second
));
...
...
nnvm/src/runtime/graph_executor.cc
View file @
d781a57f
...
...
@@ -299,7 +299,6 @@ inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
attrs
->
parsed
=
std
::
move
(
param
);
}
DMLC_REGISTER_PARAMETER
(
TVMOpParam
);
NNVM_REGISTER_OP
(
tvm_op
)
.
set_attr_parser
(
TVMOpParamParser
)
...
...
@@ -328,12 +327,12 @@ tvm::runtime::Module RuntimeCreate(std::string sym_json,
return
tvm
::
runtime
::
Module
(
exec
);
}
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.create"
)
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.create
x
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
RuntimeCreate
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.remote_create"
)
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.remote_create
x
"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
void
*
mhandle
=
args
[
1
];
*
rv
=
RuntimeCreate
(
args
[
0
],
...
...
nnvm/src/runtime/graph_executor.h
View file @
d781a57f
...
...
@@ -17,33 +17,11 @@
#include <nnvm/pass.h>
#include <vector>
#include <string>
#include "./graph_runtime.h"
namespace
nnvm
{
namespace
runtime
{
/*! \brief Magic number for NDArray file */
constexpr
uint64_t
kTVMNDArrayMagic
=
0xDD5E40F096B4A13F
;
/*! \brief Magic number for NDArray list file */
constexpr
uint64_t
kTVMNDArrayListMagic
=
0xF7E58D4F05049CB7
;
/*! \brief DLPack compatible data types */
using
DLTypeVector
=
std
::
vector
<
DLDataType
>
;
/*! \brief operator attributes about tvm op */
struct
TVMOpParam
:
public
dmlc
::
Parameter
<
TVMOpParam
>
{
std
::
string
func_name
;
uint32_t
num_inputs
;
uint32_t
num_outputs
;
bool
flatten_data
;
DMLC_DECLARE_PARAMETER
(
TVMOpParam
)
{
DMLC_DECLARE_FIELD
(
func_name
);
DMLC_DECLARE_FIELD
(
num_inputs
).
set_default
(
1
);
DMLC_DECLARE_FIELD
(
num_outputs
).
set_default
(
1
);
DMLC_DECLARE_FIELD
(
flatten_data
).
set_default
(
false
);
}
};
/*!
* \brief TVM Graph Executor.
* This is a minimum graph executor, embedded in TVM runtime
...
...
nnvm/src/runtime/graph_runtime.cc
0 → 100644
View file @
d781a57f
/*!
* Copyright (c) 2017 by Contributors
* \file graph_executor.cc
*/
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <numeric>
#include "./graph_runtime.h"
namespace
nnvm
{
namespace
runtime
{
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}
using
::
tvm
::
runtime
::
PackedFunc
;
using
::
tvm
::
runtime
::
TVMArgs
;
using
::
tvm
::
runtime
::
TVMRetValue
;
/*!
* \brief Minimum graph structure for deployment
* This is a minimum graph executor, embedded in TVM runtime
* without any framework dependency.
*
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
class
GraphRuntime
:
public
::
tvm
::
runtime
::
ModuleNode
{
public
:
~
GraphRuntime
()
{
for
(
DLTensor
*
t
:
storage_pool_
)
{
TVM_CCALL
(
TVMArrayFree
(
t
));
}
}
/*!
* \brief Get member function to front-end
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
tvm
::
runtime
::
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
ModuleNode
>&
sptr_to_self
)
final
;
/*!
* \return The type key of the executor.
*/
const
char
*
type_key
()
const
final
{
return
"GraphRuntime"
;
}
void
Run
()
{
// setup the array and requirements.
for
(
size_t
i
=
0
;
i
<
op_execs_
.
size
();
++
i
)
{
if
(
op_execs_
[
i
])
op_execs_
[
i
]();
}
}
/*!
* \brief Initialize the graph executor with graph and context.
* \param graph The execution graph.
* \param module The module containing the compiled functions.
* \param ctx The context where the graph should sit on
*/
void
Init
(
const
std
::
string
&
graph_json
,
tvm
::
runtime
::
Module
module
,
TVMContext
ctx
)
{
std
::
istringstream
is
(
graph_json
);
dmlc
::
JSONReader
reader
(
&
is
);
this
->
Load
(
&
reader
);
module_
=
module
;
ctx_
=
ctx
;
this
->
SetupStorage
();
this
->
SetupOpExecs
();
}
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
int
GetInputIndex
(
const
std
::
string
&
name
)
{
for
(
size_t
i
=
0
;
i
<
input_nodes_
.
size
();
++
i
)
{
uint32_t
nid
=
input_nodes_
[
i
];
if
(
nodes_
[
nid
].
name
==
name
)
{
return
static_cast
<
int
>
(
i
);
}
}
LOG
(
FATAL
)
<<
"cannot find "
<<
name
<<
" among input"
;
return
-
1
;
}
/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \param data The input data.
*/
void
SetInput
(
int
index
,
DLTensor
*
data_in
)
{
CHECK_LT
(
static_cast
<
size_t
>
(
index
),
input_nodes_
.
size
());
uint32_t
eid
=
this
->
entry_id
(
input_nodes_
[
index
],
0
);
TVM_CCALL
(
TVMArrayCopyFromTo
(
data_in
,
&
data_entry_
[
eid
],
nullptr
));
}
/*!
* \brief Copy index-th output to data_out.
* \param index The output index.
* \param data_out the output data.
*/
void
GetOutput
(
int
index
,
DLTensor
*
data_out
)
{
CHECK_LT
(
static_cast
<
size_t
>
(
index
),
outputs_
.
size
());
uint32_t
eid
=
this
->
entry_id
(
outputs_
[
index
]);
TVM_CCALL
(
TVMArrayCopyFromTo
(
&
data_entry_
[
eid
],
data_out
,
nullptr
));
}
/*!
* \brief Load parameters from binary stream
* \param strm The input stream.
*/
void
LoadParams
(
dmlc
::
Stream
*
strm
);
/*!
* \brief Load parameters from parameter blob.
* \param param_blob A binary blob of parameter.
*/
void
LoadParams
(
const
std
::
string
&
param_blob
)
{
dmlc
::
MemoryStringStream
strm
(
const_cast
<
std
::
string
*>
(
&
param_blob
));
this
->
LoadParams
(
&
strm
);
}
private
:
// Node entry
struct
NodeEntry
{
uint32_t
node_id
;
uint32_t
index
;
uint32_t
version
;
// JSON Loader
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
reader
->
BeginArray
();
CHECK
(
reader
->
NextArrayItem
())
<<
"invalid json format"
;
reader
->
Read
(
&
node_id
);
CHECK
(
reader
->
NextArrayItem
())
<<
"invalid json format"
;
reader
->
Read
(
&
index
);
if
(
reader
->
NextArrayItem
())
{
reader
->
Read
(
&
version
);
CHECK
(
!
reader
->
NextArrayItem
())
<<
"invalid json format"
;
}
else
{
version
=
0
;
}
}
};
// Node
struct
Node
{
// operator type in string
std
::
string
op_type
;
// name of the op
std
::
string
name
;
// parameters
TVMOpParam
param
;
// inputs
std
::
vector
<
NodeEntry
>
inputs
;
// control deps
std
::
vector
<
uint32_t
>
control_deps
;
// JSON Loader
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
reader
->
BeginObject
();
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dict
;
int
bitmask
=
0
;
std
::
string
key
;
while
(
reader
->
NextObjectItem
(
&
key
))
{
if
(
key
==
"op"
)
{
reader
->
Read
(
&
op_type
);
bitmask
|=
1
;
}
else
if
(
key
==
"name"
)
{
reader
->
Read
(
&
name
);
bitmask
|=
2
;
}
else
if
(
key
==
"inputs"
)
{
reader
->
Read
(
&
inputs
);
bitmask
|=
4
;
}
else
if
(
key
==
"attr"
)
{
reader
->
Read
(
&
dict
);
param
.
Init
(
dict
);
}
else
if
(
key
==
"control_deps"
)
{
reader
->
Read
(
&
control_deps
);
}
else
{
LOG
(
FATAL
)
<<
"do not support key"
<<
key
;
}
}
CHECK_EQ
(
bitmask
,
1
|
2
|
4
)
<<
"invalid format"
;
}
};
struct
GraphAttr
{
size_t
storage_num_not_alloctaed
{
0
};
std
::
vector
<
int
>
storage_id
;
std
::
vector
<
std
::
string
>
dltype
;
std
::
vector
<
std
::
vector
<
int64_t
>
>
shape
;
// The graph attribute fields.
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
reader
->
BeginObject
();
int
bitmask
=
0
;
std
::
string
key
,
type
;
while
(
reader
->
NextObjectItem
(
&
key
))
{
if
(
key
==
"dltype"
)
{
reader
->
BeginArray
();
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
type
);
CHECK_EQ
(
type
,
"list_str"
);
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
dltype
);
CHECK
(
!
reader
->
NextArrayItem
());
bitmask
|=
1
;
}
else
if
(
key
==
"storage_id"
)
{
reader
->
BeginArray
();
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
type
);
CHECK_EQ
(
type
,
"list_int"
);
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
storage_id
);
CHECK
(
!
reader
->
NextArrayItem
());
bitmask
|=
2
;
}
else
if
(
key
==
"shape"
)
{
reader
->
BeginArray
();
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
type
);
CHECK_EQ
(
type
,
"list_shape"
);
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
shape
);
CHECK
(
!
reader
->
NextArrayItem
());
bitmask
|=
4
;
}
else
{
reader
->
BeginArray
();
CHECK
(
reader
->
NextArrayItem
());
reader
->
Read
(
&
type
);
if
(
type
==
"list_int"
)
{
CHECK
(
reader
->
NextArrayItem
());
std
::
vector
<
int
>
temp
;
reader
->
Read
(
&
temp
);
}
else
if
(
type
==
"size_t"
)
{
CHECK
(
reader
->
NextArrayItem
());
size_t
temp
;
reader
->
Read
(
&
temp
);
}
else
{
LOG
(
FATAL
)
<<
"cannot skip graph attr "
<<
key
;
}
CHECK
(
!
reader
->
NextArrayItem
());
}
}
CHECK_EQ
(
bitmask
,
1
|
2
|
4
)
<<
"invalid format"
;
}
};
// The graph attribute fields.
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
reader
->
BeginObject
();
int
bitmask
=
0
;
std
::
string
key
;
while
(
reader
->
NextObjectItem
(
&
key
))
{
if
(
key
==
"nodes"
)
{
reader
->
Read
(
&
nodes_
);
bitmask
|=
1
;
}
else
if
(
key
==
"arg_nodes"
)
{
reader
->
Read
(
&
input_nodes_
);
bitmask
|=
2
;
}
else
if
(
key
==
"node_row_ptr"
)
{
reader
->
Read
(
&
node_row_ptr_
);
bitmask
|=
4
;
}
else
if
(
key
==
"heads"
)
{
reader
->
Read
(
&
outputs_
);
bitmask
|=
8
;
}
else
if
(
key
==
"attrs"
)
{
reader
->
Read
(
&
attrs_
);
bitmask
|=
16
;
}
}
CHECK_EQ
(
bitmask
,
1
|
2
|
4
|
8
|
16
)
<<
"invalid format"
;
}
bool
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
);
/*! \brief Setup the temporal storage */
void
SetupStorage
();
/*! \brief Setup the executors */
void
SetupOpExecs
();
/*!
* \brief Create a executtion function given input.
* \param attrs The node attributes
* \param args The arguments to the functor, including inputs and outputs.
* \param num_inputs Number of inputs
* \return The created executor.
*/
std
::
function
<
void
()
>
CreateTVMOp
(
const
TVMOpParam
&
attrs
,
const
std
::
vector
<
DLTensor
>&
args
,
size_t
num_inputs
);
// Get node entry index.
uint32_t
entry_id
(
uint32_t
nid
,
uint32_t
index
)
const
{
return
node_row_ptr_
[
nid
]
+
index
;
}
// Get node entry index.
uint32_t
entry_id
(
const
NodeEntry
&
e
)
const
{
return
entry_id
(
e
.
node_id
,
e
.
index
);
}
// Number of node entries
uint32_t
num_node_entries
()
const
{
return
node_row_ptr_
.
back
();
}
// Number of nodes.
uint32_t
num_nodes
()
const
{
return
static_cast
<
uint32_t
>
(
nodes_
.
size
());
}
// The graph nodes.
std
::
vector
<
Node
>
nodes_
;
// The argument nodes.
std
::
vector
<
uint32_t
>
input_nodes_
;
// used or quick entry indexing
std
::
vector
<
uint32_t
>
node_row_ptr_
;
// output entries
std
::
vector
<
NodeEntry
>
outputs_
;
// Additional graph attributes
GraphAttr
attrs_
;
/*! \brief The code module */
tvm
::
runtime
::
Module
module_
;
/*! \brief execution context */
TVMContext
ctx_
;
/*! \brief common storage pool */
std
::
vector
<
DLTensor
*>
storage_pool_
;
/*! \brief data entry of each node */
std
::
vector
<
DLTensor
>
data_entry_
;
/*! \brief operator on each node */
std
::
vector
<
std
::
function
<
void
()
>
>
op_execs_
;
};
DMLC_REGISTER_PARAMETER
(
TVMOpParam
);
bool
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
,
sizeof
(
header
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
,
sizeof
(
reserved
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
->
ctx
,
sizeof
(
tensor
->
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
->
ndim
,
sizeof
(
tensor
->
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
tensor
->
dtype
,
sizeof
(
tensor
->
dtype
)))
<<
"Invalid DLTensor file format"
;
int
ndim
=
tensor
->
ndim
;
CHECK
(
strm
->
Read
(
tensor
->
shape
,
sizeof
(
int64_t
)
*
ndim
))
<<
"Invalid DLTensor file format"
;
int64_t
size
=
1
;
int
type_size
=
tensor
->
dtype
.
bits
/
8
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
size
*=
tensor
->
shape
[
i
];
}
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
,
sizeof
(
data_byte_size
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
type_size
*
size
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
tensor
->
data
,
type_size
*
size
))
<<
"Invalid DLTensor file format"
;
return
true
;
}
void
GraphRuntime
::
LoadParams
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid parameters file format"
;
CHECK
(
header
==
kTVMNDArrayListMagic
)
<<
"Invalid parameters file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid parameters file format"
;
std
::
vector
<
std
::
string
>
names
;
CHECK
(
strm
->
Read
(
&
names
))
<<
"Invalid parameters file format"
;
uint64_t
sz
;
strm
->
Read
(
&
sz
,
sizeof
(
sz
));
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
uint32_t
in_idx
=
GetInputIndex
(
names
[
i
]);
CHECK
(
LoadDLTensor
(
strm
,
&
data_entry_
[
this
->
entry_id
(
input_nodes_
[
in_idx
],
0
)]))
<<
"Invalid parameters file format"
;
}
}
void
GraphRuntime
::
SetupStorage
()
{
// Grab saved optimization plan from graph.
std
::
vector
<
TVMType
>
vtype
;
for
(
const
std
::
string
&
s_type
:
attrs_
.
dltype
)
{
vtype
.
push_back
(
tvm
::
runtime
::
String2TVMType
(
s_type
));
}
data_entry_
.
resize
(
num_node_entries
());
// Find the maximum space size.
int
max_id
=
0
;
for
(
size_t
i
=
0
;
i
<
attrs_
.
shape
.
size
();
++
i
)
{
max_id
=
std
::
max
(
attrs_
.
storage_id
[
i
]
+
1
,
max_id
);
}
for
(
uint32_t
nid
:
input_nodes_
)
{
attrs_
.
storage_id
[
this
->
entry_id
(
nid
,
0
)]
=
max_id
++
;
}
// size of each storage pool entry
std
::
vector
<
size_t
>
pool_entry_bytes
;
// Find the maximum space size.
for
(
size_t
i
=
0
;
i
<
attrs_
.
shape
.
size
();
++
i
)
{
int
storage_id
=
attrs_
.
storage_id
[
i
];
size_t
size
=
1
;
for
(
int64_t
sz
:
attrs_
.
shape
[
i
])
{
size
*=
static_cast
<
size_t
>
(
sz
);
}
CHECK_GE
(
storage_id
,
0
)
<<
"Do not support runtime shape op"
;
DLDataType
t
=
vtype
[
i
];
size_t
bits
=
t
.
bits
*
t
.
lanes
;
CHECK_EQ
(
bits
%
8U
,
0U
);
size_t
bytes
=
(
bits
/
8U
)
*
size
;
size_t
sid
=
static_cast
<
size_t
>
(
storage_id
);
if
(
sid
>=
pool_entry_bytes
.
size
())
{
pool_entry_bytes
.
resize
(
sid
+
1
,
0
);
}
pool_entry_bytes
[
sid
]
=
std
::
max
(
pool_entry_bytes
[
sid
],
bytes
);
}
// Allocate the space.
for
(
size_t
i
=
0
;
i
<
pool_entry_bytes
.
size
();
++
i
)
{
int64_t
shape
[]
=
{
static_cast
<
int64_t
>
(
pool_entry_bytes
[
i
]
+
3
)
/
4
};
DLTensor
*
tensor
;
TVM_CCALL
(
TVMArrayAlloc
(
shape
,
1
,
kFloat
,
32
,
1
,
ctx_
.
device_type
,
ctx_
.
device_id
,
&
tensor
));
storage_pool_
.
push_back
(
tensor
);
}
// Assign the pooled entries.
for
(
size_t
i
=
0
;
i
<
data_entry_
.
size
();
++
i
)
{
int
storage_id
=
attrs_
.
storage_id
[
i
];
data_entry_
[
i
]
=
*
storage_pool_
[
storage_id
];
data_entry_
[
i
].
shape
=
const_cast
<
int64_t
*>
(
attrs_
.
shape
[
i
].
data
());
data_entry_
[
i
].
ndim
=
static_cast
<
int
>
(
attrs_
.
shape
[
i
].
size
());
data_entry_
[
i
].
dtype
=
vtype
[
i
];
}
}
/*! \brief Setup the executors */
void
GraphRuntime
::
SetupOpExecs
()
{
op_execs_
.
resize
(
this
->
num_nodes
());
// setup the array and requirements.
for
(
uint32_t
nid
=
0
;
nid
<
this
->
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
nodes_
[
nid
];
if
(
inode
.
op_type
==
"null"
)
continue
;
std
::
vector
<
DLTensor
>
args
;
for
(
const
auto
&
e
:
inode
.
inputs
)
{
args
.
push_back
(
data_entry_
[
this
->
entry_id
(
e
)]);
}
for
(
uint32_t
index
=
0
;
index
<
inode
.
param
.
num_outputs
;
++
index
)
{
uint32_t
eid
=
this
->
entry_id
(
nid
,
index
);
args
.
push_back
(
data_entry_
[
eid
]);
}
CHECK_EQ
(
inode
.
op_type
,
"tvm_op"
)
<<
"transform the graph to tvm op5A"
;
op_execs_
[
nid
]
=
CreateTVMOp
(
inode
.
param
,
args
,
inode
.
inputs
.
size
());
}
}
std
::
function
<
void
()
>
GraphRuntime
::
CreateTVMOp
(
const
TVMOpParam
&
param
,
const
std
::
vector
<
DLTensor
>&
args
,
size_t
num_inputs
)
{
struct
OpArgs
{
std
::
vector
<
DLTensor
>
args
;
std
::
vector
<
TVMValue
>
arg_values
;
std
::
vector
<
int
>
arg_tcodes
;
std
::
vector
<
int64_t
>
shape_data
;
};
std
::
shared_ptr
<
OpArgs
>
arg_ptr
=
std
::
make_shared
<
OpArgs
>
();
// setup address.
arg_ptr
->
args
=
std
::
move
(
args
);
if
(
param
.
flatten_data
)
{
arg_ptr
->
shape_data
.
resize
(
arg_ptr
->
args
.
size
());
}
for
(
size_t
i
=
0
;
i
<
arg_ptr
->
args
.
size
();
++
i
)
{
TVMValue
v
;
DLTensor
*
t
=
&
(
arg_ptr
->
args
[
i
]);
v
.
v_handle
=
t
;
arg_ptr
->
arg_values
.
push_back
(
v
);
arg_ptr
->
arg_tcodes
.
push_back
(
kArrayHandle
);
if
(
param
.
flatten_data
)
{
arg_ptr
->
shape_data
[
i
]
=
std
::
accumulate
(
t
->
shape
,
t
->
shape
+
t
->
ndim
,
1
,
std
::
multiplies
<
int64_t
>
());
t
->
ndim
=
1
;
t
->
shape
=
&
(
arg_ptr
->
shape_data
[
i
]);
}
}
// get compiled function from module.
tvm
::
runtime
::
PackedFunc
pf
=
module_
.
GetFunction
(
param
.
func_name
,
false
);
CHECK
(
pf
!=
nullptr
)
<<
"no such function in module: "
<<
param
.
func_name
;
auto
fexec
=
[
arg_ptr
,
pf
]
()
{
TVMRetValue
rv
;
TVMArgs
targs
(
arg_ptr
->
arg_values
.
data
(),
arg_ptr
->
arg_tcodes
.
data
(),
static_cast
<
int
>
(
arg_ptr
->
arg_values
.
size
()));
pf
.
CallPacked
(
targs
,
&
rv
);
};
return
fexec
;
}
PackedFunc
GraphRuntime
::
GetFunction
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
ModuleNode
>&
sptr_to_self
)
{
// return member functions during query.
if
(
name
==
"set_input"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
if
(
args
[
0
].
type_code
()
==
kStr
)
{
this
->
SetInput
(
this
->
GetInputIndex
(
args
[
0
]),
args
[
1
]);
}
else
{
this
->
SetInput
(
args
[
0
],
args
[
1
]);
}
});
}
else
if
(
name
==
"get_output"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
this
->
GetOutput
(
args
[
0
],
args
[
1
]);
});
}
else
if
(
name
==
"run"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
this
->
Run
();
});
}
else
if
(
name
==
"load_params"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
this
->
LoadParams
(
args
[
0
].
operator
std
::
string
());
});
}
else
{
return
PackedFunc
();
}
}
tvm
::
runtime
::
Module
GraphRuntimeCreate
(
std
::
string
sym_json
,
tvm
::
runtime
::
Module
m
,
int
device_type
,
int
device_id
)
{
TVMContext
ctx
;
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
ctx
.
device_id
=
device_id
;
std
::
shared_ptr
<
GraphRuntime
>
exec
=
std
::
make_shared
<
GraphRuntime
>
();
exec
->
Init
(
sym_json
,
m
,
ctx
);
return
tvm
::
runtime
::
Module
(
exec
);
}
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.runtime.remote_create"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
void
*
mhandle
=
args
[
1
];
*
rv
=
GraphRuntimeCreate
(
args
[
0
],
*
static_cast
<
tvm
::
runtime
::
Module
*>
(
mhandle
),
args
[
2
],
args
[
3
]);
});
}
// namespace runtime
}
// namespace nnvm
nnvm/src/runtime/graph_runtime.h
0 → 100644
View file @
d781a57f
/*!
* Copyright (c) 2017 by Contributors
*
* Runtime module for graph deployment.
*
* \file graph_executor.h
*/
#ifndef NNVM_RUNTIME_GRAPH_RUNTIME_H_
#define NNVM_RUNTIME_GRAPH_RUNTIME_H_
#include <dmlc/parameter.h>
#include <string>
namespace
nnvm
{
namespace
runtime
{
/*! \brief Magic number for NDArray file */
constexpr
uint64_t
kTVMNDArrayMagic
=
0xDD5E40F096B4A13F
;
/*! \brief Magic number for NDArray list file */
constexpr
uint64_t
kTVMNDArrayListMagic
=
0xF7E58D4F05049CB7
;
/*! \brief operator attributes about tvm op */
struct
TVMOpParam
:
public
dmlc
::
Parameter
<
TVMOpParam
>
{
std
::
string
func_name
;
uint32_t
num_inputs
;
uint32_t
num_outputs
;
uint32_t
flatten_data
;
DMLC_DECLARE_PARAMETER
(
TVMOpParam
)
{
DMLC_DECLARE_FIELD
(
func_name
);
DMLC_DECLARE_FIELD
(
num_inputs
).
set_default
(
1
);
DMLC_DECLARE_FIELD
(
num_outputs
).
set_default
(
1
);
DMLC_DECLARE_FIELD
(
flatten_data
).
set_default
(
0
);
}
};
}
// namespace runtime
}
// namespace nnvm
#endif // NNVM_RUNTIME_GRAPH_RUNTIME_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment