Commit 13b28566 by Benjamin Tu Committed by Jared Roesch

[VTA][Chisel] TSIM VTA Source Refactor (#4163)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo
parent 07606e4c
...@@ -22,21 +22,31 @@ package accel ...@@ -22,21 +22,31 @@ package accel
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
import vta.dpi._ import vta.dpi._
import vta.core._
import vta.util.config._
import vta.shell._
class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute /** Compute
* *
* Bit Slice GEMM: * Bit Slice GEMM:
* *
* 1. Wait for launch to be asserted * 1. Wait for launch to be asserted
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address * 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value * 3. Wait for the value
* 4. Increment read-address for next value * 4. Increment read-address for next value
* 5. Wait for sliced accumulator * 5. Repeat until all inp1 data have been read
* 6. Check if counter (cnt) is equal to length process,
otherwise goto step 2 * 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
* 7. Check if reset slice accumulator * 7. Wait for the value
* 8. Wait for overall accumulator * 8. Increment read-address for next value
* 8. Issue a write request for 8-byte value at out_baddr address * 9. Repeat until all inp2 data have been read
* 10. Wait for output to be calculated
* 11. Issue a write request for 8-byte value at out_baddr address
* 12. Increment write-address for next value to write
* 13. Check if counter (cntout) is equal to length to asser finish,
otherwise go to step 11
*/ */
class Compute(implicit config: AccelConfig) extends Module { class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W))) val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster val mem = new VTAMemDPIMaster
}) })
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7) implicit val p: Parameters = new TestConfig
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle) val state = RegInit(sIdle)
val shift = io.vals(0) val shift = io.vals(0)
val length = io.vals(1) val length = io.vals(1)
val rstAccum = io.vals(2) val rstAccum = io.vals(2)
val startDot = io.vals(3) val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W)) val cycles = RegInit(0.U(config.regBits.W))
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits)) val mvc = Module(new MatrixVectorMultiplication)
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits)) val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
val cnt = Reg(UInt(config.regBits.W)) val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
val cntwgt = Reg(UInt(config.regBits.W))
val cntinp = Reg(UInt(config.regBits.W))
val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W)) val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W)) val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W)) val waddr = Reg(UInt(config.ptrBits.W))
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
switch (state) { switch (state) {
is (sIdle) { is (sIdle) {
...@@ -73,7 +88,14 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -73,7 +88,14 @@ class Compute(implicit config: AccelConfig) extends Module {
} }
is (sReadAData) { is (sReadAData) {
when (io.mem.rd.valid) { when (io.mem.rd.valid) {
state := sReadADone
}
}
is (sReadADone) {
when (cntwgt === (length * length) - 1.U) {
state := sReadBReq state := sReadBReq
} .otherwise {
state := sReadAReq
} }
} }
is (sReadBReq) { is (sReadBReq) {
...@@ -81,6 +103,23 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -81,6 +103,23 @@ class Compute(implicit config: AccelConfig) extends Module {
} }
is (sReadBData) { is (sReadBData) {
when (io.mem.rd.valid) { when (io.mem.rd.valid) {
state := sReadBDone
}
}
is (sReadBDone) {
when (cntinp === length-1.U) {
state := sInpDone
} .otherwise {
state := sReadBReq
}
}
// Both input is processed
is (sInpDone) {
state := sWait
}
// Wait for computation
is (sWait) {
when (accum.io.ready) {
state := sWriteReq state := sWriteReq
} }
} }
...@@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
state := sWriteData state := sWriteData
} }
is (sWriteData) { is (sWriteData) {
when (cnt === (length - 1.U)) { state := sWriteDone
}
is (sWriteDone) {
when (cntout === (length - 1.U)) {
state := sIdle state := sIdle
} .otherwise { } .otherwise {
state := sReadAReq state := sWriteReq
} }
} }
} }
val last = state === sWriteData && cnt === (length - 1.U) val last = state === sWriteDone && cntout === (length - 1.U)
// cycle counter // cycle counter
when (state === sIdle) { when (state === sIdle) {
...@@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
raddr1 := io.ptrs(0) raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1) raddr2 := io.ptrs(1)
waddr := io.ptrs(2) waddr := io.ptrs(2)
} .elsewhen (state === sWriteData) { // increment input array by 1-byte } .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U raddr1 := raddr1 + 1.U
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U raddr2 := raddr2 + 1.U
waddr := waddr } .elsewhen (state === sWriteDone) {
waddr := waddr + 4.U // writing 4 bytes
} }
// create request // create request
...@@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {
// read // read
when (state === sReadAData && io.mem.rd.valid) { when (state === sReadAData && io.mem.rd.valid) {
reg1 := io.mem.rd.bits(7, 0) reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
} }
when (state === sReadBData && io.mem.rd.valid) { when (state === sReadBData && io.mem.rd.valid) {
reg2 := io.mem.rd.bits(7, 0) reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
} }
io.mem.rd.ready := state === sReadAData | state === sReadBData io.mem.rd.ready := state === sReadAData | state === sReadBData
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.bits <> reg1
mvc.io.inp.data.bits <> reg2
// Modify when shift operation is supported
mvc.io.reset := false.B
mvc.io.acc_i.data.valid := true.B
for (i <- 0 until p(CoreKey).blockOut) {
mvc.io.acc_i.data.bits(0)(i) := 0.U
}
accum.io.in := mvc.io.acc_o.data.bits
val sliceAccum = Module(new Accumulator(63)) accum.io.shift := shift
val overallAccum = Module(new Accumulator(64)) accum.io.clear := rstAccum
accum.io.valid := mvc.io.acc_o.data.valid
sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
sliceAccum.io.in := reg1 * reg2
sliceAccum.io.clear := startDot
overallAccum.io.clear := rstAccum
overallAccum.io.valid := last // last element has been processed
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
// write // write
io.mem.wr.valid := overallAccum.io.ready io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := overallAccum.io.sum io.mem.wr.bits := accum.io.sum(cntout)
// count read/write // count read/write
when (state === sIdle) { when (state === sIdle) {
cnt := 0.U cntwgt := 0.U
} .elsewhen (state === sWriteData) { cntinp := 0.U
cnt := cnt + 1.U cntout := 0.U
} .elsewhen (state === sReadADone) {
cntwgt := cntwgt + 1.U
} .elsewhen (state === sReadBDone) {
cntinp := cntinp + 1.U
} .elsewhen (state === sWriteDone) {
cntout := cntout + 1.U
} }
io.finish := overallAccum.io.ready // data has been added io.finish := last // data has been added
} }
// Shift operation until supported in MVM
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
class Accumulator(dataBits: Int = 8) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val clear = Input(Bool()) val clear = Input(Bool())
val valid = Input(Bool()) val valid = Input(Bool())
val ready = Output(Bool()) val ready = Output(Bool())
val in = Input(UInt(dataBits.W)) val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
val sum = Output(UInt((dataBits).W)) val shift = Input(UInt(8.W))
val sum = Output(Vec(size, (UInt(accBits.W))))
}) })
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
val reg = RegInit(0.U((dataBits).W)) for (i <- 0 until size) {
val ready = RegNext(io.valid) when (io.clear) {
when (io.clear) { reg(i) := 0.U
reg := 0.U } .elsewhen(io.valid) {
} .elsewhen (io.valid) { reg(i) := reg(i) + (io.in(0)(i) << io.shift)
reg := reg + io.in }
} }
io.ready := ready io.ready := RegNext(io.valid)
io.sum := reg io.sum := reg
} }
...@@ -35,13 +35,9 @@ import vta.dpi._ ...@@ -35,13 +35,9 @@ import vta.dpi._
* Shift value | 0x08 * Shift value | 0x08
* Vector length | 0x0c * Vector length | 0x0c
* Reset Accumulator | 0x10 * Reset Accumulator | 0x10
* Reset Dot Module | 0x14 * Input1 pointer | 0x18
* Input1 pointer lsb | 0x18 * Input2 pointer | 0x20
* Input1 pointer msb | 0x1c * Output pointer | 0x28
* Input2 pointer lsb | 0x20
* Input2 pointer msb | 0x24
* Output pointer lsb | 0x28
* Output pointer msb | 0x2c
* ------------------------------- * -------------------------------
* ------------------------------ * ------------------------------
......
...@@ -66,10 +66,12 @@ class Device { ...@@ -66,10 +66,12 @@ class Device {
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) { uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles; uint32_t cycles;
uint32_t length = inp1->shape[0]; uint32_t length = inp2->shape[0];
size_t size1 = (inp1->dtype.bits >> 3) * length; // 1 matrix 1 vector input
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length; size_t size2 = (inp2->dtype.bits >> 3) * length;
size_t size3 = (64 >> 3); // 1 vector output
size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1); inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2); inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3); out_ = this->MemAlloc(size3);
...@@ -115,19 +117,17 @@ class Device { ...@@ -115,19 +117,17 @@ class Device {
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) { void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal); dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // vector length dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_)); dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_)); dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_)); dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0); // launch dpi_->WriteReg(0x00, 0x0);
if (reset == 1) { if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accum dpi_->WriteReg(0x10, 0x1); // reset accumulator
dpi_->WriteReg(0x10, 0x0); // stop reset accum dpi_->WriteReg(0x10, 0x0);
} }
dpi_->WriteReg(0x14, 0x1); // reset dot
dpi_->WriteReg(0x14, 0x0); // stop reset dot
} }
uint32_t WaitForCompletion() { uint32_t WaitForCompletion() {
......
...@@ -26,7 +26,7 @@ Parameters ...@@ -26,7 +26,7 @@ Parameters
A : Vector to be sliced and packed A : Vector to be sliced and packed
slice_width : slice width slice_width : slice width
Returnsi Returns
--------- ---------
C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
""" """
...@@ -39,7 +39,7 @@ def slice(A, slice_width): ...@@ -39,7 +39,7 @@ def slice(A, slice_width):
elif dtype is np.uint16: row = 16 // slice_width elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype " + str(dtype) + "currently not supported") else: raise ValueError("datatype currently not supported")
if (row >= 8): if (row >= 8):
dtype = 'uint' + str(row) dtype = 'uint' + str(row)
else: else:
...@@ -55,64 +55,88 @@ def slice(A, slice_width): ...@@ -55,64 +55,88 @@ def slice(A, slice_width):
C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask) C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
return C return C
def slice_mat(A, slice_width):
assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
dtype = type(A[0][0])
row = 0
# currently only supports uint
if dtype is np.uint8: row = 8 // slice_width
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
dtype = 'uint8'
# 3d array (bits, row, clmn)
C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform
# create mask
slice_mask = 2**(slice_width)-1
# slice and pack
for z in range(A.shape[0]):
C[:, z, :] = slice(A[z], slice_width)
return C
""" Matrix Multiplication Function """ Matrix Multiplication Function
Parameters Parameters
---------- ----------
A : Matrix A A : Matrix A
B: Matrix B B: Matrix B
w_width : weight slice width i_width : weight slice width
a_width : activation slice width w_width : activation slice width
Returns Returns
--------- ---------
C: result of A * B C: result of A * B
""" """
# A is a n*m matrix, B is a m*p matrix(not transposed yet) # A is a n*m matrix, B is a m*p matrix(not transposed yet)
def matrix_multiply(A, B, w_width, a_width): def matrix_multiply(A, B, i_width, w_width):
assert A.shape[1] == B.shape[0], "can't perform multiplication" assert A.shape[1] == B.shape[0], "can't perform multiplication"
BT = B.transpose() BT = B.transpose()
cycles = 0 cycles = 0
B_sliced = slice_mat(BT, w_width)
C = np.zeros((A.shape[0], B.shape[1])).astype('uint64') C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
for i in range(A.shape[0]): for i in range(A.shape[0]):
for j in range(B.shape[1]): A_sliced = slice(A[i], i_width)
# C[i, j] = A[i].dot(BT[j]) test = test_accel(A_sliced, B_sliced, i_width, w_width)
A_sliced = slice(A[i], w_width) C[i] = test[0]
B_sliced = slice(BT[j], a_width) cycles += test[1]
np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width))
C[i, j] = compute(A_sliced, B_sliced, w_width, a_width) print("PASS row " + str(i))
test = test_accel(A_sliced, B_sliced, w_width, a_width)
cycles += test[1] np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B))
np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
print("PASS SW serial & parallel")
np.testing.assert_equal(test[0], C[i, j])
print("PASS SW & HW bit serial")
np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
print("PASS SW bit parallel & HW bit parallel")
print("result: ") print("result: ")
print(C) print(C)
print("ALL TESTS PASSED, cycles: " + str(cycles)) print("TEST PASSED, cycles: " + str(cycles))
return C return C
""" Software Verification Function""" """ Software Verification Function
# takes 2 matrix input (sliced and packed) Parameter Dimesions
def compute(A, B, w_width, a_width): ---------
A (bits, y) and B (bits, y, x) (transposed)
Takes 1 vector and 1 matrix input (sliced and packed)
Returns
---------
Resulting vector
"""
def compute(A, B, i_width, w_width):
assert A.shape[1] == B.shape[1], "sliced shape not match" assert A.shape[1] == B.shape[1], "sliced shape not match"
# reset hardware accumulator # reset hardware accumulator
accum = 0 accum = np.zeros(A.shape[1])
for x in range(A.shape[0]): for x in range(A.shape[0]):
for y in range(B.shape[0]): for y in range(B.shape[0]):
# hardware implementation accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width)
accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
# get value from accumulator # get value from accumulator
return accum return accum
"""Testing Function for Dot Product""" """Testing Function for Matrix Vector Multiplication"""
def test_accel(A, B, w_width, a_width): def test_accel(A, B, i_width, w_width):
assert A.shape[1] == B.shape[1], "sliced shape not match" assert A.shape[1] == B.shape[2], "sliced shape not match"
dtype = A.dtype dtype = A.dtype
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tsim.load_module() f = tsim.load_module()
...@@ -126,57 +150,54 @@ def test_accel(A, B, w_width, a_width): ...@@ -126,57 +150,54 @@ def test_accel(A, B, w_width, a_width):
a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx)) a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
for i in range(B.shape[0]): for i in range(B.shape[0]):
list_b = np.zeros(B.shape[1]).astype(dtype) # transpose
for j in range(B.shape[1]): list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype)
list_b[j] = B[i][j] for j in range(B.shape[2]):
for k in range(B.shape[1]):
list_b[j][k] = B[i][j][k]
b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx)) b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
cycles = 0 cycles = 0
accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx)
accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
for i in range(len(a_arr)): for i in range(len(a_arr)):
for j in range(len(b_arr)): for j in range(len(b_arr)):
shift = np.uint8(i*w_width + j*a_width) shift = np.uint8(i*i_width + j*w_width)
if i == 0 and j == 0: if i == 0 and j == 0:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator
else: else:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset
return (accum.asnumpy()[0], cycles) return (accum.asnumpy(), cycles)
""" Matrix Generator """ Matrix Generator
Parameters Parameters
---------- ----------
dtype : String, datatype generated (supports only uint) dtype : String, datatype generated (supports only uint)
w_width : weight bit slices(needs to be less than actual bit width) i_width : weight bit slices(needs to be less than actual bit width)
a_width : activation bit slices(needs to be less than actual bit width) w_width : activation bit slices(needs to be less than actual bit width)
""" """
def top_test(dtype, w_width, a_width): def top_test(dtype, i_width, w_width):
rmax = np.random.randint(256)
# random matrix generation (dimension up to 8)
rrow = np.random.randint(7) + 1
rclmn = np.random.randint(7) + 1
rrow2 = np.random.randint(7) + 1
A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
print("A: ") # only supports positive values (up to 2**(bits-1))
print(A) rmax = 127
print("\n") # (m,16) * (16,16) GEMM
print("B: ") rrow = np.random.randint(7) + 1
print(B) clmn = 16
print("\n") A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype)
matrix_multiply(A, B, w_width, a_width) B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype)
print("A: " + str(A))
print("B: " + str(B))
# perform GEMM
matrix_multiply(A, B, i_width, w_width)
if __name__ == "__main__": if __name__ == "__main__":
tsim.init("chisel") tsim.init("chisel")
for i in range(1): for i in range(1):
# reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits # reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits
if sys.argv[1] == 'serial': if sys.argv[1] == 'serial':
# generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation # generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight
top_test("uint8",4, 2) top_test("uint8", 4, 2)
elif sys.argv[1] == 'parallel': elif sys.argv[1] == 'parallel':
# generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel) # generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel)
top_test('uint8', 1, 1) top_test('uint8', 8, 8)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment