Unverified Commit f23ac969 by Orion34C Committed by GitHub

[CODEGEN] Support cuda tensorcore subbyte int data type in auto tensorcore (#4546)

* support cuda tensorcore subbyte int data type in auto tensorcore

* add lisence

* pass cpplint

* fix code review comments

* merge the int4/int1 codegen tutorial into the existing auto tensorcore tutorial

* using master's new API

* disable tuning when cuda is not enabled

* address cr comment

* do not run the tuning

* fix test failure

* fix cpplint error

* fix bool type reduction bug

* 1. fix a index bug 2. fix returned bytes value of int1/int4/uint4

* fix typo
parent 98e7709f
...@@ -230,7 +230,12 @@ class DataType { ...@@ -230,7 +230,12 @@ class DataType {
inline int GetVectorBytes(DataType dtype) { inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes(); int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist // allow bool to exist
if (dtype == DataType::Bool()) return 1; if (dtype == DataType::Bool() ||
dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
return 1;
}
CHECK_EQ(data_bits % 8, 0U) CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes"; << "Need to load/store by multiple of bytes";
return data_bits / 8; return data_bits / 8;
......
...@@ -1261,6 +1261,18 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; ...@@ -1261,6 +1261,18 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
*/ */
constexpr const char* tvm_mma_sync = "tvm_mma_sync"; constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*! /*!
* \brief tvm intrinsic for tensor core bmma_sync operators.
*
* void tvm_bmma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_bmma_sync = "tvm_bmma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators. * \brief tvm intrinsic for tensor core fill_fragment operators.
* *
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
......
...@@ -44,6 +44,11 @@ inline void VerifyDataType(DLDataType dtype) { ...@@ -44,6 +44,11 @@ inline void VerifyDataType(DLDataType dtype) {
} else { } else {
// allow uint1 as a special flag for bool. // allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return; if (dtype.bits == 1 && dtype.code == kDLUInt) return;
// allow int1/uint4/int4
else if (dtype.bits == 1 && dtype.code == kDLInt) return;
else if (dtype.bits == 4 && dtype.code == kDLUInt) return;
else if (dtype.bits == 4 && dtype.code == kDLInt) return;
else
CHECK_EQ(dtype.bits % 8, 0); CHECK_EQ(dtype.bits % 8, 0);
} }
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
......
...@@ -170,8 +170,13 @@ std::string CodeGenC::GetBufferRef( ...@@ -170,8 +170,13 @@ std::string CodeGenC::GetBufferRef(
} else { } else {
os << vid; os << vid;
} }
os << '['; os << "[(";
PrintExpr(index, os); PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << ']'; os << ']';
} else { } else {
// Buffer declared as vector type. // Buffer declared as vector type.
...@@ -205,8 +210,13 @@ std::string CodeGenC::GetBufferRef( ...@@ -205,8 +210,13 @@ std::string CodeGenC::GetBufferRef(
PrintType(t.element_of(), os); PrintType(t.element_of(), os);
os << "*)"; os << "*)";
} }
os << vid << " + "; os << vid << " + (";
PrintExpr(index, os); PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << "))[0]"; os << "))[0]";
} }
return os.str(); return os.str();
......
...@@ -144,6 +144,37 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -144,6 +144,37 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} }
} }
switch (t.bits()) { switch (t.bits()) {
case 1: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 8) {
os << "int8_t"; return;
} else if (t.lanes() == 16) {
os << "int16_t"; return;
} else if (t.lanes() == 32) {
os << "int"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 4) {
os << "int16_t"; return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int"; return;
} else if (t.lanes() == 16) {
os << "int2"; return;
} else if (t.lanes() == 32) {
os << "int4"; return;
} else if (t.lanes() == 64) {
os << "int8"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 8: { case 8: {
if (t.lanes() == 4) { if (t.lanes() == 4) {
// directly 4 8 bit int in integer. // directly 4 8 bit int in integer.
...@@ -182,7 +213,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -182,7 +213,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << "long"; break; os << "long"; break;
} }
} }
case 1: os << "int"; break;
default: fail = true; break; default: fail = true; break;
} }
if (!fail && lanes == 1) { if (!fail && lanes == 1) {
...@@ -371,6 +401,16 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { ...@@ -371,6 +401,16 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os); this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")"); os << "]" << ((i < 3) ? ", ": ")");
} }
} else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else { } else {
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
...@@ -410,8 +450,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { ...@@ -410,8 +450,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->dtype == DataType::Float(16) || CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8)) op->dtype == DataType::UInt(8) ||
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now"; op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else { } else {
CHECK(op->dtype == DataType::Float(16) || CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) || op->dtype == DataType::Float(32) ||
...@@ -425,6 +469,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { ...@@ -425,6 +469,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
stream << ' '; stream << ' ';
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
} }
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' '<< vid << '[' stream << ' '<< vid << '['
<< constant_size << "];\n"; << constant_size << "];\n";
} }
...@@ -552,6 +601,24 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, ...@@ -552,6 +601,24 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
std::stringstream type; std::stringstream type;
PrintType(t, type); PrintType(t, type);
std::string shape_str = fragment_shapes[variable]; std::string shape_str = fragment_shapes[variable];
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
if (t.is_int()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::s4";
} else if (t.bits() == 1) {
type << "nvcuda::wmma::experimental::precision::b1";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
} else if (t.is_uint()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::u4";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
}
}
if (scope == "wmma.matrix_a") { if (scope == "wmma.matrix_a") {
need_mma_h_ = true; need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable]; std::string layout_str = fragment_layouts[variable];
......
...@@ -184,7 +184,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -184,7 +184,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
IntImm(DataType::UInt(8), dtype.bits()) && IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes())); IntImm(DataType::UInt(16), dtype.lanes()));
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
}
// data field // data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) { arg_name + ".data", true)) {
...@@ -201,6 +205,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -201,6 +205,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
init_nest_.emplace_back(LetStmtNode::make( init_nest_.emplace_back(LetStmtNode::make(
v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) { for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
break;
}
std::ostringstream field_name; std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']'; field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k], Bind_(buffer->shape[k],
......
...@@ -138,7 +138,8 @@ class FragmentChecker : public StmtExprVisitor { ...@@ -138,7 +138,8 @@ class FragmentChecker : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final { void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync // Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { if (op->is_intrinsic(intrinsic::tvm_mma_sync) ||
op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
CHECK_EQ(op->args.size(), 8U); CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>(); const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>(); const VarNode* buffer_var_a = op->args[2].as<VarNode>();
......
...@@ -199,7 +199,11 @@ class MMAMatcher: public StmtVisitor { ...@@ -199,7 +199,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_a; BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a) if (!check_local_buffer_(load_a, &buffer_a)
|| !(buffer_a.dtype == DataType::Float(16) || || !(buffer_a.dtype == DataType::Float(16) ||
buffer_a.dtype == DataType::Int(8))) { buffer_a.dtype == DataType::Int(8) ||
buffer_a.dtype == DataType::UInt(8) ||
buffer_a.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false; return false;
} }
...@@ -208,7 +212,11 @@ class MMAMatcher: public StmtVisitor { ...@@ -208,7 +212,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_b; BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b) if (!check_local_buffer_(load_b, &buffer_b)
|| !(buffer_b.dtype == DataType::Float(16) || || !(buffer_b.dtype == DataType::Float(16) ||
buffer_b.dtype == DataType::Int(8))) { buffer_b.dtype == DataType::Int(8) ||
buffer_b.dtype == DataType::UInt(8) ||
buffer_b.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false; return false;
} }
...@@ -736,6 +744,17 @@ class BufferAnalyser : public StmtExprVisitor { ...@@ -736,6 +744,17 @@ class BufferAnalyser : public StmtExprVisitor {
warp_tile_.k == 16) { warp_tile_.k == 16) {
return true; return true;
} }
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 32) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 128) {
return true;
}
return false; return false;
} }
...@@ -869,10 +888,20 @@ class TensorCoreIRMutator : public StmtExprMutator { ...@@ -869,10 +888,20 @@ class TensorCoreIRMutator : public StmtExprMutator {
ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>(); ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();
auto mma_sync_call = auto mma_sync_call =
[&buffer_node_a, &buffer_node_b] [&buffer_node_a, &buffer_node_b, &ca, &cb]
(const Buffer &buffer) { (const Buffer &buffer) {
Buffer buffer_a(buffer_node_a); Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b); Buffer buffer_b(buffer_node_b);
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_bmma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
} else {
return EvaluateNode::make( return EvaluateNode::make(
CallNode::make(DataType::Handle(), CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync, intrinsic::tvm_mma_sync,
...@@ -881,6 +910,7 @@ class TensorCoreIRMutator : public StmtExprMutator { ...@@ -881,6 +910,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
buffer_b->data, buffer_b->elem_offset, buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset}, buffer->data, buffer->elem_offset},
CallNode::Intrinsic)); CallNode::Intrinsic));
}
}; };
auto call_add_c = auto call_add_c =
......
...@@ -56,6 +56,8 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'): ...@@ -56,6 +56,8 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
out_type = 'float' out_type = 'float'
elif dtype == 'int8': elif dtype == 'int8':
out_type = 'int' out_type = 'int'
elif dtype == 'int4' or dtype == 'int1':
out_type = 'int'
if (layout == 'NN'): if (layout == 'NN'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k)) return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
if (layout == 'NT'): if (layout == 'NT'):
...@@ -123,6 +125,12 @@ def test_gemm(N, L, M, dtype, layout): ...@@ -123,6 +125,12 @@ def test_gemm(N, L, M, dtype, layout):
if dtype == 'int8': if dtype == 'int8':
factor = 32 factor = 32
offset = 16 offset = 16
elif dtype == 'int4':
factor = 64
offset = 32
elif dtype == 'int1':
factor = 256
offset = 128
# create cache stages # create cache stages
AA = s.cache_read(A, "shared", [C]) AA = s.cache_read(A, "shared", [C])
...@@ -139,9 +147,9 @@ def test_gemm(N, L, M, dtype, layout): ...@@ -139,9 +147,9 @@ def test_gemm(N, L, M, dtype, layout):
cfg = autotvm.get_config() cfg = autotvm.get_config()
cfg.define_knob("bx", [2, 4, 8]) cfg.define_knob("bx", [2, 4, 8])
cfg.define_knob("by", [16, 32, 64]) cfg.define_knob("by", [8, 16, 32, 64])
cfg.define_knob("step_k", [8, 16, 32]) cfg.define_knob("step_k", [1, 2, 4, 8, 16, 32])
cfg.define_knob("v", [4, 8]) cfg.define_knob("v", [4, 8, 16, 32])
by = cfg['by'].val by = cfg['by'].val
bx = cfg['bx'].val bx = cfg['bx'].val
step_k = cfg['step_k'].val step_k = cfg['step_k'].val
...@@ -150,9 +158,17 @@ def test_gemm(N, L, M, dtype, layout): ...@@ -150,9 +158,17 @@ def test_gemm(N, L, M, dtype, layout):
# thread tile # thread tile
TX = 8 TX = 8
TY = 1 TY = 1
if dtype == 'int4' or dtype == 'int1':
TX = 2
# warp tile # warp tile
warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0 warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
warp_tile_k = 16 # it must be 16 warp_tile_k = 16 # it must be 16 for fp16/int8 data type
if dtype == 'int4':
warp_tile_m = 8
warp_tile_k = 32
elif dtype == 'int1':
warp_tile_m = 8
warp_tile_k = 128
# block tile # block tile
tile_x = bx * TX tile_x = bx * TX
tile_y = by * TY tile_y = by * TY
...@@ -219,6 +235,10 @@ def test_gemm(N, L, M, dtype, layout): ...@@ -219,6 +235,10 @@ def test_gemm(N, L, M, dtype, layout):
# and run the kernel to compare with numpy to check whether the results are correct. # and run the kernel to compare with numpy to check whether the results are correct.
# check whether the gpu has tensorcore # check whether the gpu has tensorcore
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
sys.exit(0)
ctx = tvm.gpu() ctx = tvm.gpu()
if not nvcc.have_tensorcore(ctx.compute_version): if not nvcc.have_tensorcore(ctx.compute_version):
print('the gpu has no tensorcore, skipping...') print('the gpu has no tensorcore, skipping...')
...@@ -234,6 +254,15 @@ if len(sys.argv) >= 5: ...@@ -234,6 +254,15 @@ if len(sys.argv) >= 5:
if len(sys.argv) >= 6: if len(sys.argv) >= 6:
layout = sys.argv[5] layout = sys.argv[5]
# check whether current gpu arch support support current dtype's wmma codegen
cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
major, minor= nvcc.parse_compute_version(cuda_compute_capability)
if dtype == 'int8':
assert(major == 7 and minor >= 2)
elif dtype == 'int4' or dtype == 'int1':
# int4/int1 only support layout TN
assert(major == 7 and minor == 5 and layout == 'TN')
def tune_and_evaluate(M, N, L, dtype, layout): def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda') task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
print(task.config_space) print(task.config_space)
...@@ -305,6 +334,42 @@ def tune_and_evaluate(M, N, L, dtype, layout): ...@@ -305,6 +334,42 @@ def tune_and_evaluate(M, N, L, dtype, layout):
c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T) c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
elif (layout == "TT"): elif (layout == "TT"):
c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T) c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
elif dtype == 'int4':
c_np_type = np.int32
a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
# "TN"
c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
a_np = np.zeros(shape=(N, int(L/8)), dtype = np.int32)
b_np = np.zeros(shape=(M, int(L/8)), dtype = np.int32)
# a_np --> col_major
for i in range(N):
for j in range(int(L/8)):
for k in range(8):
a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
# b_np --> row_major
for i in range(M):
for j in range(int(L/8)):
for k in range(8):
b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
elif dtype == 'int1':
c_np_type = np.int32
a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
# "TN"
c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
a_np = np.zeros(shape=(N, int(L/32)), dtype = np.int32)
b_np = np.zeros(shape=(M, int(L/32)), dtype = np.int32)
for i in range(N):
for j in range(int(L/32)):
for k in range(32):
a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xf) << (31 - k))
for i in range(M):
for j in range(int(L/32)):
for k in range(32):
b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xf) << (31 - k))
c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx) c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
a_tvm = tvm.nd.array(a_np, ctx=ctx) a_tvm = tvm.nd.array(a_np, ctx=ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment