Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
72fa4c1d
Commit
72fa4c1d
authored
Jul 18, 2018
by
Tianqi Chen
Committed by
GitHub
Jul 18, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE][REFLECTION] Support NDArray as field (#1452)
parent
6fbda22d
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
582 additions
and
170 deletions
+582
-170
HalideIR
+1
-1
include/tvm/runtime/ndarray.h
+138
-5
include/tvm/runtime/serializer.h
+1
-0
nnvm/python/nnvm/compiler/param_dict.py
+2
-10
nnvm/src/compiler/graph_runtime.cc
+12
-97
nnvm/src/compiler/graph_runtime.h
+24
-2
nnvm/tests/python/compiler/test_param_dict.py
+20
-0
python/tvm/_ffi/_ctypes/function.py
+1
-0
python/tvm/_ffi/_cython/node.pxi
+2
-0
python/tvm/_ffi/ndarray.py
+26
-0
src/api/dsl_api.cc
+12
-3
src/common/base64.h
+284
-0
src/lang/reflection.cc
+55
-3
src/runtime/graph/graph_runtime.cc
+4
-47
src/runtime/graph/graph_runtime.h
+0
-2
No files found.
HalideIR
@
a5a80bdc
Subproject commit
9204453ae8de77e7dfc32c4d80f58dd788ad75ff
Subproject commit
a5a80bdc8232c9dbfe508bb5c46e8f58cdf7ec20
include/tvm/runtime/ndarray.h
View file @
72fa4c1d
...
...
@@ -10,6 +10,7 @@
#include <vector>
#include <utility>
#include "./c_runtime_api.h"
#include "./serializer.h"
namespace
tvm
{
namespace
runtime
{
...
...
@@ -103,8 +104,25 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline
void
CopyTo
(
DLTensor
*
other
);
inline
void
CopyTo
(
const
NDArray
&
other
);
inline
void
CopyTo
(
DLTensor
*
other
)
const
;
inline
void
CopyTo
(
const
NDArray
&
other
)
const
;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline
NDArray
CopyTo
(
const
DLContext
&
ctx
)
const
;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
* \return Whether load is successful
*/
inline
bool
Load
(
dmlc
::
Stream
*
stream
);
/*!
* \brief Save NDArray to stream
* \param stream The output data stream
*/
inline
void
Save
(
dmlc
::
Stream
*
stream
)
const
;
/*!
* \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array.
...
...
@@ -162,6 +180,13 @@ class NDArray {
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
const
DLTensor
*
tensor
);
/*!
* \brief Reference counted Container object used to back NDArray.
*
* This object is DLTensor compatible:
...
...
@@ -260,17 +285,26 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFromTo
(
&
(
other
.
data_
->
dl_tensor
),
&
(
data_
->
dl_tensor
));
}
inline
void
NDArray
::
CopyTo
(
DLTensor
*
other
)
{
inline
void
NDArray
::
CopyTo
(
DLTensor
*
other
)
const
{
CHECK
(
data_
!=
nullptr
);
CopyFromTo
(
&
(
data_
->
dl_tensor
),
other
);
}
inline
void
NDArray
::
CopyTo
(
const
NDArray
&
other
)
{
inline
void
NDArray
::
CopyTo
(
const
NDArray
&
other
)
const
{
CHECK
(
data_
!=
nullptr
);
CHECK
(
other
.
data_
!=
nullptr
);
CopyFromTo
(
&
(
data_
->
dl_tensor
),
&
(
other
.
data_
->
dl_tensor
));
}
inline
NDArray
NDArray
::
CopyTo
(
const
DLContext
&
ctx
)
const
{
CHECK
(
data_
!=
nullptr
);
const
DLTensor
*
dptr
=
operator
->
();
NDArray
ret
=
Empty
(
std
::
vector
<
int64_t
>
(
dptr
->
shape
,
dptr
->
shape
+
dptr
->
ndim
),
dptr
->
dtype
,
ctx
);
this
->
CopyTo
(
ret
);
return
ret
;
}
inline
int
NDArray
::
use_count
()
const
{
if
(
data_
==
nullptr
)
return
0
;
return
data_
->
ref_counter_
.
load
(
std
::
memory_order_relaxed
);
...
...
@@ -280,7 +314,106 @@ inline const DLTensor* NDArray::operator->() const {
return
&
(
data_
->
dl_tensor
);
}
/*! \brief Magic number for NDArray file */
constexpr
uint64_t
kTVMNDArrayMagic
=
0xDD5E40F096B4A13F
;
inline
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
header
);
strm
->
Write
(
reserved
);
// Always save data as CPU context
//
// Parameters that get serialized should be in CPU by default.
// So even the array's context is GPU, it will be stored as CPU array.
// This is used to prevent case when another user loads the parameters
// back on machine that do not have GPU or related context.
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kDLCPU
;
cpu_ctx
.
device_id
=
0
;
strm
->
Write
(
cpu_ctx
);
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
tensor
->
dtype
);
int
ndim
=
tensor
->
ndim
;
strm
->
WriteArray
(
tensor
->
shape
,
ndim
);
int
type_bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
num_elems
*=
tensor
->
shape
[
i
];
}
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
data_byte_size
);
if
(
DMLC_IO_NO_ENDIAN_SWAP
&&
tensor
->
ctx
.
device_type
==
kDLCPU
&&
tensor
->
strides
==
nullptr
&&
tensor
->
byte_offset
==
0
)
{
// quick path
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
);
CHECK_EQ
(
TVMArrayCopyToBytes
(
tensor
,
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
),
0
)
<<
TVMGetLastError
();
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
}
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
return
true
;
}
inline
void
NDArray
::
Save
(
dmlc
::
Stream
*
strm
)
const
{
SaveDLTensor
(
strm
,
const_cast
<
DLTensor
*>
(
operator
->
()));
}
inline
bool
NDArray
::
Load
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLContext
ctx
;
int
ndim
;
DLDataType
dtype
;
CHECK
(
strm
->
Read
(
&
ctx
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
ndim
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
dtype
))
<<
"Invalid DLTensor file format"
;
CHECK_EQ
(
ctx
.
device_type
,
kDLCPU
)
<<
"Invalid DLTensor context: can only save as CPU tensor"
;
std
::
vector
<
int64_t
>
shape
(
ndim
);
if
(
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
ndim
))
<<
"Invalid DLTensor file format"
;
}
NDArray
ret
=
NDArray
::
Empty
(
shape
,
dtype
,
ctx
);
int64_t
num_elems
=
1
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
num_elems
*=
ret
->
shape
[
i
];
}
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_size
))
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
*
this
=
ret
;
return
true
;
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_
include/tvm/runtime/serializer.h
View file @
72fa4c1d
...
...
@@ -10,6 +10,7 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "./c_runtime_api.h"
#include "./ndarray.h"
namespace
dmlc
{
namespace
serializer
{
...
...
nnvm/python/nnvm/compiler/param_dict.py
View file @
72fa4c1d
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import
ctypes
import
tvm
from
tvm._ffi.runtime_ctypes
import
TVMArrayHandle
_save_param_dict
=
tvm
.
get_global_func
(
"nnvm.compiler._save_param_dict"
)
_load_param_dict
=
tvm
.
get_global_func
(
"nnvm.compiler._load_param_dict"
)
...
...
@@ -59,11 +57,5 @@ def load_param_dict(param_bytes):
"""
if
isinstance
(
param_bytes
,
(
bytes
,
str
)):
param_bytes
=
bytearray
(
param_bytes
)
load_mod
=
_load_param_dict
(
param_bytes
)
size
=
load_mod
(
0
)
param_dict
=
{}
for
i
in
range
(
size
):
key
=
load_mod
(
1
,
i
)
dltensor_handle
=
ctypes
.
cast
(
load_mod
(
2
,
i
),
TVMArrayHandle
)
param_dict
[
key
]
=
tvm
.
nd
.
NDArray
(
dltensor_handle
,
False
)
return
param_dict
load_arr
=
_load_param_dict
(
param_bytes
)
return
{
v
.
name
:
v
.
array
for
v
in
load_arr
}
nnvm/src/compiler/graph_runtime.cc
View file @
72fa4c1d
...
...
@@ -4,10 +4,6 @@
* \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/serializer.h>
#include "./graph_runtime.h"
namespace
nnvm
{
...
...
@@ -37,81 +33,6 @@ NNVM_REGISTER_OP(tvm_op)
return
param
.
num_outputs
;
});
bool
SaveDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
tensor
)
{
uint64_t
header
=
kTVMNDArrayMagic
,
reserved
=
0
;
strm
->
Write
(
header
);
strm
->
Write
(
reserved
);
strm
->
Write
(
tensor
->
ctx
);
strm
->
Write
(
tensor
->
ndim
);
strm
->
Write
(
tensor
->
dtype
);
int
ndim
=
tensor
->
ndim
;
strm
->
WriteArray
(
tensor
->
shape
,
ndim
);
int
type_bytes
=
tensor
->
dtype
.
bits
/
8
;
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
num_elems
*=
tensor
->
shape
[
i
];
}
int64_t
data_byte_size
=
type_bytes
*
num_elems
;
strm
->
Write
(
data_byte_size
);
// handle endianness of data correctly.
if
(
DMLC_IO_NO_ENDIAN_SWAP
)
{
strm
->
Write
(
tensor
->
data
,
data_byte_size
);
}
else
{
uint8_t
*
dptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
data
);
std
::
vector
<
uint8_t
>
bytes
(
dptr
,
dptr
+
data_byte_size
);
dmlc
::
ByteSwap
(
dmlc
::
BeginPtr
(
bytes
),
type_bytes
,
num_elems
);
strm
->
Write
(
dmlc
::
BeginPtr
(
bytes
),
data_byte_size
);
}
return
true
;
}
DLTensor
*
LoadDLTensor
(
dmlc
::
Stream
*
strm
)
{
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
DLTensor
*
ret
;
CHECK_EQ
(
TVMArrayAlloc
(
shape
.
data
(),
tensor
.
ndim
,
tensor
.
dtype
.
code
,
tensor
.
dtype
.
bits
,
tensor
.
dtype
.
lanes
,
static_cast
<
int
>
(
tensor
.
ctx
.
device_type
),
tensor
.
ctx
.
device_id
,
&
ret
),
0
)
<<
TVMGetLastError
();
int64_t
num_elems
=
1
;
int
elem_bytes
=
(
ret
->
dtype
.
bits
+
7
)
/
8
;
for
(
int
i
=
0
;
i
<
ret
->
ndim
;
++
i
)
{
num_elems
*=
ret
->
shape
[
i
];
}
int64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK
(
data_byte_size
==
num_elems
*
elem_bytes
)
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
ret
->
data
,
data_byte_size
))
<<
"Invalid DLTensor file format"
;
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
ret
->
data
,
elem_bytes
,
num_elems
);
}
return
ret
;
}
TVM_REGISTER_GLOBAL
(
"nnvm.compiler._save_param_dict"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
...
...
@@ -136,7 +57,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
uint64_t
sz
=
static_cast
<
uint64_t
>
(
arrays
.
size
());
fo
->
Write
(
sz
);
for
(
size_t
i
=
0
;
i
<
sz
;
++
i
)
{
SaveDLTensor
(
fo
,
arrays
[
i
]);
tvm
::
runtime
::
SaveDLTensor
(
fo
,
arrays
[
i
]);
}
}
TVMByteArray
arr
;
...
...
@@ -149,11 +70,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
TVM_REGISTER_GLOBAL
(
"nnvm.compiler._load_param_dict"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
bytes
=
args
[
0
];
std
::
vector
<
DLTensor
*>
data
;
std
::
vector
<
std
::
string
>
names
;
dmlc
::
MemoryStringStream
memstrm
(
&
bytes
);
dmlc
::
Stream
*
strm
=
&
memstrm
;
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid parameters file format"
;
...
...
@@ -168,23 +87,19 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
size_t
size
=
static_cast
<
size_t
>
(
sz
);
CHECK
(
size
==
names
.
size
())
<<
"Invalid parameters file format"
;
tvm
::
Array
<
NDArrayWrapper
>
ret
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
data
.
push_back
(
LoadDLTensor
(
strm
));
}
auto
packed
=
[
data
,
names
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
int
code
=
args
[
0
];
if
(
code
==
0
)
{
*
rv
=
static_cast
<
int64_t
>
(
data
.
size
());
}
else
if
(
code
==
1
)
{
int
index
=
args
[
1
];
*
rv
=
names
[
index
];
}
else
{
CHECK_EQ
(
code
,
2
);
int
index
=
args
[
1
];
*
rv
=
static_cast
<
void
*>
(
data
[
index
]);
tvm
::
runtime
::
NDArray
temp
;
temp
.
Load
(
strm
);
std
::
shared_ptr
<
NDArrayWrapperNode
>
n
=
std
::
make_shared
<
NDArrayWrapperNode
>
();
n
->
name
=
std
::
move
(
names
[
i
]);
n
->
array
=
temp
;
ret
.
push_back
(
NDArrayWrapper
(
n
));
}
};
*
rv
=
PackedFunc
(
packed
);
*
rv
=
ret
;
});
TVM_REGISTER_NODE_TYPE
(
NDArrayWrapperNode
);
}
// namespace compiler
}
// namespace nnvm
nnvm/src/compiler/graph_runtime.h
View file @
72fa4c1d
...
...
@@ -7,14 +7,16 @@
#define NNVM_COMPILER_GRAPH_RUNTIME_H_
#include <nnvm/graph.h>
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <vector>
#include <string>
namespace
nnvm
{
namespace
compiler
{
/*! \brief Magic number for NDArray file */
constexpr
uint64_t
kTVMNDArrayMagic
=
0xDD5E40F096B4A13F
;
/*! \brief Magic number for NDArray list file */
constexpr
uint64_t
kTVMNDArrayListMagic
=
0xF7E58D4F05049CB7
;
...
...
@@ -32,6 +34,26 @@ struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
}
};
/*!
* \brief wrapper node container for exchange.
*/
struct
NDArrayWrapperNode
:
public
::
tvm
::
Node
{
std
::
string
name
;
tvm
::
runtime
::
NDArray
array
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"array"
,
&
array
);
}
static
constexpr
const
char
*
_type_key
=
"NDArrayWrapper"
;
TVM_DECLARE_NODE_TYPE_INFO
(
NDArrayWrapperNode
,
Node
);
};
TVM_DEFINE_NODE_REF
(
NDArrayWrapper
,
NDArrayWrapperNode
);
}
// namespace compiler
}
// namespace nnvm
#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_
nnvm/tests/python/compiler/test_param_dict.py
View file @
72fa4c1d
...
...
@@ -2,6 +2,9 @@ import os
import
numpy
as
np
import
nnvm.compiler
import
tvm
import
json
import
base64
from
tvm._ffi.base
import
py_str
from
tvm
import
rpc
from
tvm.contrib
import
util
,
graph_runtime
...
...
@@ -20,6 +23,22 @@ def test_save_load():
np
.
testing
.
assert_equal
(
param2
[
"y"
]
.
asnumpy
(),
y
)
def
test_ndarray_reflection
():
x
=
np
.
random
.
uniform
(
size
=
(
10
,
2
))
.
astype
(
"float32"
)
xx
=
tvm
.
nd
.
array
(
x
)
xnode
=
tvm
.
make
.
node
(
"NDArrayWrapper"
,
name
=
"xx"
,
array
=
xx
)
xnode2
=
tvm
.
make
.
node
(
"NDArrayWrapper"
,
name
=
"x2"
,
array
=
xx
)
assert
xnode
.
array
.
same_as
(
xx
)
json_str
=
tvm
.
save_json
([
xnode
,
xnode2
])
json_dict
=
json
.
loads
(
json_str
)
b64_str
=
json_dict
[
"b64ndarrays"
][
0
]
decoded
=
py_str
(
base64
.
b64encode
(
base64
.
b64decode
(
b64_str
)))
assert
b64_str
==
decoded
xlist
=
tvm
.
load_json
(
json_str
)
np
.
testing
.
assert_equal
(
xlist
[
0
]
.
array
.
asnumpy
(),
xx
.
asnumpy
())
assert
xlist
[
1
]
.
array
==
xlist
[
0
]
.
array
def
test_bigendian_rpc_param
():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host
=
os
.
environ
.
get
(
"TVM_POWERPC_TEST_HOST"
,
None
)
...
...
@@ -60,5 +79,6 @@ def test_bigendian_rpc_param():
if
__name__
==
"__main__"
:
test_ndarray_reflection
()
test_save_load
()
test_bigendian_rpc_param
()
python/tvm/_ffi/_ctypes/function.py
View file @
72fa4c1d
...
...
@@ -204,6 +204,7 @@ def _handle_return_func(x):
# setup return handle for function type
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
...
...
python/tvm/_ffi/_cython/node.pxi
View file @
72fa4c1d
...
...
@@ -23,6 +23,8 @@ cdef inline object make_ret_node(void* chandle):
obj = cls(None)
else:
obj = NodeBase(None)
else:
obj = NodeBase(None)
(<NodeBase>obj).chandle = chandle
return obj
...
...
python/tvm/_ffi/ndarray.py
View file @
72fa4c1d
...
...
@@ -134,6 +134,32 @@ class NDArrayBase(_NDArrayBase):
"""context of this array"""
return
self
.
ctx
def
__hash__
(
self
):
return
ctypes
.
cast
(
self
.
handle
,
ctypes
.
c_void_p
)
.
value
def
__eq__
(
self
,
other
):
return
self
.
same_as
(
other
)
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
same_as
(
self
,
other
):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if
not
isinstance
(
other
,
NDArrayBase
):
return
False
return
self
.
__hash__
()
==
other
.
__hash__
()
def
__setitem__
(
self
,
in_slice
,
value
):
"""Set ndarray value"""
if
(
not
isinstance
(
in_slice
,
slice
)
or
...
...
src/api/dsl_api.cc
View file @
72fa4c1d
...
...
@@ -32,7 +32,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct
APIAttrGetter
:
public
AttrVisitor
{
std
::
string
skey
;
TVMRetValue
*
ret
;
bool
found_
node_ref
{
false
};
bool
found_
ref_object
{
false
};
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
if
(
skey
==
key
)
*
ret
=
value
[
0
];
...
...
@@ -63,7 +63,13 @@ struct APIAttrGetter : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
if
(
skey
==
key
)
{
*
ret
=
value
[
0
];
found_node_ref
=
true
;
found_ref_object
=
true
;
}
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
if
(
skey
==
key
)
{
*
ret
=
value
[
0
];
found_ref_object
=
true
;
}
}
};
...
...
@@ -98,6 +104,9 @@ struct APIAttrDir : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
names
->
push_back
(
key
);
}
};
class
DSLAPIImpl
:
public
DSLAPI
{
...
...
@@ -130,7 +139,7 @@ class DSLAPIImpl : public DSLAPI {
*
ret_success
=
1
;
}
else
{
(
*
tnode
)
->
VisitAttrs
(
&
getter
);
*
ret_success
=
getter
.
found_
node_ref
||
rv
.
type_code
()
!=
kNull
;
*
ret_success
=
getter
.
found_
ref_object
||
rv
.
type_code
()
!=
kNull
;
if
(
rv
.
type_code
()
==
kStr
||
rv
.
type_code
()
==
kTVMType
)
{
TVMAPIThreadLocalEntry
*
e
=
TVMAPIThreadLocalStore
::
Get
();
...
...
src/common/base64.h
0 → 100644
View file @
72fa4c1d
/*!
* Copyright 2018 by Contributors
*
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
*/
#ifndef TVM_COMMON_BASE64_H_
#define TVM_COMMON_BASE64_H_
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <cctype>
#include <cstdio>
#include <string>
namespace
tvm
{
namespace
common
{
/*! \brief namespace of base64 decoding and encoding table */
namespace
base64
{
// decoding table
const
char
DecodeTable
[]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
62
,
// '+'
0
,
0
,
0
,
63
,
// '/'
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
,
60
,
61
,
// '0'-'9'
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
// 'A'-'Z'
0
,
0
,
0
,
0
,
0
,
0
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
,
50
,
51
,
// 'a'-'z'
};
// encoding table
static
const
char
EncodeTable
[]
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
;
}
// namespace base64
/*!
* \brief Buffer reader from stream to avoid
* virtual call overhead on each read.
*/
class
StreamBufferReader
{
public
:
explicit
StreamBufferReader
(
size_t
buffer_size
)
{
buffer_
.
resize
(
buffer_size
);
}
/*!
* \brief set input stream
* \param stream The stream to be set
*/
void
set_stream
(
dmlc
::
Stream
*
stream
)
{
stream_
=
stream
;
read_len_
=
read_ptr_
=
1
;
}
/*!
* \return allows quick read using get char
*/
char
GetChar
()
{
while
(
true
)
{
if
(
read_ptr_
<
read_len_
)
{
return
buffer_
[
read_ptr_
++
];
}
else
{
read_len_
=
stream_
->
Read
(
&
buffer_
[
0
],
buffer_
.
length
());
if
(
read_len_
==
0
)
return
EOF
;
read_ptr_
=
0
;
}
}
}
/*! \return whether we are reaching the end of file */
bool
AtEnd
()
const
{
return
read_len_
==
0
;
}
private
:
/*! \brief the underlying stream */
dmlc
::
Stream
*
stream_
{
nullptr
};
/*! \brief buffer to hold data */
std
::
string
buffer_
;
/*! \brief length of valid data in buffer */
size_t
read_len_
{
1
};
/*! \brief pointer in the buffer */
size_t
read_ptr_
{
1
};
};
/*!
* \brief Input stream from base64 encoding
*/
class
Base64InStream
:
public
dmlc
::
Stream
{
public
:
explicit
Base64InStream
(
dmlc
::
Stream
*
fs
)
:
reader_
(
256
)
{
reader_
.
set_stream
(
fs
);
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* \note call this function before actually start read
*/
void
InitPosition
(
void
)
{
// get a character
do
{
temp_ch_
=
reader_
.
GetChar
();
}
while
(
isspace
(
temp_ch_
));
}
/*! \brief whether current position is end of a base64 stream */
bool
IsEOF
(
void
)
const
{
return
num_prev_
==
0
&&
(
temp_ch_
==
EOF
||
isspace
(
temp_ch_
));
}
// override read function.
virtual
size_t
Read
(
void
*
ptr
,
size_t
size
)
{
using
base64
::
DecodeTable
;
if
(
size
==
0
)
return
0
;
// use tlen to record left size
size_t
tlen
=
size
;
unsigned
char
*
cptr
=
static_cast
<
unsigned
char
*>
(
ptr
);
// if anything left, load from previous buffered result
if
(
num_prev_
!=
0
)
{
if
(
num_prev_
==
2
)
{
if
(
tlen
>=
2
)
{
*
cptr
++
=
buf_prev
[
0
];
*
cptr
++
=
buf_prev
[
1
];
tlen
-=
2
;
num_prev_
=
0
;
}
else
{
// assert tlen == 1
*
cptr
++
=
buf_prev
[
0
];
--
tlen
;
buf_prev
[
0
]
=
buf_prev
[
1
];
num_prev_
=
1
;
}
}
else
{
// assert num_prev_ == 1
*
cptr
++
=
buf_prev
[
0
];
--
tlen
;
num_prev_
=
0
;
}
}
if
(
tlen
==
0
)
return
size
;
int
nvalue
;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while
(
tlen
&&
temp_ch_
!=
EOF
&&
!
isspace
(
temp_ch_
))
{
// first byte
nvalue
=
DecodeTable
[
temp_ch_
]
<<
18
;
{
// second byte
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
!=
EOF
&&
!
isspace
(
temp_ch_
))
<<
"invalid base64 format"
;
nvalue
|=
DecodeTable
[
temp_ch_
]
<<
12
;
*
cptr
++
=
(
nvalue
>>
16
)
&
0xFF
;
--
tlen
;
}
{
// third byte
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
!=
EOF
&&
!
isspace
(
temp_ch_
))
<<
"invalid base64 format"
;
// handle termination
if
(
temp_ch_
==
'='
)
{
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
==
'='
)
<<
"invalid base64 format"
;
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
==
EOF
||
isspace
(
temp_ch_
))
<<
"invalid base64 format"
;
break
;
}
nvalue
|=
DecodeTable
[
temp_ch_
]
<<
6
;
if
(
tlen
)
{
*
cptr
++
=
(
nvalue
>>
8
)
&
0xFF
;
--
tlen
;
}
else
{
buf_prev
[
num_prev_
++
]
=
(
nvalue
>>
8
)
&
0xFF
;
}
}
{
// fourth byte
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
!=
EOF
&&
!
isspace
(
temp_ch_
))
<<
"invalid base64 format"
;
if
(
temp_ch_
==
'='
)
{
temp_ch_
=
reader_
.
GetChar
();
CHECK
(
temp_ch_
==
EOF
||
isspace
(
temp_ch_
))
<<
"invalid base64 format"
;
break
;
}
nvalue
|=
DecodeTable
[
temp_ch_
];
if
(
tlen
)
{
*
cptr
++
=
nvalue
&
0xFF
;
--
tlen
;
}
else
{
buf_prev
[
num_prev_
++
]
=
nvalue
&
0xFF
;
}
}
// get next char
temp_ch_
=
reader_
.
GetChar
();
}
if
(
kStrictCheck
)
{
CHECK_EQ
(
tlen
,
0
)
<<
"Base64InStream: read incomplete"
;
}
return
size
-
tlen
;
}
virtual
void
Write
(
const
void
*
ptr
,
size_t
size
)
{
LOG
(
FATAL
)
<<
"Base64InStream do not support write"
;
}
private
:
// internal reader
StreamBufferReader
reader_
;
int
temp_ch_
{
0
};
int
num_prev_
{
0
};
unsigned
char
buf_prev
[
2
];
// whether we need to do strict check
static
const
bool
kStrictCheck
=
false
;
};
/*!
* \brief Stream to write to base64 format.
*/
class
Base64OutStream
:
public
dmlc
::
Stream
{
public
:
explicit
Base64OutStream
(
dmlc
::
Stream
*
fp
)
:
fp_
(
fp
)
{
}
virtual
void
Write
(
const
void
*
ptr
,
size_t
size
)
{
using
base64
::
EncodeTable
;
size_t
tlen
=
size
;
const
unsigned
char
*
cptr
=
static_cast
<
const
unsigned
char
*>
(
ptr
);
while
(
tlen
)
{
while
(
buf__top_
<
3
&&
tlen
!=
0
)
{
buf_
[
++
buf__top_
]
=
*
cptr
++
;
--
tlen
;
}
if
(
buf__top_
==
3
)
{
// flush 4 bytes out
PutChar
(
EncodeTable
[
buf_
[
1
]
>>
2
]);
PutChar
(
EncodeTable
[((
buf_
[
1
]
<<
4
)
|
(
buf_
[
2
]
>>
4
))
&
0x3F
]);
PutChar
(
EncodeTable
[((
buf_
[
2
]
<<
2
)
|
(
buf_
[
3
]
>>
6
))
&
0x3F
]);
PutChar
(
EncodeTable
[
buf_
[
3
]
&
0x3F
]);
buf__top_
=
0
;
}
}
}
virtual
size_t
Read
(
void
*
ptr
,
size_t
size
)
{
LOG
(
FATAL
)
<<
"Base64OutStream do not support read"
;
return
0
;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch character to put to end of stream, if it is EOF, then nothing will be appended.
*/
void
Finish
(
char
endch
=
EOF
)
{
using
base64
::
EncodeTable
;
if
(
buf__top_
==
1
)
{
PutChar
(
EncodeTable
[
buf_
[
1
]
>>
2
]);
PutChar
(
EncodeTable
[(
buf_
[
1
]
<<
4
)
&
0x3F
]);
PutChar
(
'='
);
PutChar
(
'='
);
}
if
(
buf__top_
==
2
)
{
PutChar
(
EncodeTable
[
buf_
[
1
]
>>
2
]);
PutChar
(
EncodeTable
[((
buf_
[
1
]
<<
4
)
|
(
buf_
[
2
]
>>
4
))
&
0x3F
]);
PutChar
(
EncodeTable
[(
buf_
[
2
]
<<
2
)
&
0x3F
]);
PutChar
(
'='
);
}
buf__top_
=
0
;
if
(
endch
!=
EOF
)
PutChar
(
endch
);
this
->
Flush
();
}
private
:
static
constexpr
size_t
kBufferSize
=
256
;
dmlc
::
Stream
*
fp_
{
nullptr
};
int
buf__top_
{
0
};
unsigned
char
buf_
[
4
];
std
::
string
out_buf_
;
void
PutChar
(
char
ch
)
{
out_buf_
+=
ch
;
if
(
out_buf_
.
length
()
>=
kBufferSize
)
Flush
();
}
void
Flush
(
void
)
{
if
(
out_buf_
.
length
()
!=
0
)
{
fp_
->
Write
(
&
out_buf_
[
0
],
out_buf_
.
length
());
out_buf_
.
clear
();
}
}
};
}
// namespace common
}
// namespace tvm
#endif // TVM_COMMON_BASE64_H_
src/lang/reflection.cc
View file @
72fa4c1d
...
...
@@ -7,8 +7,11 @@
#include <tvm/expr.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string>
#include "../common/base64.h"
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
NodeFactoryReg
);
...
...
@@ -23,6 +26,7 @@ inline std::string Type2String(const Type& t) {
return
os
.
str
();
}
inline
Type
String2Type
(
std
::
string
s
)
{
std
::
istringstream
is
(
s
);
halideir_type_code_t
code
=
Type
::
Int
;
...
...
@@ -52,6 +56,8 @@ class NodeIndexer : public AttrVisitor {
public
:
std
::
unordered_map
<
Node
*
,
size_t
>
node_index
{{
nullptr
,
0
}};
std
::
vector
<
Node
*>
node_list
{
nullptr
};
std
::
unordered_map
<
DLTensor
*
,
size_t
>
tensor_index
;
std
::
vector
<
DLTensor
*>
tensor_list
;
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{}
...
...
@@ -64,7 +70,13 @@ class NodeIndexer : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
MakeIndex
(
value
->
node_
.
get
());
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
DLTensor
*
ptr
=
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
());
if
(
tensor_index
.
count
(
ptr
))
return
;
CHECK_EQ
(
tensor_index
.
size
(),
tensor_list
.
size
());
tensor_index
[
ptr
]
=
tensor_list
.
size
();
tensor_list
.
push_back
(
ptr
);
}
// make index of all the children of node
void
MakeIndex
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
...
...
@@ -140,6 +152,7 @@ struct JSONNode {
class
JSONAttrGetter
:
public
AttrVisitor
{
public
:
const
std
::
unordered_map
<
Node
*
,
size_t
>*
node_index_
;
const
std
::
unordered_map
<
DLTensor
*
,
size_t
>*
tensor_index_
;
JSONNode
*
node_
;
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
...
...
@@ -170,6 +183,10 @@ class JSONAttrGetter : public AttrVisitor {
node_
->
attrs
[
key
]
=
std
::
to_string
(
node_index_
->
at
(
value
->
node_
.
get
()));
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
tensor_index_
->
at
(
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
())));
}
// Get the node
void
Get
(
Node
*
node
)
{
if
(
node
==
nullptr
)
{
...
...
@@ -209,6 +226,7 @@ class JSONAttrGetter : public AttrVisitor {
class
JSONAttrSetter
:
public
AttrVisitor
{
public
:
const
std
::
vector
<
std
::
shared_ptr
<
Node
>
>*
node_list_
;
const
std
::
vector
<
runtime
::
NDArray
>*
tensor_list_
;
JSONNode
*
node_
;
std
::
string
GetValue
(
const
char
*
key
)
const
{
...
...
@@ -254,10 +272,16 @@ class JSONAttrSetter : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
size_t
index
;
ParseValue
(
key
,
&
index
);
CHECK_LE
(
index
,
node_list_
->
size
());
value
->
node_
=
node_list_
->
at
(
index
);
}
// Get the node
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
size_t
index
;
ParseValue
(
key
,
&
index
);
CHECK_LE
(
index
,
tensor_list_
->
size
());
*
value
=
tensor_list_
->
at
(
index
);
}
// set node to be current JSONNode
void
Set
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
if
(
node
->
is_type
<
ArrayNode
>
())
{
...
...
@@ -292,6 +316,8 @@ struct JSONGraph {
size_t
root
;
// the nodes of the graph
std
::
vector
<
JSONNode
>
nodes
;
// base64 b64ndarrays of arrays
std
::
vector
<
std
::
string
>
b64ndarrays
;
// global attributes
AttrMap
attrs
;
...
...
@@ -299,6 +325,7 @@ struct JSONGraph {
writer
->
BeginObject
();
writer
->
WriteObjectKeyValue
(
"root"
,
root
);
writer
->
WriteObjectKeyValue
(
"nodes"
,
nodes
);
writer
->
WriteObjectKeyValue
(
"b64ndarrays"
,
b64ndarrays
);
if
(
attrs
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
}
...
...
@@ -310,6 +337,7 @@ struct JSONGraph {
dmlc
::
JSONObjectReadHelper
helper
;
helper
.
DeclareField
(
"root"
,
&
root
);
helper
.
DeclareField
(
"nodes"
,
&
nodes
);
helper
.
DeclareOptionalField
(
"b64ndarrays"
,
&
b64ndarrays
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
ReadAllFields
(
reader
);
}
...
...
@@ -320,6 +348,7 @@ struct JSONGraph {
indexer
.
MakeIndex
(
root
.
node_
.
get
());
JSONAttrGetter
getter
;
getter
.
node_index_
=
&
indexer
.
node_index
;
getter
.
tensor_index_
=
&
indexer
.
tensor_index
;
for
(
Node
*
n
:
indexer
.
node_list
)
{
JSONNode
jnode
;
getter
.
node_
=
&
jnode
;
...
...
@@ -328,6 +357,15 @@ struct JSONGraph {
}
g
.
attrs
[
"tvm_version"
]
=
TVM_VERSION
;
g
.
root
=
indexer
.
node_index
.
at
(
root
.
node_
.
get
());
// serialize tensor
for
(
DLTensor
*
tensor
:
indexer
.
tensor_list
)
{
std
::
string
blob
;
dmlc
::
MemoryStringStream
mstrm
(
&
blob
);
common
::
Base64OutStream
b64strm
(
&
mstrm
);
runtime
::
SaveDLTensor
(
&
b64strm
,
tensor
);
b64strm
.
Finish
();
g
.
b64ndarrays
.
emplace_back
(
std
::
move
(
blob
));
}
return
g
;
}
};
...
...
@@ -347,6 +385,16 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
// load in json graph.
jgraph
.
Load
(
&
reader
);
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
nodes
;
std
::
vector
<
runtime
::
NDArray
>
tensors
;
// load in tensors
for
(
const
std
::
string
&
blob
:
jgraph
.
b64ndarrays
)
{
dmlc
::
MemoryStringStream
mstrm
(
const_cast
<
std
::
string
*>
(
&
blob
));
common
::
Base64InStream
b64strm
(
&
mstrm
);
b64strm
.
InitPosition
();
runtime
::
NDArray
temp
;
CHECK
(
temp
.
Load
(
&
b64strm
));
tensors
.
emplace_back
(
temp
);
}
// node 0 is always null
nodes
.
reserve
(
jgraph
.
nodes
.
size
());
for
(
const
JSONNode
&
jnode
:
jgraph
.
nodes
)
{
...
...
@@ -362,6 +410,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
CHECK_EQ
(
nodes
.
size
(),
jgraph
.
nodes
.
size
());
JSONAttrSetter
setter
;
setter
.
node_list_
=
&
nodes
;
setter
.
tensor_list_
=
&
tensors
;
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
setter
.
node_
=
&
jgraph
.
nodes
[
i
];
...
...
@@ -402,6 +451,9 @@ class NodeAttrSetter : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
*
value
=
GetAttr
(
key
).
operator
NodeRef
();
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
*
value
=
GetAttr
(
key
).
operator
runtime
::
NDArray
();
}
private
:
runtime
::
TVMArgValue
GetAttr
(
const
char
*
key
)
{
...
...
src/runtime/graph/graph_runtime.cc
View file @
72fa4c1d
...
...
@@ -4,7 +4,7 @@
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/
serializer
.h>
#include <tvm/runtime/
ndarray
.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <numeric>
...
...
@@ -399,52 +399,9 @@ class GraphRuntime : public ModuleNode {
void
GraphRuntime
::
LoadDLTensor
(
dmlc
::
Stream
*
strm
,
DLTensor
*
dst
)
{
// always use strm->Read to maintain endianness conversion
uint64_t
header
,
reserved
;
CHECK
(
strm
->
Read
(
&
header
))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
reserved
))
<<
"Invalid DLTensor file format"
;
CHECK
(
header
==
kTVMNDArrayMagic
)
<<
"Invalid DLTensor file format"
;
DLTensor
tensor
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
ctx
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
ndim
)))
<<
"Invalid DLTensor file format"
;
CHECK
(
strm
->
Read
(
&
(
tensor
.
dtype
)))
<<
"Invalid DLTensor file format"
;
std
::
vector
<
int64_t
>
shape
(
tensor
.
ndim
);
if
(
tensor
.
ndim
!=
0
)
{
CHECK
(
strm
->
ReadArray
(
&
shape
[
0
],
tensor
.
ndim
))
<<
"Invalid DLTensor file format"
;
}
CHECK_EQ
(
tensor
.
ndim
,
dst
->
ndim
)
<<
"param dimension mismatch"
;
CHECK
(
tensor
.
dtype
.
bits
==
dst
->
dtype
.
bits
&&
tensor
.
dtype
.
code
==
dst
->
dtype
.
code
&&
tensor
.
dtype
.
lanes
==
dst
->
dtype
.
lanes
)
<<
"param type mismatch"
;
for
(
int
i
=
0
;
i
<
tensor
.
ndim
;
++
i
)
{
CHECK_EQ
(
shape
[
i
],
dst
->
shape
[
i
])
<<
"param shape mismatch"
;
}
size_t
bits
=
dst
->
dtype
.
bits
*
dst
->
dtype
.
lanes
;
size_t
elem_bytes
=
(
bits
+
7
)
/
8
;
size_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
dst
->
ndim
;
++
i
)
{
num_elems
*=
dst
->
shape
[
i
];
}
uint64_t
data_byte_size
;
CHECK
(
strm
->
Read
(
&
data_byte_size
))
<<
"Invalid DLTensor file format"
;
CHECK_EQ
(
data_byte_size
,
elem_bytes
*
num_elems
)
<<
"Invalid DLTensor file format"
;
std
::
vector
<
uint8_t
>
bytes
(
data_byte_size
+
1
);
CHECK
(
strm
->
Read
(
&
bytes
[
0
],
data_byte_size
))
<<
"Invalid DLTensor file format"
;
// explicitly swap endian when necessary.
if
(
!
DMLC_IO_NO_ENDIAN_SWAP
)
{
dmlc
::
ByteSwap
(
&
bytes
[
0
],
elem_bytes
,
num_elems
);
}
TVM_CCALL
(
TVMArrayCopyFromBytes
(
dst
,
&
bytes
[
0
],
data_byte_size
));
NDArray
temp
;
temp
.
Load
(
strm
);
temp
.
CopyTo
(
dst
);
}
void
GraphRuntime
::
LoadParams
(
dmlc
::
Stream
*
strm
)
{
...
...
src/runtime/graph/graph_runtime.h
View file @
72fa4c1d
...
...
@@ -13,8 +13,6 @@
namespace
tvm
{
namespace
runtime
{
/*! \brief Magic number for NDArray file */
constexpr
uint64_t
kTVMNDArrayMagic
=
0xDD5E40F096B4A13F
;
/*! \brief Magic number for NDArray list file */
constexpr
uint64_t
kTVMNDArrayListMagic
=
0xF7E58D4F05049CB7
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment