MvmTest.scala 2.76 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
 
package unittest

import chisel3._
import chisel3.util._
import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}
import scala.math.pow
26
import unittest.util._
27 28 29 30 31 32 33 34 35
import vta.core._

class TestMatrixVectorMultiplication(c: MatrixVectorMultiplication) extends PeekPokeTester(c) {
    
  /* mvm_ref
   *
   * This is a software function that computes dot product with a programmable shift
   * This is used as a reference for the hardware
   */
Benjamin Tu committed
36
  def mvmRef(inp: Array[Int], wgt: Array[Array[Int]], shift: Int) : Array[Int] = {
37 38 39 40 41
    val size = inp.length
    val res = Array.fill(size) {0}
    for (i <- 0 until size) {
        var dot = 0
        for (j <- 0 until size) {
42
          dot += wgt(i)(j) * inp(j)
43 44 45 46 47 48 49 50
        }
        res(i) = dot * pow(2, shift).toInt
    }
    return res
  }

  val cycles = 5
  for (i <- 0 until cycles) {
51 52 53 54 55
    // generate data based on bits
    val inpGen = new RandomArray(c.size, c.inpBits)
    val wgtGen = new RandomArray(c.size, c.wgtBits)
    val in_a = inpGen.any
    val in_b = Array.fill(c.size) { wgtGen.any }
Benjamin Tu committed
56
    val res = mvmRef(in_a, in_b, 0)  
57 58 59
    val inpMask = helper.getMask(c.inpBits)
    val wgtMask = helper.getMask(c.wgtBits)
    val accMask = helper.getMask(c.accBits)
60 61 62 63 64
    
    for (i <- 0 until c.size) {
      poke(c.io.inp.data.bits(0)(i), in_a(i) & inpMask)
      poke(c.io.acc_i.data.bits(0)(i), 0)
      for (j <- 0 until c.size) {
65
        poke(c.io.wgt.data.bits(i)(j), in_b(i)(j) & wgtMask)
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
      }
    }
    
    poke(c.io.reset, 0)
    
    poke(c.io.inp.data.valid, 1)
    poke(c.io.wgt.data.valid, 1)
    poke(c.io.acc_i.data.valid, 1)
      
    step(1)

    poke(c.io.inp.data.valid, 0)
    poke(c.io.wgt.data.valid, 0)
    poke(c.io.acc_i.data.valid, 0)

    // wait for valid signal
    while (peek(c.io.acc_o.data.valid) == BigInt(0)) {
83
      step(1) // advance clock
84 85
    } 
    if (peek(c.io.acc_o.data.valid) == BigInt(1)) {
86 87 88
      for (i <- 0 until c.size) {
          expect(c.io.acc_o.data.bits(0)(i), res(i) & accMask)
      }
89 90 91
    }
  }
}