Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5912ed03
Commit
5912ed03
authored
Jun 03, 2017
by
Tianqi Chen
Committed by
GitHub
Jun 03, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PERF/TIMER] Add builtin timing logic (#168)
* [PERF/TIMER] Add buildin timing logic * fix lint
parent
46b4a914
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
146 additions
and
26 deletions
+146
-26
python/tvm/_ffi/function.py
+4
-2
python/tvm/module.py
+32
-1
src/runtime/rpc/rpc_module.cc
+45
-12
src/runtime/rpc/rpc_session.cc
+35
-0
src/runtime/rpc/rpc_session.h
+22
-1
tests/python/integration/test_ewise.py
+2
-1
tests/python/integration/test_gemm.py
+3
-8
tests/python/unittest/test_runtime_rpc.py
+3
-1
No files found.
python/tvm/_ffi/function.py
View file @
5912ed03
...
@@ -56,10 +56,12 @@ class Function(_FunctionBase):
...
@@ -56,10 +56,12 @@ class Function(_FunctionBase):
class
ModuleBase
(
object
):
class
ModuleBase
(
object
):
"""Base class for module"""
"""Base class for module"""
__slots__
=
[
"handle"
,
"_entry"
]
__slots__
=
[
"handle"
,
"_entry"
,
"entry_name"
]
def
__init__
(
self
,
handle
):
def
__init__
(
self
,
handle
):
self
.
handle
=
handle
self
.
handle
=
handle
self
.
_entry
=
None
self
.
_entry
=
None
self
.
entry_name
=
"__tvm_main__"
def
__del__
(
self
):
def
__del__
(
self
):
check_call
(
_LIB
.
TVMModFree
(
self
.
handle
))
check_call
(
_LIB
.
TVMModFree
(
self
.
handle
))
...
@@ -75,7 +77,7 @@ class ModuleBase(object):
...
@@ -75,7 +77,7 @@ class ModuleBase(object):
"""
"""
if
self
.
_entry
:
if
self
.
_entry
:
return
self
.
_entry
return
self
.
_entry
self
.
_entry
=
self
.
get_function
(
"__tvm_main__"
)
self
.
_entry
=
self
.
get_function
(
self
.
entry_name
)
return
self
.
_entry
return
self
.
_entry
def
get_function
(
self
,
name
,
query_imports
=
False
):
def
get_function
(
self
,
name
,
query_imports
=
False
):
...
...
python/tvm/module.py
View file @
5912ed03
...
@@ -72,7 +72,7 @@ class Module(ModuleBase):
...
@@ -72,7 +72,7 @@ class Module(ModuleBase):
The name of the shared library.
The name of the shared library.
"""
"""
if
self
.
type_key
!=
"llvm"
:
if
self
.
type_key
!=
"llvm"
:
raise
ValueError
(
"
Only llvm support export shared"
)
raise
ValueError
(
"
Module[
%
s]: Only llvm support export shared"
%
self
.
type_key
)
temp
=
_util
.
tempdir
()
temp
=
_util
.
tempdir
()
path_obj
=
temp
.
relpath
(
"lib.o"
)
path_obj
=
temp
.
relpath
(
"lib.o"
)
self
.
save
(
path_obj
)
self
.
save
(
path_obj
)
...
@@ -84,6 +84,37 @@ class Module(ModuleBase):
...
@@ -84,6 +84,37 @@ class Module(ModuleBase):
files
.
append
(
path_cc
)
files
.
append
(
path_cc
)
_cc
.
create_shared
(
file_name
,
files
)
_cc
.
create_shared
(
file_name
,
files
)
def
time_evaluator
(
self
,
func_name
,
ctx
,
number
):
"""Get an evaluator that measures time cost of running function.
Parameters
----------
func_name: str
The name of the function in the module.
ctx: TVMContext
The context we should run this function on.
number: int
The number of repeative times to run evaluation.
Note
----
The function will be invoked number + 1 times,
with the first call discarded in case there is lazy initialization.
Returns
-------
ftimer : Function
The function that takes same argument as func
and return a float representing seconds per function call.
"""
try
:
return
_RPCTimeEvaluator
(
self
,
func_name
,
ctx
.
device_type
,
ctx
.
device_id
,
number
)
except
NameError
:
raise
NameError
(
"time_evaluate is only supported when RPC is enabled"
)
def
load
(
path
,
fmt
=
""
):
def
load
(
path
,
fmt
=
""
):
"""Load module from file
"""Load module from file
...
...
src/runtime/rpc/rpc_module.cc
View file @
5912ed03
...
@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode {
...
@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc
GetFunction
(
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
ModuleNode
>&
sptr_to_self
)
final
{
const
std
::
shared_ptr
<
ModuleNode
>&
sptr_to_self
)
final
{
RPCFuncHandle
handle
=
nullptr
;
RPCFuncHandle
handle
=
GetFuncHandle
(
name
);
if
(
module_handle_
==
nullptr
)
{
return
WrapRemote
(
handle
);
handle
=
sess_
->
CallRemote
(
RPCCode
::
kGetGlobalFunc
,
name
);
}
else
{
handle
=
sess_
->
CallRemote
(
RPCCode
::
kModuleGetFunc
,
module_handle_
,
name
);
}
if
(
handle
==
nullptr
)
return
PackedFunc
();
auto
wf
=
std
::
make_shared
<
RPCWrappedFunc
>
(
handle
,
sess_
);
return
PackedFunc
([
wf
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
return
wf
->
operator
()(
args
,
rv
);
});
}
}
void
SaveToFile
(
const
std
::
string
&
file_name
,
void
SaveToFile
(
const
std
::
string
&
file_name
,
...
@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode {
...
@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode {
return
sess_
;
return
sess_
;
}
}
PackedFunc
GetTimeEvaluator
(
const
std
::
string
&
name
,
TVMContext
ctx
,
int
nstep
)
{
RPCFuncHandle
handle
=
GetFuncHandle
(
name
);
if
(
handle
==
nullptr
)
return
PackedFunc
();
handle
=
sess_
->
GetTimeEvaluator
(
handle
,
ctx
,
nstep
);
return
WrapRemote
(
handle
);
}
private
:
private
:
PackedFunc
WrapRemote
(
RPCFuncHandle
handle
)
{
if
(
handle
==
nullptr
)
return
PackedFunc
();
auto
wf
=
std
::
make_shared
<
RPCWrappedFunc
>
(
handle
,
sess_
);
return
PackedFunc
([
wf
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
return
wf
->
operator
()(
args
,
rv
);
});
}
RPCFuncHandle
GetFuncHandle
(
const
std
::
string
&
name
)
{
RPCFuncHandle
handle
=
nullptr
;
if
(
module_handle_
==
nullptr
)
{
handle
=
sess_
->
CallRemote
(
RPCCode
::
kGetGlobalFunc
,
name
);
}
else
{
handle
=
sess_
->
CallRemote
(
RPCCode
::
kModuleGetFunc
,
module_handle_
,
name
);
}
return
handle
;
}
// The module handle
// The module handle
void
*
module_handle_
{
nullptr
};
void
*
module_handle_
{
nullptr
};
// The local channel
// The local channel
...
@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
...
@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
*
rv
=
RPCConnect
(
args
[
0
],
args
[
1
]);
*
rv
=
RPCConnect
(
args
[
0
],
args
[
1
]);
});
});
TVM_REGISTER_GLOBAL
(
"module._RPCTimeEvaluator"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Module
m
=
args
[
0
];
std
::
string
tkey
=
m
->
type_key
();
TVMContext
ctx
;
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
args
[
2
].
operator
int
());
ctx
.
device_id
=
args
[
3
];
if
(
tkey
==
"rpc"
)
{
*
rv
=
static_cast
<
RPCModuleNode
*>
(
m
.
operator
->
())
->
GetTimeEvaluator
(
args
[
1
],
ctx
,
args
[
4
]);
}
else
{
*
rv
=
WrapTimeEvaluator
(
m
.
GetFunction
(
args
[
1
],
false
),
ctx
,
args
[
3
]);
}
});
TVM_REGISTER_GLOBAL
(
"contrib.rpc._LoadRemoteModule"
)
TVM_REGISTER_GLOBAL
(
"contrib.rpc._LoadRemoteModule"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Module
m
=
args
[
0
];
Module
m
=
args
[
0
];
...
...
src/runtime/rpc/rpc_session.cc
View file @
5912ed03
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <memory>
#include <array>
#include <array>
#include <chrono>
#include "./rpc_session.h"
#include "./rpc_session.h"
#include "../device_api.h"
#include "../device_api.h"
...
@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from,
...
@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from,
}
}
}
}
RPCFuncHandle
RPCSession
::
GetTimeEvaluator
(
RPCFuncHandle
fhandle
,
TVMContext
ctx
,
int
nstep
)
{
return
this
->
CallRemote
(
RPCCode
::
kGetTimeEvaluator
,
fhandle
,
ctx
,
nstep
);
}
void
RPCSession
::
SendReturnValue
(
void
RPCSession
::
SendReturnValue
(
int
succ
,
TVMValue
ret_value
,
int
ret_tcode
)
{
int
succ
,
TVMValue
ret_value
,
int
ret_tcode
)
{
if
(
succ
==
0
)
{
if
(
succ
==
0
)
{
...
@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
...
@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
*
rv
=
(
*
static_cast
<
Module
*>
(
mhandle
))
->
GetSource
(
fmt
);
*
rv
=
(
*
static_cast
<
Module
*>
(
mhandle
))
->
GetSource
(
fmt
);
}
}
void
RPCGetTimeEvaluator
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
PackedFunc
*
pf
=
static_cast
<
PackedFunc
*>
(
args
[
0
].
operator
void
*
());
void
*
fhandle
=
new
PackedFunc
(
WrapTimeEvaluator
(
*
pf
,
args
[
1
],
args
[
2
]));
delete
pf
;
*
rv
=
fhandle
;
}
RPCCode
RPCSession
::
HandleNextEvent
(
TVMRetValue
*
rv
)
{
RPCCode
RPCSession
::
HandleNextEvent
(
TVMRetValue
*
rv
)
{
RPCCode
code
;
RPCCode
code
;
CHECK_EQ
(
sock_
.
RecvAll
(
&
code
,
sizeof
(
int
)),
sizeof
(
int
));
CHECK_EQ
(
sock_
.
RecvAll
(
&
code
,
sizeof
(
int
)),
sizeof
(
int
));
...
@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
...
@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
case
RPCCode
:
:
kCopyToRemote
:
HandleCopyToRemote
();
break
;
case
RPCCode
:
:
kCopyToRemote
:
HandleCopyToRemote
();
break
;
case
RPCCode
:
:
kShutdown
:
break
;
case
RPCCode
:
:
kShutdown
:
break
;
// system functions
// system functions
case
RPCCode
:
:
kGetTimeEvaluator
:
CallHandler
(
RPCGetTimeEvaluator
);
break
;
case
RPCCode
:
:
kFreeFunc
:
CallHandler
(
RPCFreeFunc
);
break
;
case
RPCCode
:
:
kFreeFunc
:
CallHandler
(
RPCFreeFunc
);
break
;
case
RPCCode
:
:
kGetGlobalFunc
:
CallHandler
(
RPCGetGlobalFunc
);
break
;
case
RPCCode
:
:
kGetGlobalFunc
:
CallHandler
(
RPCGetGlobalFunc
);
break
;
case
RPCCode
:
:
kDevSetDevice
:
CallHandler
(
RPCDevSetDevice
);
break
;
case
RPCCode
:
:
kDevSetDevice
:
CallHandler
(
RPCDevSetDevice
);
break
;
...
@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
...
@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
}
}
return
code
;
return
code
;
}
}
PackedFunc
WrapTimeEvaluator
(
PackedFunc
pf
,
TVMContext
ctx
,
int
nstep
)
{
auto
ftimer
=
[
pf
,
ctx
,
nstep
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
TVMRetValue
temp
;
// skip first time call, to activate lazy compilation components.
pf
.
CallPacked
(
args
,
&
temp
);
DeviceAPI
::
Get
(
ctx
)
->
StreamSync
(
ctx
,
nullptr
);
// start timing
auto
tbegin
=
std
::
chrono
::
high_resolution_clock
::
now
();
for
(
int
i
=
0
;
i
<
nstep
;
++
i
)
{
pf
.
CallPacked
(
args
,
&
temp
);
}
DeviceAPI
::
Get
(
ctx
)
->
StreamSync
(
ctx
,
nullptr
);
auto
tend
=
std
::
chrono
::
high_resolution_clock
::
now
();
double
speed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>
>
(
tend
-
tbegin
).
count
()
/
nstep
;
// return the time.
*
rv
=
speed
;
};
return
PackedFunc
(
ftimer
);
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
src/runtime/rpc/rpc_session.h
View file @
5912ed03
...
@@ -31,6 +31,7 @@ enum class RPCCode : int {
...
@@ -31,6 +31,7 @@ enum class RPCCode : int {
kCopyAck
,
kCopyAck
,
// The following are code that can send over CallRemote
// The following are code that can send over CallRemote
kGetGlobalFunc
,
kGetGlobalFunc
,
kGetTimeEvaluator
,
kFreeFunc
,
kFreeFunc
,
kDevSetDevice
,
kDevSetDevice
,
kDevGetAttr
,
kDevGetAttr
,
...
@@ -93,6 +94,18 @@ class RPCSession {
...
@@ -93,6 +94,18 @@ class RPCSession {
size_t
size
,
size_t
size
,
TVMContext
ctx_from
);
TVMContext
ctx_from
);
/*!
/*!
* \brief Get a remote timer function on ctx.
* This function consumes fhandle, caller should not call Free on fhandle.
*
* \param fhandle The function handle.
* \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run.
* \return A remote timer function
*/
RPCFuncHandle
GetTimeEvaluator
(
RPCFuncHandle
fhandle
,
TVMContext
ctx
,
int
nstep
);
/*!
* \brief Call a remote defined system function with arguments.
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
* \param fcode The function code.
* \param args The arguments
* \param args The arguments
...
@@ -133,13 +146,13 @@ class RPCSession {
...
@@ -133,13 +146,13 @@ class RPCSession {
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
);
void
SendPackedSeq
(
const
TVMValue
*
arg_values
,
const
int
*
type_codes
,
int
n
);
void
RecvPackedSeq
(
RPCArgBuffer
*
buf
);
void
RecvPackedSeq
(
RPCArgBuffer
*
buf
);
RPCCode
HandleNextEvent
(
TVMRetValue
*
rv
);
RPCCode
HandleNextEvent
(
TVMRetValue
*
rv
);
TVMContext
StripSessMask
(
TVMContext
ctx
);
// special handler.
// special handler.
void
HandleCallFunc
();
void
HandleCallFunc
();
void
HandleException
();
void
HandleException
();
void
HandleCopyFromRemote
();
void
HandleCopyFromRemote
();
void
HandleCopyToRemote
();
void
HandleCopyToRemote
();
void
HandleReturn
(
TVMRetValue
*
rv
);
void
HandleReturn
(
TVMRetValue
*
rv
);
TVMContext
StripSessMask
(
TVMContext
ctx
);
// Internal mutex
// Internal mutex
std
::
recursive_mutex
mutex_
;
std
::
recursive_mutex
mutex_
;
// Internal socket
// Internal socket
...
@@ -152,6 +165,14 @@ class RPCSession {
...
@@ -152,6 +165,14 @@ class RPCSession {
int
table_index_
{
0
};
int
table_index_
{
0
};
};
};
/*!
* \brief Wrap a timer function for a given packed function.
* \param f The function argument.
* \param ctx The context.
* \param nstep Number of repeative steps.
*/
PackedFunc
WrapTimeEvaluator
(
PackedFunc
f
,
TVMContext
ctx
,
int
nstep
);
// Remote space pointer.
// Remote space pointer.
struct
RemoteSpace
{
struct
RemoteSpace
{
void
*
data
;
void
*
data
;
...
...
tests/python/integration/test_ewise.py
View file @
5912ed03
...
@@ -95,7 +95,8 @@ def test_add():
...
@@ -95,7 +95,8 @@ def test_add():
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
vbias
=
np
.
random
.
uniform
()
vbias
=
np
.
random
.
uniform
()
vscale
=
np
.
random
.
uniform
()
vscale
=
np
.
random
.
uniform
()
fadd
(
a
,
b
,
c
,
vbias
,
vscale
)
ftimer
=
fadd
.
time_evaluator
(
fadd
.
entry_name
,
ctx
,
number
=
1000
)
tcost
=
ftimer
(
a
,
b
,
c
,
vbias
,
vscale
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
()
*
vscale
+
vbias
,
rtol
=
1e-6
)
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
()
*
vscale
+
vbias
,
rtol
=
1e-6
)
...
...
tests/python/integration/test_gemm.py
View file @
5912ed03
...
@@ -78,14 +78,9 @@ def test_gemm():
...
@@ -78,14 +78,9 @@ def test_gemm():
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
ftimer
=
f
.
time_evaluator
(
f
.
entry_name
,
ctx
,
number
=
20
)
ctx
.
sync
()
tcost
=
ftimer
(
a
,
b
,
c
)
tbegin
=
time
.
time
()
print
(
"
%
s: exec=
%
g sec/op"
%
(
ctx
,
tcost
))
f
(
a
,
b
,
c
)
tpush
=
time
.
time
()
ctx
.
sync
()
tend
=
time
.
time
()
print
(
"launch=
%
g sec, exec=
%
g sec"
%
(
tpush
-
tbegin
,
tend
-
tbegin
))
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
...
...
tests/python/unittest/test_runtime_rpc.py
View file @
5912ed03
...
@@ -70,7 +70,9 @@ def test_rpc_remote_module():
...
@@ -70,7 +70,9 @@ def test_rpc_remote_module():
f1
=
remote
.
load_module
(
"dev_lib.so"
)
f1
=
remote
.
load_module
(
"dev_lib.so"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
1024
)
.
astype
(
A
.
dtype
),
ctx
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
1024
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
1024
,
dtype
=
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
1024
,
dtype
=
A
.
dtype
),
ctx
)
f1
(
a
,
b
)
time_f
=
f1
.
time_evaluator
(
f1
.
entry_name
,
remote
.
cpu
(
0
),
number
=
10
)
cost
=
time_f
(
a
,
b
)
print
(
'
%
g secs/op'
%
cost
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
check_remote
()
check_remote
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment