Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
146ebc5e
Commit
146ebc5e
authored
Aug 02, 2018
by
Tatsuya Nishiyama
Committed by
Tianqi Chen
Aug 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TVM][CUDA] NVIDIA GPU Int8 Support (#1503)
parent
217792ec
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
10 deletions
+81
-10
src/codegen/codegen_cuda.cc
+15
-5
src/codegen/codegen_cuda.h
+3
-1
src/codegen/opt/build_cuda_on.cc
+0
-3
tests/python/unittest/test_codegen_cuda.py
+63
-1
No files found.
src/codegen/codegen_cuda.cc
View file @
146ebc5e
...
@@ -34,6 +34,10 @@ std::string CodeGenCUDA::Finish() {
...
@@ -34,6 +34,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream
<<
"#include <cuda_fp16.h>
\n
"
;
decl_stream
<<
"#include <cuda_fp16.h>
\n
"
;
}
}
if
(
enable_int8_
)
{
decl_stream
<<
"#include <sm_61_intrinsics.h>
\n
"
;
}
return
CodeGenC
::
Finish
();
return
CodeGenC
::
Finish
();
}
}
...
@@ -81,13 +85,19 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
...
@@ -81,13 +85,19 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
os
<<
"unsigned "
;
os
<<
"unsigned "
;
}
}
}
}
if
(
t
.
bits
()
==
8
&&
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
os
<<
"int"
;
return
;
}
switch
(
t
.
bits
())
{
switch
(
t
.
bits
())
{
case
8
:
{
case
8
:
{
if
(
!
t
.
is_uint
()
&&
t
.
lanes
()
==
1
)
{
if
(
t
.
lanes
()
==
4
)
{
// directly 4 8 bit int in integer.
enable_int8_
=
true
;
os
<<
"char4"
;
return
;
}
else
if
(
t
.
lanes
()
==
8
)
{
enable_int8_
=
true
;
os
<<
"int2"
;
return
;
}
else
if
(
t
.
lanes
()
==
16
)
{
enable_int8_
=
true
;
os
<<
"int4"
;
return
;
}
else
if
(
!
t
.
is_uint
()
&&
t
.
lanes
()
==
1
)
{
os
<<
"signed char"
;
break
;
os
<<
"signed char"
;
break
;
}
else
{
}
else
{
os
<<
"char"
;
break
;
os
<<
"char"
;
break
;
...
...
src/codegen/codegen_cuda.h
View file @
146ebc5e
...
@@ -20,7 +20,7 @@ class CodeGenCUDA final : public CodeGenC {
...
@@ -20,7 +20,7 @@ class CodeGenCUDA final : public CodeGenC {
void
Init
(
bool
output_ssa
);
void
Init
(
bool
output_ssa
);
void
AddFunction
(
LoweredFunc
f
);
void
AddFunction
(
LoweredFunc
f
);
std
::
string
Finish
();
std
::
string
Finish
();
bool
need_include_path
()
{
return
enable_fp16_
;
}
bool
need_include_path
()
{
return
(
enable_fp16_
||
enable_int8_
)
;
}
// override behavior
// override behavior
void
VisitStmt_
(
const
ir
::
For
*
op
)
final
;
void
VisitStmt_
(
const
ir
::
For
*
op
)
final
;
void
PrintStorageSync
(
const
Call
*
op
)
final
;
void
PrintStorageSync
(
const
Call
*
op
)
final
;
...
@@ -49,6 +49,8 @@ class CodeGenCUDA final : public CodeGenC {
...
@@ -49,6 +49,8 @@ class CodeGenCUDA final : public CodeGenC {
std
::
string
vid_global_barrier_expect_
;
std
::
string
vid_global_barrier_expect_
;
// whether enable fp16
// whether enable fp16
bool
enable_fp16_
{
false
};
bool
enable_fp16_
{
false
};
// whether enable int8
bool
enable_int8_
{
false
};
};
};
}
// namespace codegen
}
// namespace codegen
...
...
src/codegen/opt/build_cuda_on.cc
View file @
146ebc5e
...
@@ -64,7 +64,6 @@ std::string FindCUDAIncludePath() {
...
@@ -64,7 +64,6 @@ std::string FindCUDAIncludePath() {
std
::
string
NVRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
std
::
string
NVRTCCompile
(
const
std
::
string
&
code
,
bool
include_path
=
false
)
{
std
::
vector
<
std
::
string
>
compile_params
;
std
::
vector
<
std
::
string
>
compile_params
;
std
::
vector
<
const
char
*>
param_cstrings
{};
std
::
vector
<
const
char
*>
param_cstrings
{};
int
num_options
=
0
;
nvrtcProgram
prog
;
nvrtcProgram
prog
;
cudaDeviceProp
device_prop
;
cudaDeviceProp
device_prop
;
std
::
string
cc
=
"30"
;
std
::
string
cc
=
"30"
;
...
@@ -78,13 +77,11 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
...
@@ -78,13 +77,11 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
}
}
compile_params
.
push_back
(
"-arch=compute_"
+
cc
);
compile_params
.
push_back
(
"-arch=compute_"
+
cc
);
num_options
++
;
if
(
include_path
)
{
if
(
include_path
)
{
std
::
string
include_option
=
"--include-path="
+
FindCUDAIncludePath
();
std
::
string
include_option
=
"--include-path="
+
FindCUDAIncludePath
();
compile_params
.
push_back
(
include_option
);
compile_params
.
push_back
(
include_option
);
num_options
++
;
}
}
for
(
const
auto
&
string
:
compile_params
)
{
for
(
const
auto
&
string
:
compile_params
)
{
...
...
tests/python/unittest/test_codegen_cuda.py
View file @
146ebc5e
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm.contrib.nvcc
import
have_fp16
from
tvm.contrib.nvcc
import
have_fp16
,
have_int8
from
tvm.contrib
import
nvcc
def
test_cuda_vectorize_add
():
def
test_cuda_vectorize_add
():
num_thread
=
8
num_thread
=
8
...
@@ -11,6 +12,9 @@ def test_cuda_vectorize_add():
...
@@ -11,6 +12,9 @@ def test_cuda_vectorize_add():
if
dtype
==
"float16"
and
not
have_fp16
(
tvm
.
gpu
(
0
)
.
compute_version
):
if
dtype
==
"float16"
and
not
have_fp16
(
tvm
.
gpu
(
0
)
.
compute_version
):
print
(
"skip because gpu does not support fp16"
)
print
(
"skip because gpu does not support fp16"
)
return
return
if
dtype
==
"int8"
and
not
have_int8
(
tvm
.
gpu
(
0
)
.
compute_version
):
print
(
"skip because gpu does not support int8"
)
return
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"
%
sx
%
d"
%
(
dtype
,
lanes
))
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"
%
sx
%
d"
%
(
dtype
,
lanes
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
A
[
i
]
+
tvm
.
const
(
1
,
A
.
dtype
),
name
=
'B'
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
A
[
i
]
+
tvm
.
const
(
1
,
A
.
dtype
),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
...
@@ -27,6 +31,64 @@ def test_cuda_vectorize_add():
...
@@ -27,6 +31,64 @@ def test_cuda_vectorize_add():
check_cuda
(
"float32"
,
64
,
2
)
check_cuda
(
"float32"
,
64
,
2
)
check_cuda
(
"float16"
,
64
,
2
)
check_cuda
(
"float16"
,
64
,
2
)
check_cuda
(
"int8"
,
64
,
4
)
def
test_cuda_multiply_add
():
num_thread
=
8
def
check_cuda
(
dtype
,
n
,
lanes
):
if
not
tvm
.
gpu
(
0
)
.
exist
or
not
tvm
.
module
.
enabled
(
"cuda"
):
print
(
"skip because cuda is not enabled.."
)
return
if
dtype
==
"int8"
and
not
have_int8
(
tvm
.
gpu
(
0
)
.
compute_version
):
print
(
"skip because gpu does not support int8"
)
return
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"
%
sx
%
d"
%
(
dtype
,
lanes
))
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
,
dtype
=
"
%
sx
%
d"
%
(
dtype
,
lanes
))
C
=
tvm
.
placeholder
((
n
,),
name
=
'C'
,
dtype
=
"int32"
)
D
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
call_pure_extern
(
"int32"
,
"__dp4a"
,
A
[
i
],
B
[
i
],
C
[
i
]),
name
=
'D'
)
s
=
tvm
.
create_schedule
(
D
.
op
)
xo
,
xi
=
s
[
D
]
.
split
(
D
.
op
.
axis
[
0
],
factor
=
num_thread
)
s
[
D
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
D
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.x"
))
fun
=
tvm
.
build
(
s
,
[
A
,
B
,
C
,
D
],
"cuda"
)
np_a
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
(
n
,
lanes
))
np_b
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
(
n
,
lanes
))
np_c
=
np
.
random
.
randint
(
low
=
0
,
high
=
127
,
size
=
(
n
,))
np_d
=
[
sum
(
x
*
y
)
+
z
for
x
,
y
,
z
in
zip
(
np_a
,
np_b
,
np_c
)]
ctx
=
tvm
.
gpu
(
0
)
a
=
tvm
.
nd
.
empty
((
n
,),
A
.
dtype
,
ctx
)
.
copyfrom
(
np_a
)
b
=
tvm
.
nd
.
empty
((
n
,),
B
.
dtype
,
ctx
)
.
copyfrom
(
np_b
)
c
=
tvm
.
nd
.
empty
((
n
,),
C
.
dtype
,
ctx
)
.
copyfrom
(
np_c
)
d
=
tvm
.
nd
.
empty
((
n
,),
D
.
dtype
,
ctx
)
fun
(
a
,
b
,
c
,
d
)
np
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
np_d
)
check_cuda
(
"int8"
,
64
,
4
)
def
test_cuda_vectorize_load
():
num_thread
=
8
def
check_cuda
(
dtype
,
n
,
lanes
):
if
not
tvm
.
gpu
(
0
)
.
exist
or
not
tvm
.
module
.
enabled
(
"cuda"
):
print
(
"skip because cuda is not enabled.."
)
return
ctx
=
tvm
.
gpu
(
0
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"
%
sx
%
d"
%
(
dtype
,
lanes
))
B
=
tvm
.
compute
((
n
,),
lambda
i
:
A
[
i
],
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
bx
,
tx
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
num_thread
)
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
fun
=
tvm
.
build
(
s
,
[
A
,
B
],
"cuda"
,
name
=
"vector_load"
)
np_a
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
(
n
,
lanes
))
a
=
tvm
.
nd
.
empty
((
n
,),
A
.
dtype
,
ctx
)
.
copyfrom
(
np_a
)
b
=
tvm
.
nd
.
empty
((
n
,),
B
.
dtype
,
ctx
)
fun
(
a
,
b
)
np
.
testing
.
assert_allclose
(
a
.
asnumpy
(),
b
.
asnumpy
())
check_cuda
(
"int8"
,
64
,
8
)
check_cuda
(
"int8"
,
64
,
16
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cuda_vectorize_add
()
test_cuda_vectorize_add
()
test_cuda_multiply_add
()
test_cuda_load_store
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment