Unverified Commit f23ac969 by Orion34C Committed by GitHub

[CODEGEN] Support cuda tensorcore subbyte int data type in auto tensorcore (#4546)

* support cuda tensorcore subbyte int data type in auto tensorcore

* add lisence

* pass cpplint

* fix code review comments

* merge the int4/int1 codegen tutorial into the existing auto tensorcore tutorial

* using master's new API

* disable tuning when cuda is not enabled

* address cr comment

* do not run the tuning

* fix test failure

* fix cpplint error

* fix bool type reduction bug

* 1. fix a index bug 2. fix returned bytes value of int1/int4/uint4

* fix typo
parent 98e7709f
......@@ -230,7 +230,12 @@ class DataType {
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool()) return 1;
if (dtype == DataType::Bool() ||
dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
return 1;
}
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
......
......@@ -1261,6 +1261,18 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
*/
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*!
* \brief tvm intrinsic for tensor core bmma_sync operators.
*
* void tvm_bmma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_bmma_sync = "tvm_bmma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators.
*
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
......
......@@ -44,7 +44,12 @@ inline void VerifyDataType(DLDataType dtype) {
} else {
// allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return;
CHECK_EQ(dtype.bits % 8, 0);
// allow int1/uint4/int4
else if (dtype.bits == 1 && dtype.code == kDLInt) return;
else if (dtype.bits == 4 && dtype.code == kDLUInt) return;
else if (dtype.bits == 4 && dtype.code == kDLInt) return;
else
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}
......
......@@ -170,8 +170,13 @@ std::string CodeGenC::GetBufferRef(
} else {
os << vid;
}
os << '[';
os << "[(";
PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << ']';
} else {
// Buffer declared as vector type.
......@@ -205,8 +210,13 @@ std::string CodeGenC::GetBufferRef(
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
os << vid << " + (";
PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << "))[0]";
}
return os.str();
......
......@@ -144,6 +144,37 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
}
switch (t.bits()) {
case 1: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 8) {
os << "int8_t"; return;
} else if (t.lanes() == 16) {
os << "int16_t"; return;
} else if (t.lanes() == 32) {
os << "int"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 4) {
os << "int16_t"; return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int"; return;
} else if (t.lanes() == 16) {
os << "int2"; return;
} else if (t.lanes() == 32) {
os << "int4"; return;
} else if (t.lanes() == 64) {
os << "int8"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 8: {
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
......@@ -182,7 +213,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << "long"; break;
}
}
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) {
......@@ -371,6 +401,16 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else {
CodeGenC::VisitExpr_(op, os);
}
......@@ -410,8 +450,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
......@@ -425,6 +469,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
......@@ -552,6 +601,24 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
if (t.is_int()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::s4";
} else if (t.bits() == 1) {
type << "nvcuda::wmma::experimental::precision::b1";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
} else if (t.is_uint()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::u4";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
}
}
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
......
......@@ -184,7 +184,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
......@@ -201,6 +205,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
init_nest_.emplace_back(LetStmtNode::make(
v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
break;
}
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
......
......@@ -138,7 +138,8 @@ class FragmentChecker : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
if (op->is_intrinsic(intrinsic::tvm_mma_sync) ||
op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>();
......
......@@ -199,7 +199,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a)
|| !(buffer_a.dtype == DataType::Float(16) ||
buffer_a.dtype == DataType::Int(8))) {
buffer_a.dtype == DataType::Int(8) ||
buffer_a.dtype == DataType::UInt(8) ||
buffer_a.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false;
}
......@@ -208,7 +212,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b)
|| !(buffer_b.dtype == DataType::Float(16) ||
buffer_b.dtype == DataType::Int(8))) {
buffer_b.dtype == DataType::Int(8) ||
buffer_b.dtype == DataType::UInt(8) ||
buffer_b.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false;
}
......@@ -736,6 +744,17 @@ class BufferAnalyser : public StmtExprVisitor {
warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 32) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 128) {
return true;
}
return false;
}
......@@ -869,18 +888,29 @@ class TensorCoreIRMutator : public StmtExprMutator {
ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();
auto mma_sync_call =
[&buffer_node_a, &buffer_node_b]
[&buffer_node_a, &buffer_node_b, &ca, &cb]
(const Buffer &buffer) {
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_bmma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
} else {
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
}
};
auto call_add_c =
......
......@@ -56,6 +56,8 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
out_type = 'float'
elif dtype == 'int8':
out_type = 'int'
elif dtype == 'int4' or dtype == 'int1':
out_type = 'int'
if (layout == 'NN'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
if (layout == 'NT'):
......@@ -123,6 +125,12 @@ def test_gemm(N, L, M, dtype, layout):
if dtype == 'int8':
factor = 32
offset = 16
elif dtype == 'int4':
factor = 64
offset = 32
elif dtype == 'int1':
factor = 256
offset = 128
# create cache stages
AA = s.cache_read(A, "shared", [C])
......@@ -139,9 +147,9 @@ def test_gemm(N, L, M, dtype, layout):
cfg = autotvm.get_config()
cfg.define_knob("bx", [2, 4, 8])
cfg.define_knob("by", [16, 32, 64])
cfg.define_knob("step_k", [8, 16, 32])
cfg.define_knob("v", [4, 8])
cfg.define_knob("by", [8, 16, 32, 64])
cfg.define_knob("step_k", [1, 2, 4, 8, 16, 32])
cfg.define_knob("v", [4, 8, 16, 32])
by = cfg['by'].val
bx = cfg['bx'].val
step_k = cfg['step_k'].val
......@@ -150,9 +158,17 @@ def test_gemm(N, L, M, dtype, layout):
# thread tile
TX = 8
TY = 1
if dtype == 'int4' or dtype == 'int1':
TX = 2
# warp tile
warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
warp_tile_k = 16 # it must be 16
warp_tile_k = 16 # it must be 16 for fp16/int8 data type
if dtype == 'int4':
warp_tile_m = 8
warp_tile_k = 32
elif dtype == 'int1':
warp_tile_m = 8
warp_tile_k = 128
# block tile
tile_x = bx * TX
tile_y = by * TY
......@@ -219,6 +235,10 @@ def test_gemm(N, L, M, dtype, layout):
# and run the kernel to compare with numpy to check whether the results are correct.
# check whether the gpu has tensorcore
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
sys.exit(0)
ctx = tvm.gpu()
if not nvcc.have_tensorcore(ctx.compute_version):
print('the gpu has no tensorcore, skipping...')
......@@ -234,6 +254,15 @@ if len(sys.argv) >= 5:
if len(sys.argv) >= 6:
layout = sys.argv[5]
# check whether current gpu arch support support current dtype's wmma codegen
cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
major, minor= nvcc.parse_compute_version(cuda_compute_capability)
if dtype == 'int8':
assert(major == 7 and minor >= 2)
elif dtype == 'int4' or dtype == 'int1':
# int4/int1 only support layout TN
assert(major == 7 and minor == 5 and layout == 'TN')
def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
print(task.config_space)
......@@ -305,6 +334,42 @@ def tune_and_evaluate(M, N, L, dtype, layout):
c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
elif (layout == "TT"):
c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
elif dtype == 'int4':
c_np_type = np.int32
a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
# "TN"
c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
a_np = np.zeros(shape=(N, int(L/8)), dtype = np.int32)
b_np = np.zeros(shape=(M, int(L/8)), dtype = np.int32)
# a_np --> col_major
for i in range(N):
for j in range(int(L/8)):
for k in range(8):
a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
# b_np --> row_major
for i in range(M):
for j in range(int(L/8)):
for k in range(8):
b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
elif dtype == 'int1':
c_np_type = np.int32
a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
# "TN"
c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
a_np = np.zeros(shape=(N, int(L/32)), dtype = np.int32)
b_np = np.zeros(shape=(M, int(L/32)), dtype = np.int32)
for i in range(N):
for j in range(int(L/32)):
for k in range(32):
a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xf) << (31 - k))
for i in range(M):
for j in range(int(L/32)):
for k in range(32):
b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xf) << (31 - k))
c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
a_tvm = tvm.nd.array(a_np, ctx=ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment