Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f1d815cc
Unverified
Commit
f1d815cc
authored
Oct 06, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable bool type as storage type (#1853)
parent
ea07f740
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
139 additions
and
7 deletions
+139
-7
include/tvm/expr.h
+2
-0
include/tvm/runtime/packed_func.h
+11
-1
python/tvm/_ffi/runtime_ctypes.py
+9
-0
src/codegen/codegen_cuda.cc
+2
-0
src/codegen/codegen_metal.cc
+3
-0
src/codegen/codegen_opencl.cc
+3
-0
src/codegen/spirv/ir_builder.cc
+19
-2
src/lang/buffer.cc
+17
-0
src/pass/storage_flatten.cc
+8
-2
src/runtime/builtin_fp16.cc
+4
-1
src/runtime/ndarray.cc
+2
-0
tests/python/unittest/test_codegen_bool.py
+58
-0
tests/python/unittest/test_lang_basic.py
+1
-1
No files found.
include/tvm/expr.h
View file @
f1d815cc
...
...
@@ -56,6 +56,8 @@ inline TVMType Type2TVMType(Type t) {
// Get number of bytes considering vector type.
inline
int
GetVectorBytes
(
Type
dtype
)
{
int
data_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
// allow bool to exist
if
(
dtype
==
Bool
())
return
1
;
CHECK_EQ
(
data_bits
%
8
,
0U
)
<<
"Need to load/store by multiple of bytes"
;
return
data_bits
/
8
;
...
...
include/tvm/runtime/packed_func.h
View file @
f1d815cc
...
...
@@ -873,6 +873,9 @@ inline const char* TypeCode2Str(int type_code) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
TVMType
t
)
{
// NOLINT(*)
if
(
t
.
bits
==
1
&&
t
.
lanes
==
1
&&
t
.
code
==
kDLUInt
)
{
os
<<
"bool"
;
return
os
;
}
os
<<
TypeCode2Str
(
t
.
code
);
if
(
t
.
code
==
kHandle
)
return
os
;
os
<<
static_cast
<
int
>
(
t
.
bits
);
...
...
@@ -890,7 +893,9 @@ inline std::string TVMType2String(TVMType t) {
os
<<
t
;
return
os
.
str
();
#else
std
::
string
repr
=
""
;
if
(
t
.
bits
==
1
&&
t
.
lanes
==
1
&&
t
.
code
==
kDLUInt
)
{
return
"bool"
;
}
repr
+=
TypeCode2Str
(
t
.
code
);
if
(
t
.
code
==
kHandle
)
return
repr
;
repr
+=
std
::
to_string
(
static_cast
<
int
>
(
t
.
bits
));
...
...
@@ -920,6 +925,11 @@ inline TVMType String2TVMType(std::string s) {
t
.
code
=
kHandle
;
t
.
bits
=
64
;
// handle uses 64 bit by default.
scan
=
s
.
c_str
()
+
6
;
}
else
if
(
s
==
"bool"
)
{
t
.
code
=
kDLUInt
;
t
.
bits
=
1
;
t
.
lanes
=
1
;
return
t
;
}
else
{
scan
=
s
.
c_str
();
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
...
...
python/tvm/_ffi/runtime_ctypes.py
View file @
f1d815cc
...
...
@@ -48,6 +48,13 @@ class TVMType(ctypes.Structure):
super
(
TVMType
,
self
)
.
__init__
()
if
isinstance
(
type_str
,
np
.
dtype
):
type_str
=
str
(
type_str
)
if
type_str
==
"bool"
:
self
.
bits
=
1
self
.
type_code
=
1
self
.
lanes
=
1
return
arr
=
type_str
.
split
(
"x"
)
head
=
arr
[
0
]
self
.
lanes
=
int
(
arr
[
1
])
if
len
(
arr
)
>
1
else
1
...
...
@@ -73,6 +80,8 @@ class TVMType(ctypes.Structure):
def
__repr__
(
self
):
if
self
.
bits
==
1
and
self
.
lanes
==
1
:
return
"bool"
x
=
"
%
s
%
d"
%
(
TVMType
.
CODE2STR
[
self
.
type_code
],
self
.
bits
)
if
self
.
lanes
!=
1
:
x
+=
"x
%
d"
%
self
.
lanes
...
...
src/codegen/codegen_cuda.cc
View file @
f1d815cc
...
...
@@ -77,6 +77,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
if
(
!
fail
&&
(
lanes
>=
2
&&
lanes
<=
4
))
{
os
<<
lanes
;
return
;
}
}
else
if
(
t
==
Bool
())
{
os
<<
"bool"
;
return
;
}
else
if
(
t
.
is_uint
()
||
t
.
is_int
())
{
if
(
t
.
is_uint
())
{
if
(
t
.
lanes
()
!=
1
)
{
...
...
src/codegen/codegen_metal.cc
View file @
f1d815cc
...
...
@@ -141,6 +141,9 @@ void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<<
"do not yet support vector types"
;
os
<<
"void*"
;
return
;
}
if
(
t
==
Bool
())
{
os
<<
"bool"
;
return
;
}
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
...
...
src/codegen/codegen_opencl.cc
View file @
f1d815cc
...
...
@@ -80,6 +80,9 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<<
"do not yet support vector types"
;
os
<<
"void*"
;
return
;
}
if
(
t
==
Bool
())
{
os
<<
"bool"
;
return
;
}
bool
fail
=
false
;
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
...
...
src/codegen/spirv/ir_builder.cc
View file @
f1d815cc
...
...
@@ -438,8 +438,25 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
const
tvm
::
Type
&
from
=
value
.
stype
.
type
;
const
tvm
::
Type
&
to
=
dst_type
.
type
;
CHECK_EQ
(
from
.
lanes
(),
to
.
lanes
());
if
(
from
.
is_int
()
&&
to
.
is_int
())
{
if
(
from
==
Bool
())
{
if
(
to
.
is_int
())
{
return
Select
(
value
,
IntImm
(
dst_type
,
1
),
IntImm
(
dst_type
,
0
));
}
else
if
(
to
.
is_uint
())
{
return
Select
(
value
,
UIntImm
(
dst_type
,
1
),
UIntImm
(
dst_type
,
0
));
}
else
{
LOG
(
FATAL
)
<<
"cannot cast from "
<<
from
<<
" to "
<<
to
;
return
Value
();
}
}
else
if
(
to
==
Bool
())
{
if
(
from
.
is_int
())
{
return
NE
(
value
,
IntImm
(
value
.
stype
,
0
));
}
else
if
(
to
.
is_uint
())
{
return
NE
(
value
,
UIntImm
(
value
.
stype
,
0
));
}
else
{
LOG
(
FATAL
)
<<
"cannot cast from "
<<
from
<<
" to "
<<
to
;
return
Value
();
}
}
else
if
(
from
.
is_int
()
&&
to
.
is_int
())
{
return
MakeValue
(
spv
::
OpSConvert
,
dst_type
,
value
);
}
else
if
(
from
.
is_uint
()
&&
to
.
is_uint
())
{
return
MakeValue
(
spv
::
OpUConvert
,
dst_type
,
value
);
...
...
src/lang/buffer.cc
View file @
f1d815cc
...
...
@@ -260,25 +260,42 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
}
Expr
Buffer
::
vload
(
Array
<
Expr
>
begin
,
Type
dtype
)
const
{
// specially handle bool, stored as Int(8)
const
BufferNode
*
n
=
operator
->
();
CHECK
(
dtype
.
element_of
()
==
n
->
dtype
.
element_of
()
&&
dtype
.
lanes
()
%
n
->
dtype
.
lanes
()
==
0
)
<<
"Cannot load "
<<
dtype
<<
" from buffer of "
<<
n
->
dtype
;
if
(
dtype
==
Bool
())
{
return
ir
::
Cast
::
make
(
Bool
(),
ir
::
Load
::
make
(
Int
(
8
),
n
->
data
,
BufferOffset
(
n
,
begin
,
Int
(
8
)),
const_true
()));
}
else
{
return
ir
::
Load
::
make
(
dtype
,
n
->
data
,
BufferOffset
(
n
,
begin
,
dtype
),
const_true
(
dtype
.
lanes
()));
}
}
Stmt
Buffer
::
vstore
(
Array
<
Expr
>
begin
,
Expr
value
)
const
{
// specially handle bool, stored as Int(8)
const
BufferNode
*
n
=
operator
->
();
Type
dtype
=
value
.
type
();
CHECK
(
dtype
.
element_of
()
==
n
->
dtype
.
element_of
()
&&
dtype
.
lanes
()
%
n
->
dtype
.
lanes
()
==
0
)
<<
"Cannot load "
<<
dtype
<<
" from buffer of "
<<
n
->
dtype
;
if
(
value
.
type
()
==
Bool
())
{
return
ir
::
Store
::
make
(
n
->
data
,
ir
::
Cast
::
make
(
Int
(
8
),
value
),
BufferOffset
(
n
,
begin
,
Int
(
8
)),
const_true
());
}
else
{
return
ir
::
Store
::
make
(
n
->
data
,
value
,
BufferOffset
(
n
,
begin
,
dtype
),
const_true
(
dtype
.
lanes
()));
}
}
Buffer
Buffer
::
MakeStrideView
()
const
{
...
...
src/pass/storage_flatten.cc
View file @
f1d815cc
...
...
@@ -191,10 +191,16 @@ class StorageFlattener : public IRMutator {
buf_map_
[
key
].
released
=
true
;
Stmt
ret
;
Type
storage_type
=
e
.
buffer
->
dtype
;
// specially handle bool, lower its storage
// type to be Int(8)(byte)
if
(
storage_type
==
Bool
())
{
storage_type
=
Int
(
8
);
}
if
(
strides
.
size
()
!=
0
)
{
int
first_dim
=
0
;
ret
=
Allocate
::
make
(
e
.
buffer
->
data
,
e
.
buffer
->
d
type
,
e
.
buffer
->
data
,
storage_
type
,
{
arith
::
ComputeExpr
<
Mul
>
(
e
.
buffer
->
strides
[
first_dim
],
e
.
buffer
->
shape
[
first_dim
])},
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
}
else
{
...
...
@@ -203,7 +209,7 @@ class StorageFlattener : public IRMutator {
shape
.
push_back
(
make_const
(
Int
(
32
),
1
));
}
ret
=
Allocate
::
make
(
e
.
buffer
->
data
,
e
.
buffer
->
d
type
,
shape
,
e
.
buffer
->
data
,
storage_
type
,
shape
,
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
}
ret
=
AttrStmt
::
make
(
...
...
src/runtime/builtin_fp16.cc
View file @
f1d815cc
...
...
@@ -3,12 +3,14 @@
* \file builtin_fp16.cc
* \brief Functions for conversion between fp32 and fp16
*/
#include <builtin_fp16.h>
#include <tvm/runtime/c_runtime_api.h>
extern
"C"
{
// disable under msvc
#ifndef _MSC_VER
TVM_WEAK
uint16_t
__gnu_f2h_ieee
(
float
a
)
{
return
__truncXfYf2__
<
float
,
uint32_t
,
23
,
uint16_t
,
uint16_t
,
10
>
(
a
);
}
...
...
@@ -17,4 +19,5 @@ TVM_WEAK float __gnu_h2f_ieee(uint16_t a) {
return
__extendXfYf2__
<
uint16_t
,
uint16_t
,
10
,
float
,
uint32_t
,
23
>
(
a
);
}
#endif
}
src/runtime/ndarray.cc
View file @
f1d815cc
...
...
@@ -20,6 +20,8 @@ inline void VerifyDataType(DLDataType dtype) {
if
(
dtype
.
code
==
kDLFloat
)
{
CHECK_EQ
(
dtype
.
bits
%
8
,
0
);
}
else
{
// allow uint1 as a special flag for bool.
if
(
dtype
.
bits
==
1
&&
dtype
.
code
==
kDLUInt
)
return
;
CHECK_EQ
(
dtype
.
bits
%
8
,
0
);
}
CHECK_EQ
(
dtype
.
bits
&
(
dtype
.
bits
-
1
),
0
);
...
...
tests/python/unittest/test_codegen_bool.py
0 → 100644
View file @
f1d815cc
"""codegen related to bool types"""
import
tvm
import
numpy
as
np
def
test_cmp_load_store
():
n
=
32
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
>
B
(
*
i
),
name
=
'C'
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
*
i
:
tvm
.
all
(
C
(
*
i
),
A
(
*
i
)
>
1
),
name
=
"D"
)
def
check_llvm
():
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
s
=
tvm
.
create_schedule
(
D
.
op
)
xo
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
4
)
xo1
,
xo2
=
s
[
C
]
.
split
(
xo
,
factor
=
13
)
s
[
C
]
.
parallel
(
xo2
)
# BUILD and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
a_np
=
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
D
.
dtype
),
ctx
)
f
(
a
,
b
,
d
)
np
.
testing
.
assert_equal
(
d
.
asnumpy
(),
np
.
logical_and
(
a
.
asnumpy
()
>
b
.
asnumpy
(),
a
.
asnumpy
()
>
1
))
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
return
s
=
tvm
.
create_schedule
(
D
.
op
)
for
stage
in
[
C
,
D
]:
xo
,
xi
=
s
[
stage
]
.
split
(
stage
.
op
.
axis
[
0
],
factor
=
4
)
s
[
stage
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
stage
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.x"
))
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
],
device
)
a_np
=
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
D
.
dtype
),
ctx
)
f
(
a
,
b
,
d
)
np
.
testing
.
assert_equal
(
d
.
asnumpy
(),
np
.
logical_and
(
a
.
asnumpy
()
>
b
.
asnumpy
(),
a
.
asnumpy
()
>
1
))
check_llvm
()
for
device
in
[
"vulkan"
,
"opencl"
,
"cuda"
,
"rocm"
,
"metal"
]:
check_device
(
device
)
if
__name__
==
"__main__"
:
test_cmp_load_store
()
tests/python/unittest/test_lang_basic.py
View file @
f1d815cc
...
...
@@ -79,7 +79,7 @@ def test_dtype():
x
=
tvm
.
var
(
'x'
)
assert
x
.
dtype
==
'int32'
y
=
tvm
.
var
(
'y'
)
assert
(
x
>
y
)
.
dtype
==
'
uint1
'
assert
(
x
>
y
)
.
dtype
==
'
bool
'
def
test_any
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment