Unverified Commit f1d815cc by Tianqi Chen Committed by GitHub

Enable bool type as storage type (#1853)

parent ea07f740
......@@ -56,6 +56,8 @@ inline TVMType Type2TVMType(Type t) {
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
......
......@@ -873,6 +873,9 @@ inline const char* TypeCode2Str(int type_code) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
os << TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
......@@ -890,7 +893,9 @@ inline std::string TVMType2String(TVMType t) {
os << t;
return os.str();
#else
std::string repr = "";
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
return "bool";
}
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
......@@ -920,6 +925,11 @@ inline TVMType String2TVMType(std::string s) {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
......
......@@ -48,6 +48,13 @@ class TVMType(ctypes.Structure):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str == "bool":
self.bits = 1
self.type_code = 1
self.lanes = 1
return
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
......@@ -73,6 +80,8 @@ class TVMType(ctypes.Structure):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
......
......@@ -77,6 +77,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
} else if (t == Bool()) {
os << "bool"; return;
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
if (t.lanes() != 1) {
......
......@@ -141,6 +141,9 @@ void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<< "do not yet support vector types";
os << "void*"; return;
}
if (t == Bool()) {
os << "bool"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
......
......@@ -80,6 +80,9 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<< "do not yet support vector types";
os << "void*"; return;
}
if (t == Bool()) {
os << "bool"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
......
......@@ -438,8 +438,25 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
const tvm::Type& from = value.stype.type;
const tvm::Type& to = dst_type.type;
CHECK_EQ(from.lanes(), to.lanes());
if (from.is_int() && to.is_int()) {
if (from == Bool()) {
if (to.is_int()) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (to == Bool()) {
if (from.is_int()) {
return NE(value, IntImm(value.stype, 0));
} else if (to.is_uint()) {
return NE(value, UIntImm(value.stype, 0));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (from.is_int() && to.is_int()) {
return MakeValue(spv::OpSConvert, dst_type, value);
} else if (from.is_uint() && to.is_uint()) {
return MakeValue(spv::OpUConvert, dst_type, value);
......
......@@ -260,25 +260,42 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
}
Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
// specially handle bool, stored as Int(8)
const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (dtype == Bool()) {
return ir::Cast::make(
Bool(),
ir::Load::make(
Int(8), n->data, BufferOffset(n, begin, Int(8)),
const_true()));
} else {
return ir::Load::make(
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}
Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
// specially handle bool, stored as Int(8)
const BufferNode* n = operator->();
Type dtype = value.type();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (value.type() == Bool()) {
return ir::Store::make(n->data,
ir::Cast::make(Int(8), value),
BufferOffset(n, begin, Int(8)),
const_true());
} else {
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}
Buffer Buffer::MakeStrideView() const {
......
......@@ -191,10 +191,16 @@ class StorageFlattener : public IRMutator {
buf_map_[key].released = true;
Stmt ret;
Type storage_type = e.buffer->dtype;
// specially handle bool, lower its storage
// type to be Int(8)(byte)
if (storage_type == Bool()) {
storage_type = Int(8);
}
if (strides.size() != 0) {
int first_dim = 0;
ret = Allocate::make(
e.buffer->data, e.buffer->dtype,
e.buffer->data, storage_type,
{arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
......@@ -203,7 +209,7 @@ class StorageFlattener : public IRMutator {
shape.push_back(make_const(Int(32), 1));
}
ret = Allocate::make(
e.buffer->data, e.buffer->dtype, shape,
e.buffer->data, storage_type, shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt::make(
......
......@@ -3,12 +3,14 @@
* \file builtin_fp16.cc
* \brief Functions for conversion between fp32 and fp16
*/
#include <builtin_fp16.h>
#include <tvm/runtime/c_runtime_api.h>
extern "C" {
// disable under msvc
#ifndef _MSC_VER
TVM_WEAK uint16_t __gnu_f2h_ieee(float a) {
return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a);
}
......@@ -17,4 +19,5 @@ TVM_WEAK float __gnu_h2f_ieee(uint16_t a) {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a);
}
#endif
}
......@@ -20,6 +20,8 @@ inline void VerifyDataType(DLDataType dtype) {
if (dtype.code == kDLFloat) {
CHECK_EQ(dtype.bits % 8, 0);
} else {
// allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return;
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
......
"""codegen related to bool types"""
import tvm
import numpy as np
def test_cmp_load_store():
n = 32
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D")
def check_llvm():
if not tvm.module.enabled("llvm"):
return
s = tvm.create_schedule(D.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
# BUILD and invoke the kernel.
f = tvm.build(s, [A, B, D], "llvm")
ctx = tvm.cpu(0)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
s = tvm.create_schedule(D.op)
for stage in [C, D]:
xo, xi = s[stage].split(stage.op.axis[0], factor=4)
s[stage].bind(xo, tvm.thread_axis("blockIdx.x"))
s[stage].bind(xi, tvm.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B, D], device)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
check_llvm()
for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]:
check_device(device)
if __name__ == "__main__":
test_cmp_load_store()
......@@ -79,7 +79,7 @@ def test_dtype():
x = tvm.var('x')
assert x.dtype == 'int32'
y = tvm.var('y')
assert (x > y).dtype == 'uint1'
assert (x > y).dtype == 'bool'
def test_any():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment