Commit 13b28566 by Benjamin Tu Committed by Jared Roesch

[VTA][Chisel] TSIM VTA Source Refactor (#4163)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo
parent 07606e4c
......@@ -22,21 +22,31 @@ package accel
import chisel3._
import chisel3.util._
import vta.dpi._
import vta.core._
import vta.util.config._
import vta.shell._
class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
* 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value
* 4. Increment read-address for next value
* 5. Wait for sliced accumulator
* 6. Check if counter (cnt) is equal to length process,
otherwise goto step 2
* 7. Check if reset slice accumulator
* 8. Wait for overall accumulator
* 8. Issue a write request for 8-byte value at out_baddr address
* 5. Repeat until all inp1 data have been read
* 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
* 7. Wait for the value
* 8. Increment read-address for next value
* 9. Repeat until all inp2 data have been read
* 10. Wait for output to be calculated
* 11. Issue a write request for 8-byte value at out_baddr address
* 12. Increment write-address for next value to write
* 13. Check if counter (cntout) is equal to length to asser finish,
otherwise go to step 11
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
......@@ -47,19 +57,24 @@ class Compute(implicit config: AccelConfig) extends Module {
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
implicit val p: Parameters = new TestConfig
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(UInt(config.regBits.W))
val mvc = Module(new MatrixVectorMultiplication)
val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
val cntwgt = Reg(UInt(config.regBits.W))
val cntinp = Reg(UInt(config.regBits.W))
val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
switch (state) {
is (sIdle) {
......@@ -73,7 +88,14 @@ class Compute(implicit config: AccelConfig) extends Module {
}
is (sReadAData) {
when (io.mem.rd.valid) {
state := sReadADone
}
}
is (sReadADone) {
when (cntwgt === (length * length) - 1.U) {
state := sReadBReq
} .otherwise {
state := sReadAReq
}
}
is (sReadBReq) {
......@@ -81,6 +103,23 @@ class Compute(implicit config: AccelConfig) extends Module {
}
is (sReadBData) {
when (io.mem.rd.valid) {
state := sReadBDone
}
}
is (sReadBDone) {
when (cntinp === length-1.U) {
state := sInpDone
} .otherwise {
state := sReadBReq
}
}
// Both input is processed
is (sInpDone) {
state := sWait
}
// Wait for computation
is (sWait) {
when (accum.io.ready) {
state := sWriteReq
}
}
......@@ -89,15 +128,18 @@ class Compute(implicit config: AccelConfig) extends Module {
state := sWriteData
}
is (sWriteData) {
when (cnt === (length - 1.U)) {
state := sWriteDone
}
is (sWriteDone) {
when (cntout === (length - 1.U)) {
state := sIdle
} .otherwise {
state := sReadAReq
state := sWriteReq
}
}
}
val last = state === sWriteData && cnt === (length - 1.U)
val last = state === sWriteDone && cntout === (length - 1.U)
// cycle counter
when (state === sIdle) {
......@@ -114,10 +156,12 @@ class Compute(implicit config: AccelConfig) extends Module {
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
} .elsewhen (state === sWriteData) { // increment input array by 1-byte
} .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U
waddr := waddr
} .elsewhen (state === sWriteDone) {
waddr := waddr + 4.U // writing 4 bytes
}
// create request
......@@ -128,59 +172,70 @@ class Compute(implicit config: AccelConfig) extends Module {
// read
when (state === sReadAData && io.mem.rd.valid) {
reg1 := io.mem.rd.bits(7, 0)
reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
}
when (state === sReadBData && io.mem.rd.valid) {
reg2 := io.mem.rd.bits(7, 0)
reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
}
io.mem.rd.ready := state === sReadAData | state === sReadBData
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.bits <> reg1
mvc.io.inp.data.bits <> reg2
// Modify when shift operation is supported
mvc.io.reset := false.B
mvc.io.acc_i.data.valid := true.B
for (i <- 0 until p(CoreKey).blockOut) {
mvc.io.acc_i.data.bits(0)(i) := 0.U
}
val sliceAccum = Module(new Accumulator(63))
val overallAccum = Module(new Accumulator(64))
sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
sliceAccum.io.in := reg1 * reg2
sliceAccum.io.clear := startDot
overallAccum.io.clear := rstAccum
overallAccum.io.valid := last // last element has been processed
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
accum.io.in := mvc.io.acc_o.data.bits
accum.io.shift := shift
accum.io.clear := rstAccum
accum.io.valid := mvc.io.acc_o.data.valid
// write
io.mem.wr.valid := overallAccum.io.ready
io.mem.wr.bits := overallAccum.io.sum
io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := accum.io.sum(cntout)
// count read/write
when (state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
cnt := cnt + 1.U
cntwgt := 0.U
cntinp := 0.U
cntout := 0.U
} .elsewhen (state === sReadADone) {
cntwgt := cntwgt + 1.U
} .elsewhen (state === sReadBDone) {
cntinp := cntinp + 1.U
} .elsewhen (state === sWriteDone) {
cntout := cntout + 1.U
}
io.finish := overallAccum.io.ready // data has been added
io.finish := last // data has been added
}
class Accumulator(dataBits: Int = 8) extends Module {
// Shift operation until supported in MVM
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
val in = Input(UInt(dataBits.W))
val sum = Output(UInt((dataBits).W))
val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
val shift = Input(UInt(8.W))
val sum = Output(Vec(size, (UInt(accBits.W))))
})
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
val reg = RegInit(0.U((dataBits).W))
val ready = RegNext(io.valid)
when (io.clear) {
reg := 0.U
} .elsewhen (io.valid) {
reg := reg + io.in
}
io.ready := ready
io.sum := reg
for (i <- 0 until size) {
when (io.clear) {
reg(i) := 0.U
} .elsewhen(io.valid) {
reg(i) := reg(i) + (io.in(0)(i) << io.shift)
}
}
io.ready := RegNext(io.valid)
io.sum := reg
}
......@@ -35,13 +35,9 @@ import vta.dpi._
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
* Reset Dot Module | 0x14
* Input1 pointer lsb | 0x18
* Input1 pointer msb | 0x1c
* Input2 pointer lsb | 0x20
* Input2 pointer msb | 0x24
* Output pointer lsb | 0x28
* Output pointer msb | 0x2c
* Input1 pointer | 0x18
* Input2 pointer | 0x20
* Output pointer | 0x28
* -------------------------------
* ------------------------------
......
......@@ -66,10 +66,12 @@ class Device {
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
uint32_t length = inp1->shape[0];
size_t size1 = (inp1->dtype.bits >> 3) * length;
uint32_t length = inp2->shape[0];
// 1 matrix 1 vector input
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
size_t size3 = (64 >> 3);
// 1 vector output
size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
......@@ -115,19 +117,17 @@ class Device {
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // vector length
dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0); // launch
dpi_->WriteReg(0x00, 0x0);
if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accum
dpi_->WriteReg(0x10, 0x0); // stop reset accum
dpi_->WriteReg(0x10, 0x1); // reset accumulator
dpi_->WriteReg(0x10, 0x0);
}
dpi_->WriteReg(0x14, 0x1); // reset dot
dpi_->WriteReg(0x14, 0x0); // stop reset dot
}
uint32_t WaitForCompletion() {
......
......@@ -26,7 +26,7 @@ Parameters
A : Vector to be sliced and packed
slice_width : slice width
Returnsi
Returns
---------
C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
"""
......@@ -39,7 +39,7 @@ def slice(A, slice_width):
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype " + str(dtype) + "currently not supported")
else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
......@@ -55,64 +55,88 @@ def slice(A, slice_width):
C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
return C
def slice_mat(A, slice_width):
assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
dtype = type(A[0][0])
row = 0
# currently only supports uint
if dtype is np.uint8: row = 8 // slice_width
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
dtype = 'uint8'
# 3d array (bits, row, clmn)
C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform
# create mask
slice_mask = 2**(slice_width)-1
# slice and pack
for z in range(A.shape[0]):
C[:, z, :] = slice(A[z], slice_width)
return C
""" Matrix Multiplication Function
Parameters
----------
A : Matrix A
B: Matrix B
w_width : weight slice width
a_width : activation slice width
i_width : weight slice width
w_width : activation slice width
Returns
---------
C: result of A * B
"""
# A is a n*m matrix, B is a m*p matrix(not transposed yet)
def matrix_multiply(A, B, w_width, a_width):
def matrix_multiply(A, B, i_width, w_width):
assert A.shape[1] == B.shape[0], "can't perform multiplication"
BT = B.transpose()
cycles = 0
B_sliced = slice_mat(BT, w_width)
C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
for i in range(A.shape[0]):
for j in range(B.shape[1]):
# C[i, j] = A[i].dot(BT[j])
A_sliced = slice(A[i], w_width)
B_sliced = slice(BT[j], a_width)
C[i, j] = compute(A_sliced, B_sliced, w_width, a_width)
test = test_accel(A_sliced, B_sliced, w_width, a_width)
cycles += test[1]
np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
print("PASS SW serial & parallel")
np.testing.assert_equal(test[0], C[i, j])
print("PASS SW & HW bit serial")
np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
print("PASS SW bit parallel & HW bit parallel")
A_sliced = slice(A[i], i_width)
test = test_accel(A_sliced, B_sliced, i_width, w_width)
C[i] = test[0]
cycles += test[1]
np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width))
print("PASS row " + str(i))
np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B))
print("result: ")
print(C)
print("ALL TESTS PASSED, cycles: " + str(cycles))
print("TEST PASSED, cycles: " + str(cycles))
return C
""" Software Verification Function"""
# takes 2 matrix input (sliced and packed)
def compute(A, B, w_width, a_width):
""" Software Verification Function
Parameter Dimesions
---------
A (bits, y) and B (bits, y, x) (transposed)
Takes 1 vector and 1 matrix input (sliced and packed)
Returns
---------
Resulting vector
"""
def compute(A, B, i_width, w_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
# reset hardware accumulator
accum = 0
accum = np.zeros(A.shape[1])
for x in range(A.shape[0]):
for y in range(B.shape[0]):
# hardware implementation
accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width)
# get value from accumulator
return accum
"""Testing Function for Dot Product"""
def test_accel(A, B, w_width, a_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
"""Testing Function for Matrix Vector Multiplication"""
def test_accel(A, B, i_width, w_width):
assert A.shape[1] == B.shape[2], "sliced shape not match"
dtype = A.dtype
ctx = tvm.cpu(0)
f = tsim.load_module()
......@@ -126,57 +150,54 @@ def test_accel(A, B, w_width, a_width):
a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
for i in range(B.shape[0]):
list_b = np.zeros(B.shape[1]).astype(dtype)
for j in range(B.shape[1]):
list_b[j] = B[i][j]
# transpose
list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype)
for j in range(B.shape[2]):
for k in range(B.shape[1]):
list_b[j][k] = B[i][j][k]
b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
cycles = 0
accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx)
for i in range(len(a_arr)):
for j in range(len(b_arr)):
shift = np.uint8(i*w_width + j*a_width)
shift = np.uint8(i*i_width + j*w_width)
if i == 0 and j == 0:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator
cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator
else:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset
cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset
return (accum.asnumpy()[0], cycles)
return (accum.asnumpy(), cycles)
""" Matrix Generator
Parameters
----------
dtype : String, datatype generated (supports only uint)
w_width : weight bit slices(needs to be less than actual bit width)
a_width : activation bit slices(needs to be less than actual bit width)
i_width : weight bit slices(needs to be less than actual bit width)
w_width : activation bit slices(needs to be less than actual bit width)
"""
def top_test(dtype, w_width, a_width):
rmax = np.random.randint(256)
# random matrix generation (dimension up to 8)
rrow = np.random.randint(7) + 1
rclmn = np.random.randint(7) + 1
rrow2 = np.random.randint(7) + 1
A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
def top_test(dtype, i_width, w_width):
print("A: ")
print(A)
print("\n")
print("B: ")
print(B)
print("\n")
matrix_multiply(A, B, w_width, a_width)
# only supports positive values (up to 2**(bits-1))
rmax = 127
# (m,16) * (16,16) GEMM
rrow = np.random.randint(7) + 1
clmn = 16
A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype)
B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype)
print("A: " + str(A))
print("B: " + str(B))
# perform GEMM
matrix_multiply(A, B, i_width, w_width)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(1):
# reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits
# reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits
if sys.argv[1] == 'serial':
# generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation
top_test("uint8",4, 2)
# generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight
top_test("uint8", 4, 2)
elif sys.argv[1] == 'parallel':
# generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel)
top_test('uint8', 1, 1)
# generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel)
top_test('uint8', 8, 8)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment