Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cf81f9f9
Commit
cf81f9f9
authored
Nov 30, 2017
by
Tianqi Chen
Committed by
GitHub
Nov 30, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[CUDA] Enable int64 (#683)
* [CUDA] Enable int64 * [PYTHON] Fix rpc tutorial with opencl * OK * update
parent
f5a6e5e2
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
101 additions
and
54 deletions
+101
-54
python/tvm/contrib/rpc.py
+42
-4
src/codegen/codegen_cuda.cc
+11
-7
tests/python/integration/test_ewise.py
+46
-41
tests/scripts/task_python_docs.sh
+1
-1
tutorials/deployment/cross_compilation_and_rpc.py
+1
-1
No files found.
python/tvm/contrib/rpc.py
View file @
cf81f9f9
...
...
@@ -15,6 +15,8 @@ import socket
import
struct
import
logging
import
multiprocessing
import
subprocess
import
time
from
.
import
util
,
cc
,
tar
from
..module
import
load
as
_load_module
from
.._ffi.function
import
_init_api
,
register_func
...
...
@@ -117,6 +119,17 @@ def _connect_proxy_loop(addr, key):
process
.
join
()
def
_popen
(
cmd
):
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
env
=
os
.
environ
)
(
out
,
_
)
=
proc
.
communicate
()
if
proc
.
returncode
!=
0
:
msg
=
"Server invoke error:
\n
"
msg
+=
out
raise
RuntimeError
(
msg
)
class
Server
(
object
):
"""Start RPC server on a seperate process.
...
...
@@ -140,15 +153,36 @@ class Server(object):
If this is true, the host and port actually corresponds to the
address of the proxy server.
use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
key : str, optional
The key used to identify the server in Proxy connection.
"""
def
__init__
(
self
,
host
,
port
=
9091
,
port_end
=
9199
,
is_proxy
=
False
,
key
=
""
):
def
__init__
(
self
,
host
,
port
=
9091
,
port_end
=
9199
,
is_proxy
=
False
,
use_popen
=
False
,
key
=
""
):
self
.
host
=
host
self
.
port
=
port
self
.
libs
=
[]
if
not
is_proxy
:
if
use_popen
:
cmd
=
[
"python"
,
"-m"
,
"tvm.exec.rpc_server"
,
"--host=
%
s"
%
host
,
"--port=
%
s"
%
port
]
self
.
proc
=
multiprocessing
.
Process
(
target
=
subprocess
.
check_call
,
args
=
(
cmd
,))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
time
.
sleep
(
1
)
elif
not
is_proxy
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
for
my_port
in
range
(
port
,
port_end
):
...
...
@@ -168,11 +202,15 @@ class Server(object):
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
else
:
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
def
terminate
(
self
):
"""Terminate the server process"""
...
...
src/codegen/codegen_cuda.cc
View file @
cf81f9f9
...
...
@@ -66,7 +66,11 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
}
}
else
if
(
t
.
is_uint
()
||
t
.
is_int
())
{
if
(
t
.
is_uint
())
{
os
<<
'u'
;
if
(
t
.
lanes
()
!=
1
)
{
os
<<
"u"
;
}
else
{
os
<<
"unsigned "
;
}
}
if
(
t
.
bits
()
==
8
&&
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
...
...
@@ -77,16 +81,16 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
case
16
:
os
<<
"short"
;
break
;
case
32
:
os
<<
"int"
;
break
;
case
64
:
{
if
(
lanes
!=
1
&&
sizeof
(
long
)
==
64
)
{
// NOLINT(*)
os
<<
"long"
;
break
;
}
else
{
os
<<
"int64_t"
;
break
;
}
CHECK
(
sizeof
(
long
)
==
8
)
// NOLINT(*)
<<
"CUDA not support int64 int in 32 bit system"
;
os
<<
"long"
;
break
;
}
case
1
:
os
<<
"int"
;
break
;
default
:
fail
=
true
;
break
;
}
if
(
!
fail
&&
lanes
==
1
)
return
;
if
(
!
fail
&&
lanes
==
1
)
{
return
;
}
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
}
...
...
tests/python/integration/test_ewise.py
View file @
cf81f9f9
...
...
@@ -80,53 +80,58 @@ def test_popcount_llvm():
b
.
asnumpy
(),
list
(
map
(
lambda
x
:
bin
(
x
)
.
count
(
'1'
),
a
.
asnumpy
())),
rtol
=
1e-5
)
def
test_add
():
# graph
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
bias
=
tvm
.
var
(
"bias"
,
dtype
=
"float32"
)
scale
=
tvm
.
var
(
"scale"
,
dtype
=
"float32"
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
B
(
*
i
)
*
scale
+
bias
,
name
=
'C'
)
# schedule
s
=
tvm
.
create_schedule
(
C
.
op
)
# create iter var and assign them tags.
num_thread
=
32
bx
,
x
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
num_thread
*
4
)
tx
,
x
=
s
[
C
]
.
split
(
x
,
nparts
=
num_thread
)
_
,
x
=
s
[
C
]
.
split
(
x
,
factor
=
4
)
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
C
]
.
vectorize
(
x
)
def
run
(
dtype
):
# graph
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
,
dtype
=
dtype
)
bias
=
tvm
.
var
(
"bias"
,
dtype
=
dtype
)
scale
=
tvm
.
var
(
"scale"
,
dtype
=
dtype
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
B
(
*
i
),
name
=
'C'
)
# schedule
s
=
tvm
.
create_schedule
(
C
.
op
)
# create iter var and assign them tags.
num_thread
=
16
bx
,
x
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
num_thread
*
4
)
tx
,
x
=
s
[
C
]
.
split
(
x
,
nparts
=
num_thread
)
_
,
x
=
s
[
C
]
.
split
(
x
,
factor
=
4
)
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
C
]
.
vectorize
(
x
)
# one line to build the function.
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"skip because
%
s is not enabled.."
%
device
)
return
fadd
=
tvm
.
build
(
s
,
[
A
,
B
,
C
,
bias
,
scale
],
device
,
name
=
"myadd"
)
ctx
=
tvm
.
context
(
device
,
0
)
# launch the kernel.
n
=
1024
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
vbias
=
np
.
random
.
uniform
()
vscale
=
np
.
random
.
uniform
()
ftimer
=
fadd
.
time_evaluator
(
fadd
.
entry_name
,
ctx
,
number
=
10
)
tcost
=
ftimer
(
a
,
b
,
c
,
vbias
,
vscale
)
.
mean
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
()
*
vscale
+
vbias
,
rtol
=
1e-6
)
# one line to build the function.
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"skip because
%
s is not enabled.."
%
device
)
return
fadd
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
,
name
=
"myadd"
)
print
(
fadd
.
imported_modules
[
0
]
.
get_source
())
ctx
=
tvm
.
context
(
device
,
0
)
# launch the kernel.
n
=
1024
a
=
tvm
.
nd
.
array
((
np
.
random
.
uniform
(
size
=
n
)
*
256
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
((
np
.
random
.
uniform
(
size
=
n
)
*
256
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
ftimer
=
fadd
.
time_evaluator
(
fadd
.
entry_name
,
ctx
,
number
=
1
)
tcost
=
ftimer
(
a
,
b
,
c
)
.
mean
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
(),
rtol
=
1e-6
)
check_device
(
"opencl"
)
check_device
(
"metal"
)
check_device
(
"cuda"
)
check_device
(
"opencl"
)
check_device
(
"metal"
)
check_device
(
"cuda"
)
run
(
"float32"
)
run
(
"int32"
)
run
(
"int64"
)
run
(
"uint64"
)
if
__name__
==
"__main__"
:
test_add
()
test_log_pow_llvm
()
test_popcount_llvm
()
test_exp
()
test_add
()
tests/scripts/task_python_docs.sh
View file @
cf81f9f9
...
...
@@ -11,7 +11,7 @@ mv out docs/_build/html/jsdoc || exit -1
rm
-rf
python/tvm/
*
.pyc python/tvm/
*
/
*
.pyc
cd
docs
PYTHONPATH
=
../python make html
||
exit
-1
PYTHONPATH
=
`
pwd
`
/
../python make html
||
exit
-1
cd
_build/html
tar
czf docs.tgz
*
mv docs.tgz ../../../
tutorials/deployment/cross_compilation_and_rpc.py
View file @
cf81f9f9
...
...
@@ -101,7 +101,7 @@ from tvm.contrib import rpc, util
# same machine, for demonstration. This line can be omitted if we
# started an remote server.
#
server
=
rpc
.
Server
(
host
=
'0.0.0.0'
,
port
=
9090
)
server
=
rpc
.
Server
(
host
=
'0.0.0.0'
,
port
=
9090
,
use_popen
=
True
)
######################################################################
# Declare and Cross Compile Kernel on Local Machine
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment