Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ba8d00c2
Commit
ba8d00c2
authored
Jan 27, 2018
by
Tianqi Chen
Committed by
GitHub
Jan 27, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TIMER] Enhance time evaluator to create multiple results (#830)
parent
a7cd0a89
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
26 deletions
+50
-26
python/tvm/module.py
+15
-7
src/runtime/rpc/rpc_module.cc
+5
-4
src/runtime/rpc/rpc_session.cc
+17
-8
src/runtime/rpc/rpc_session.h
+7
-4
tests/python/integration/test_ewise.py
+6
-3
No files found.
python/tvm/module.py
View file @
ba8d00c2
"""Container of compiled functions of TVM."""
from
__future__
import
absolute_import
as
_abs
import
struct
from
collections
import
namedtuple
from
._ffi.function
import
ModuleBase
,
_set_class_module
from
._ffi.function
import
_init_api
from
.contrib
import
cc
as
_cc
,
tar
as
_tar
,
util
as
_util
ProfileResult
=
namedtuple
(
"ProfileResult"
,
[
"mean"
])
ProfileResult
=
namedtuple
(
"ProfileResult"
,
[
"mean"
,
"results"
])
class
Module
(
ModuleBase
):
...
...
@@ -110,7 +111,7 @@ class Module(ModuleBase):
fcompile
=
_cc
.
create_shared
fcompile
(
file_name
,
files
,
**
kwargs
)
def
time_evaluator
(
self
,
func_name
,
ctx
,
number
):
def
time_evaluator
(
self
,
func_name
,
ctx
,
number
,
repeat
=
1
):
"""Get an evaluator that measures time cost of running function.
Parameters
...
...
@@ -122,11 +123,15 @@ class Module(ModuleBase):
The context we should run this function on.
number: int
The number of repeative times to run evaluation.
The number of steps used in measuring each time interval
repeat: int, optional
Number of times to run the timer measurement
If repeat equals 3, then we will get 3 numbers in the ProfileResult.
Note
----
The function will be invoked number + 1 times,
The function will be invoked
repeat *
number + 1 times,
with the first call discarded in case there is lazy initialization.
Returns
...
...
@@ -137,13 +142,16 @@ class Module(ModuleBase):
"""
try
:
feval
=
_RPCTimeEvaluator
(
self
,
func_name
,
ctx
.
device_type
,
ctx
.
device_id
,
number
)
self
,
func_name
,
ctx
.
device_type
,
ctx
.
device_id
,
number
,
repeat
)
def
evaluator
(
*
args
):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
mean
=
feval
(
*
args
)
return
ProfileResult
(
mean
=
mean
)
blob
=
feval
(
*
args
)
fmt
=
"@"
+
(
"d"
*
repeat
)
results
=
struct
.
unpack
(
fmt
,
blob
)
mean
=
sum
(
results
)
/
float
(
repeat
)
return
ProfileResult
(
mean
=
mean
,
results
=
results
)
return
evaluator
except
NameError
:
...
...
src/runtime/rpc/rpc_module.cc
View file @
ba8d00c2
...
...
@@ -77,10 +77,11 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc
GetTimeEvaluator
(
const
std
::
string
&
name
,
TVMContext
ctx
,
int
nstep
)
{
int
number
,
int
repeat
)
{
RPCFuncHandle
handle
=
GetFuncHandle
(
name
);
if
(
handle
==
nullptr
)
return
PackedFunc
();
handle
=
sess_
->
GetTimeEvaluator
(
handle
,
ctx
,
n
step
);
handle
=
sess_
->
GetTimeEvaluator
(
handle
,
ctx
,
n
umber
,
repeat
);
return
WrapRemote
(
handle
);
}
...
...
@@ -148,10 +149,10 @@ TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
ctx
.
device_id
=
args
[
3
];
if
(
tkey
==
"rpc"
)
{
*
rv
=
static_cast
<
RPCModuleNode
*>
(
m
.
operator
->
())
->
GetTimeEvaluator
(
args
[
1
],
ctx
,
args
[
4
]);
->
GetTimeEvaluator
(
args
[
1
],
ctx
,
args
[
4
]
,
args
[
5
]
);
}
else
{
*
rv
=
WrapTimeEvaluator
(
m
.
GetFunction
(
args
[
1
],
false
),
ctx
,
args
[
4
]);
m
.
GetFunction
(
args
[
1
],
false
),
ctx
,
args
[
4
]
,
args
[
5
]
);
}
});
...
...
src/runtime/rpc/rpc_session.cc
View file @
ba8d00c2
...
...
@@ -844,8 +844,9 @@ void RPCSession::CopyFromRemote(void* from,
}
RPCFuncHandle
RPCSession
::
GetTimeEvaluator
(
RPCFuncHandle
fhandle
,
TVMContext
ctx
,
int
nstep
)
{
return
this
->
CallRemote
(
RPCCode
::
kGetTimeEvaluator
,
fhandle
,
ctx
,
nstep
);
RPCFuncHandle
fhandle
,
TVMContext
ctx
,
int
number
,
int
repeat
)
{
return
this
->
CallRemote
(
RPCCode
::
kGetTimeEvaluator
,
fhandle
,
ctx
,
number
,
repeat
);
}
// Event handler functions
...
...
@@ -973,7 +974,7 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void
RPCGetTimeEvaluator
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
PackedFunc
*
pf
=
static_cast
<
PackedFunc
*>
(
args
[
0
].
operator
void
*
());
void
*
fhandle
=
new
PackedFunc
(
WrapTimeEvaluator
(
*
pf
,
args
[
1
],
args
[
2
]));
void
*
fhandle
=
new
PackedFunc
(
WrapTimeEvaluator
(
*
pf
,
args
[
1
],
args
[
2
]
,
args
[
3
]
));
delete
pf
;
*
rv
=
fhandle
;
}
...
...
@@ -1024,23 +1025,31 @@ void RPCSession::EventHandler::HandlePackedCall() {
CHECK_EQ
(
state_
,
kRecvCode
);
}
PackedFunc
WrapTimeEvaluator
(
PackedFunc
pf
,
TVMContext
ctx
,
int
n
step
)
{
auto
ftimer
=
[
pf
,
ctx
,
n
step
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
PackedFunc
WrapTimeEvaluator
(
PackedFunc
pf
,
TVMContext
ctx
,
int
n
umber
,
int
repeat
)
{
auto
ftimer
=
[
pf
,
ctx
,
n
umber
,
repeat
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
TVMRetValue
temp
;
std
::
ostringstream
os
;
// skip first time call, to activate lazy compilation components.
pf
.
CallPacked
(
args
,
&
temp
);
DeviceAPI
::
Get
(
ctx
)
->
StreamSync
(
ctx
,
nullptr
);
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
// start timing
auto
tbegin
=
std
::
chrono
::
high_resolution_clock
::
now
();
for
(
int
i
=
0
;
i
<
nstep
;
++
i
)
{
for
(
int
i
=
0
;
i
<
number
;
++
i
)
{
pf
.
CallPacked
(
args
,
&
temp
);
}
DeviceAPI
::
Get
(
ctx
)
->
StreamSync
(
ctx
,
nullptr
);
auto
tend
=
std
::
chrono
::
high_resolution_clock
::
now
();
double
speed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>
>
(
tend
-
tbegin
).
count
()
/
nstep
;
tend
-
tbegin
).
count
()
/
number
;
os
.
write
(
reinterpret_cast
<
char
*>
(
&
speed
),
sizeof
(
speed
));
}
std
::
string
blob
=
os
.
str
();
TVMByteArray
arr
;
arr
.
size
=
blob
.
length
();
arr
.
data
=
blob
.
data
();
// return the time.
*
rv
=
speed
;
*
rv
=
arr
;
};
return
PackedFunc
(
ftimer
);
}
...
...
src/runtime/rpc/rpc_session.h
View file @
ba8d00c2
...
...
@@ -146,12 +146,14 @@ class RPCSession {
*
* \param fhandle The function handle.
* \param ctx The ctx to run measurement on.
* \param nstep Number of steps to run.
* \param number How many steps to run in each time evaluation
* \param repeat How many times to repeat the timer
* \return A remote timer function
*/
RPCFuncHandle
GetTimeEvaluator
(
RPCFuncHandle
fhandle
,
TVMContext
ctx
,
int
nstep
);
int
number
,
int
repeat
);
/*!
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
...
...
@@ -212,9 +214,10 @@ class RPCSession {
* \brief Wrap a timer function for a given packed function.
* \param f The function argument.
* \param ctx The context.
* \param nstep Number of repeative steps.
* \param number Number of steps in the inner iteration
* \param repeat How many steps to repeat the time evaluation.
*/
PackedFunc
WrapTimeEvaluator
(
PackedFunc
f
,
TVMContext
ctx
,
int
n
step
);
PackedFunc
WrapTimeEvaluator
(
PackedFunc
f
,
TVMContext
ctx
,
int
n
umber
,
int
repeat
);
/*!
* \brief Create a Global RPC module that refers to the session.
...
...
tests/python/integration/test_ewise.py
View file @
ba8d00c2
...
...
@@ -55,7 +55,10 @@ def test_log_pow_llvm():
n
=
1028
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
B
.
dtype
),
ctx
)
flog
(
a
,
b
)
repeat
=
10
ftimer
=
flog
.
time_evaluator
(
flog
.
entry_name
,
ctx
,
number
=
1
,
repeat
=
repeat
)
res
=
ftimer
(
a
,
b
)
assert
(
len
(
res
.
results
)
==
repeat
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
np
.
power
(
np
.
log
(
a
.
asnumpy
()),
2.0
),
rtol
=
1e-5
)
...
...
@@ -146,7 +149,7 @@ def test_add():
if
__name__
==
"__main__"
:
test_add
()
test_log_pow_llvm
()
test_popcount
()
test_exp
()
test_add
()
test_popcount
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment