Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b3f09b01
Commit
b3f09b01
authored
Apr 12, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 12, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] LocalSession to provide RPCSession back by local env (#1102)
parent
56a6ef31
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
44 deletions
+102
-44
python/tvm/contrib/rpc/__init__.py
+1
-1
python/tvm/contrib/rpc/client.py
+60
-26
python/tvm/contrib/rpc/server.py
+1
-12
python/tvm/module.py
+18
-1
tests/python/unittest/test_runtime_rpc.py
+22
-4
No files found.
python/tvm/contrib/rpc/__init__.py
View file @
b3f09b01
...
...
@@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness.
"""
from
.server
import
Server
from
.client
import
RPCSession
,
connect
,
connect_tracker
from
.client
import
RPCSession
,
LocalSession
,
connect
,
connect_tracker
python/tvm/contrib/rpc/client.py
View file @
b3f09b01
...
...
@@ -7,8 +7,11 @@ import struct
import
time
from
.
import
base
from
..
import
util
from
..._ffi.base
import
TVMError
from
..._ffi.ndarray
import
context
as
_context
from
..._ffi
import
function
as
function
from
..._ffi
import
ndarray
as
nd
from
...module
import
load
as
_load_module
class
RPCSession
(
object
):
...
...
@@ -51,36 +54,12 @@ class RPCSession(object):
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx
=
_
context
(
dev_type
,
dev_id
)
ctx
=
nd
.
context
(
dev_type
,
dev_id
)
encode
=
(
self
.
_tbl_index
+
1
)
*
base
.
RPC_SESS_MASK
ctx
.
device_type
+=
encode
ctx
.
_rpc_sess
=
self
return
ctx
def
cpu
(
self
,
dev_id
=
0
):
"""Construct remote CPU device."""
return
self
.
context
(
1
,
dev_id
)
def
gpu
(
self
,
dev_id
=
0
):
"""Construct remote GPU device."""
return
self
.
context
(
2
,
dev_id
)
def
cl
(
self
,
dev_id
=
0
):
"""Construct remote OpenCL device."""
return
self
.
context
(
4
,
dev_id
)
def
metal
(
self
,
dev_id
=
0
):
"""Construct remote Metal device."""
return
self
.
context
(
8
,
dev_id
)
def
opengl
(
self
,
dev_id
=
0
):
"""Construct remote OpenGL device."""
return
self
.
context
(
11
,
dev_id
)
def
ext_dev
(
self
,
dev_id
=
0
):
"""Construct remote extension device."""
return
self
.
context
(
12
,
dev_id
)
def
upload
(
self
,
data
,
target
=
None
):
"""Upload file to remote runtime temp folder
...
...
@@ -139,6 +118,61 @@ class RPCSession(object):
"""
return
base
.
_LoadRemoteModule
(
self
.
_sess
,
path
)
def
cpu
(
self
,
dev_id
=
0
):
"""Construct CPU device."""
return
self
.
context
(
1
,
dev_id
)
def
gpu
(
self
,
dev_id
=
0
):
"""Construct GPU device."""
return
self
.
context
(
2
,
dev_id
)
def
cl
(
self
,
dev_id
=
0
):
"""Construct OpenCL device."""
return
self
.
context
(
4
,
dev_id
)
def
metal
(
self
,
dev_id
=
0
):
"""Construct Metal device."""
return
self
.
context
(
8
,
dev_id
)
def
opengl
(
self
,
dev_id
=
0
):
"""Construct OpenGL device."""
return
self
.
context
(
11
,
dev_id
)
def
ext_dev
(
self
,
dev_id
=
0
):
"""Construct extension device."""
return
self
.
context
(
12
,
dev_id
)
class
LocalSession
(
RPCSession
):
"""RPCSession interface backed by local environment.
This class can be used to implement functions that
need to be ran both locally and remotely.
"""
def
__init__
(
self
):
# pylint: disable=super-init-not-called
self
.
context
=
nd
.
context
self
.
get_function
=
function
.
get_global_func
self
.
_temp
=
util
.
tempdir
()
def
upload
(
self
,
data
,
target
=
None
):
if
isinstance
(
data
,
bytearray
):
if
not
target
:
raise
ValueError
(
"target must present when file is a bytearray"
)
blob
=
data
else
:
blob
=
bytearray
(
open
(
data
,
"rb"
)
.
read
())
if
not
target
:
target
=
os
.
path
.
basename
(
data
)
with
open
(
self
.
_temp
.
relpath
(
target
),
"wb"
)
as
f
:
f
.
write
(
blob
)
def
download
(
self
,
path
):
return
bytearray
(
open
(
self
.
_temp
.
relpath
(
path
),
"rb"
)
.
read
())
def
load_module
(
self
,
path
):
return
_load_module
(
self
.
_temp
.
relpath
(
path
))
class
TrackerSession
(
object
):
"""Tracker client session.
...
...
python/tvm/contrib/rpc/server.py
View file @
b3f09b01
...
...
@@ -22,7 +22,7 @@ import time
from
..._ffi.function
import
register_func
from
..._ffi.base
import
py_str
from
...module
import
load
as
_load_module
from
..
import
util
,
cc
,
tar
from
..
import
util
from
.
import
base
from
.
base
import
TrackerCode
...
...
@@ -38,17 +38,6 @@ def _server_env():
def
load_module
(
file_name
):
"""Load module from remote side."""
path
=
temp
.
relpath
(
file_name
)
# Try create a shared library in remote
if
path
.
endswith
(
".o"
):
logging
.
info
(
"Create shared library based on
%
s"
,
path
)
cc
.
create_shared
(
path
+
".so"
,
path
)
path
+=
".so"
elif
path
.
endswith
(
".tar"
):
tar_temp
=
util
.
tempdir
()
tar
.
untar
(
path
,
tar_temp
.
temp_dir
)
files
=
[
tar_temp
.
relpath
(
x
)
for
x
in
tar_temp
.
listdir
()]
cc
.
create_shared
(
path
+
".so"
,
files
)
path
+=
".so"
m
=
_load_module
(
path
)
logging
.
info
(
"load_module
%
s"
,
path
)
return
m
...
...
python/tvm/module.py
View file @
b3f09b01
...
...
@@ -186,7 +186,7 @@ def system_lib():
def
load
(
path
,
fmt
=
""
):
"""Load module from file
"""Load module from file
.
Parameters
----------
...
...
@@ -201,7 +201,24 @@ def load(path, fmt=""):
-------
module : Module
The loaded module
Note
----
This function will automatically call
cc.create_shared if the path is in format .o or .tar
"""
# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if
path
.
endswith
(
".o"
):
_cc
.
create_shared
(
path
+
".so"
,
path
)
path
+=
".so"
elif
path
.
endswith
(
".tar"
):
tar_temp
=
_util
.
tempdir
()
_tar
.
untar
(
path
,
tar_temp
.
temp_dir
)
files
=
[
tar_temp
.
relpath
(
x
)
for
x
in
tar_temp
.
listdir
()]
_cc
.
create_shared
(
path
+
".so"
,
files
)
path
+=
".so"
# Redirect to the load API
return
_LoadFromFile
(
path
,
fmt
)
...
...
tests/python/unittest/test_runtime_rpc.py
View file @
b3f09b01
...
...
@@ -61,14 +61,14 @@ def test_rpc_remote_module():
if
not
tvm
.
module
.
enabled
(
"rpc"
):
return
server
=
rpc
.
Server
(
"localhost"
)
remote
=
rpc
.
connect
(
server
.
host
,
server
.
port
)
client
=
rpc
.
connect
(
server
.
host
,
server
.
port
)
# graph
n
=
tvm
.
convert
(
1024
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
1.0
,
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
def
check_remote
():
def
check_remote
(
remote
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
print
(
"Skip because llvm is not enabled"
)
return
...
...
@@ -86,7 +86,7 @@ def test_rpc_remote_module():
print
(
'
%
g secs/op'
%
cost
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
def
check_remote_link_cl
():
def
check_remote_link_cl
(
remote
):
"""Test function to run remote code such as cl
This is not enabled because there is forking issue
...
...
@@ -134,7 +134,9 @@ def test_rpc_remote_module():
fhost
(
a
,
b
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
check_remote
()
check_remote
(
client
)
check_remote
(
rpc
.
LocalSession
())
def
test_rpc_return_func
():
@tvm.register_func
(
"rpc.test.remote_func"
)
...
...
@@ -147,6 +149,21 @@ def test_rpc_return_func():
assert
fadd
(
12
)
==
22
def
test_local_func
():
@tvm.register_func
(
"rpc.test.remote_func2"
)
def
addone
(
x
):
return
lambda
y
:
x
+
y
client
=
rpc
.
LocalSession
()
f1
=
client
.
get_function
(
"rpc.test.remote_func2"
)
fadd
=
f1
(
10
)
assert
fadd
(
12
)
==
22
blob
=
bytearray
(
np
.
random
.
randint
(
0
,
10
,
size
=
(
10
)))
client
.
upload
(
blob
,
"dat.bin"
)
rev
=
client
.
download
(
"dat.bin"
)
assert
rev
==
blob
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
test_rpc_remote_module
()
...
...
@@ -154,3 +171,4 @@ if __name__ == "__main__":
test_rpc_file_exchange
()
test_rpc_array
()
test_rpc_simple
()
test_local_func
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment