vta.cc 27 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *  Copyright (c) 2018 by Contributors
 * \file vta.cpp
 * \brief VTA HLS design.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

30
#include "vta.h"
31

32
void fetch(
33 34
  uint32_t insn_count,
  volatile insn_T *insns,
35 36 37
  hls::stream<insn_T> &load_queue,
  hls::stream<insn_T> &gemm_queue,
  hls::stream<insn_T> &store_queue) {
38 39 40 41 42 43 44 45 46
#pragma HLS INTERFACE s_axilite port = insn_count bundle = CONTROL_BUS
#pragma HLS INTERFACE m_axi port = insns offset = slave bundle = ins_port
#pragma HLS INTERFACE axis port = load_queue
#pragma HLS INTERFACE axis port = gemm_queue
#pragma HLS INTERFACE axis port = store_queue
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS

  INSN_DECODE: for (int pc = 0; pc < insn_count; pc++) {
#pragma HLS PIPELINE II = 1
47 48 49
    // Read instruction fields
    insn_T insn = insns[pc];
    // Do some partial decoding
50 51
    opcode_T opcode = insn.range(VTA_INSN_MEM_0_1, VTA_INSN_MEM_0_0);
    memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
52
    // Push to appropriate instruction queue
53
    if (opcode == VTA_OPCODE_STORE) {
54
      store_queue.write(insn);
55 56
    } else if (opcode == VTA_OPCODE_LOAD &&
          (memory_type == VTA_MEM_ID_INP || memory_type == VTA_MEM_ID_WGT)) {
57
      load_queue.write(insn);
58
    } else {
59
      gemm_queue.write(insn);
60 61 62 63
    }
  }
}

64
void load(
65 66
  volatile inp_vec_T *inputs,
  volatile wgt_vec_T *weights,
67 68 69
  hls::stream<insn_T> &load_queue,
  hls::stream<bool> &g2l_dep_queue,
  hls::stream<bool> &l2g_dep_queue,
70 71
  inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
  wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT]
72
  ) {
73 74 75 76 77 78 79 80
#pragma HLS INTERFACE m_axi port = weights offset = slave bundle = data_port
#pragma HLS INTERFACE m_axi port = inputs offset = slave bundle = data_port
#pragma HLS INTERFACE axis port = load_queue
#pragma HLS INTERFACE axis port = g2l_dep_queue
#pragma HLS INTERFACE axis port = l2g_dep_queue
#pragma HLS INTERFACE bram port = wgt_mem
#pragma HLS INTERFACE bram port = inp_mem
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
81 82

  // Pop load instruction
83
  insn_T insn = load_queue.read();
84 85

  // Decode instruction
86 87 88 89 90 91 92 93 94 95 96 97 98 99
  bool pop_prev_dependence = insn[VTA_INSN_MEM_1];
  bool pop_next_dependence = insn[VTA_INSN_MEM_2];
  bool push_prev_dependence = insn[VTA_INSN_MEM_3];
  bool push_next_dependence = insn[VTA_INSN_MEM_4];
  memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
  memop_sram_T sram_base = insn.range(VTA_INSN_MEM_6_1, VTA_INSN_MEM_6_0);
  memop_dram_T dram_base = insn.range(VTA_INSN_MEM_7_1, VTA_INSN_MEM_7_0);
  memop_size_T y_size = insn.range(VTA_INSN_MEM_8_1, VTA_INSN_MEM_8_0);
  memop_size_T x_size = insn.range(VTA_INSN_MEM_9_1, VTA_INSN_MEM_9_0);
  memop_stride_T x_stride = insn.range(VTA_INSN_MEM_A_1, VTA_INSN_MEM_A_0);
  memop_pad_T y_pad_0 = insn.range(VTA_INSN_MEM_B_1, VTA_INSN_MEM_B_0);
  memop_pad_T y_pad_1 = insn.range(VTA_INSN_MEM_C_1, VTA_INSN_MEM_C_0);
  memop_pad_T x_pad_0 = insn.range(VTA_INSN_MEM_D_1, VTA_INSN_MEM_D_0);
  memop_pad_T x_pad_1 = insn.range(VTA_INSN_MEM_E_1, VTA_INSN_MEM_E_0);
100 101 102

  // Pop dependence token if instructed
  if (pop_next_dependence) {
103
    g2l_dep_queue.read();
104 105 106 107 108 109 110 111 112 113
  }

  // Initialize indices
  memop_sram_T sram_idx = sram_base;
  memop_dram_T dram_idx = dram_base;

  // Pre-compute dimensions, and offsets
  memop_size_T y_size_total = y_pad_0 + y_size + y_pad_1;
  memop_size_T x_size_total = x_pad_0 + x_size + x_pad_1;
  memop_sram_T y_offset = x_size_total * y_pad_0;
114 115
// Force this computation to be done with LUTs to avoid using too many DSPs
#pragma HLS RESOURCE variable = y_offset core = Mul_LUT
116 117 118 119 120

  // Skip padding along y dimension
  sram_idx += y_offset;

  // Perform data transfer from DRAM
121
  for (int y = 0; y < y_size; y++) {
122 123 124 125
#pragma HLS PIPELINE rewind
    // Skip padding along x dimension
    sram_idx += x_pad_0;
    // Perform data transfer
126 127 128 129
    if (memory_type == VTA_MEM_ID_INP) {
      memcpy(&inp_mem[sram_idx][0],
             (const inp_vec_T*) &inputs[dram_idx * VTA_BATCH],
             x_size * VTA_INP_ELEM_BYTES);
130
    } else {
131 132 133
      memcpy(&wgt_mem[sram_idx][0],
             (const wgt_vec_T*) &weights[dram_idx * VTA_BLOCK_OUT],
             x_size * VTA_WGT_ELEM_BYTES);
134 135 136 137 138 139 140 141 142 143
    }
    sram_idx += x_size;
    dram_idx += x_stride;
    // Skip padding along x dimension
    sram_idx += x_pad_1;
  }

  // Reset SRAM index
  sram_idx = sram_base;
  // Pad x/y edges with zeros
144
  for (int y = 0; y < y_size_total; y++) {
145
    if (y < y_pad_0 || y >= y_pad_0 + y_size) {
146 147 148 149
      for (int x = 0; x < x_size_total; x++) {
#pragma HLS PIPELINE II = 1 rewind
        if (memory_type == VTA_MEM_ID_INP) {
          for (int i = 0; i < VTA_BATCH; i++) {
150 151 152
            inp_mem[sram_idx][i] = 0;
          }
        } else {
153
          for (int i = 0; i < VTA_BLOCK_OUT; i++) {
154 155 156
            wgt_mem[sram_idx][i] = 0;
          }
        }
157
        sram_idx++;
158 159
      }
    } else {
160 161 162 163
      for (int x = 0; x < x_pad_0; x++) {
#pragma HLS PIPELINE II = 1 rewind
        if (memory_type == VTA_MEM_ID_INP) {
          for (int i = 0; i < VTA_BATCH; i++) {
164 165 166
            inp_mem[sram_idx][i] = 0;
          }
        } else {
167
          for (int i = 0; i < VTA_BLOCK_OUT; i++) {
168 169 170
            wgt_mem[sram_idx][i] = 0;
          }
        }
171
        sram_idx++;
172 173
      }
      sram_idx += x_size;
174 175 176 177
      for (int x = 0; x < x_pad_1; x++) {
#pragma HLS PIPELINE II = 1 rewind
        if (memory_type == VTA_MEM_ID_INP) {
          for (int i = 0; i < VTA_BATCH; i++) {
178 179 180
            inp_mem[sram_idx][i] = 0;
          }
        } else {
181
          for (int i = 0; i < VTA_BLOCK_OUT; i++) {
182 183 184
            wgt_mem[sram_idx][i] = 0;
          }
        }
185
        sram_idx++;
186 187 188 189 190 191
      }
    }
  }

  // Push dependence token if instructed
  if (push_next_dependence) {
192
    l2g_dep_queue.write(1);
193 194 195
  }
}

196
void compute(
197
  volatile uint32_t &done,
198 199
  volatile uop_T *uops,
  volatile acc_vec_T *biases,
200 201 202 203 204
  hls::stream<insn_T> &gemm_queue,
  hls::stream<bool> &l2g_dep_queue,
  hls::stream<bool> &s2g_dep_queue,
  hls::stream<bool> &g2l_dep_queue,
  hls::stream<bool> &g2s_dep_queue,
205
  inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
206 207
  wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT],
  out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
208
  ) {
209 210 211 212 213 214 215 216 217 218 219 220
#pragma HLS INTERFACE s_axilite port = done bundle = CONTROL_BUS
#pragma HLS INTERFACE m_axi port = uops offset = slave bundle = uop_port
#pragma HLS INTERFACE m_axi port = biases offset = slave bundle = data_port
#pragma HLS INTERFACE axis port = gemm_queue
#pragma HLS INTERFACE axis port = l2g_dep_queue
#pragma HLS INTERFACE axis port = s2g_dep_queue
#pragma HLS INTERFACE axis port = g2l_dep_queue
#pragma HLS INTERFACE axis port = g2s_dep_queue
#pragma HLS INTERFACE bram port = inp_mem
#pragma HLS INTERFACE bram port = wgt_mem
#pragma HLS INTERFACE bram port = out_mem
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
221
// This is necessary connect the SRAM to the load module
222
#pragma HLS RESOURCE variable = wgt_mem core = RAM_1P
223 224

  // Micro-op storage
225
  static uop_T uop_mem[VTA_UOP_BUFF_DEPTH];
226 227

  // Accumulator storage
228 229
  static acc_vec_T acc_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH];
#pragma HLS ARRAY_PARTITION variable = acc_mem complete dim = 2
230 231

  // Pop GEMM instruction
232
  insn_T insn = gemm_queue.read();
233 234

  // Decode
235 236 237 238 239
  opcode_T opcode = insn.range(VTA_INSN_MEM_0_1, VTA_INSN_MEM_0_0);
  bool pop_prev_dependence = insn[VTA_INSN_MEM_1];
  bool pop_next_dependence = insn[VTA_INSN_MEM_2];
  bool push_prev_dependence = insn[VTA_INSN_MEM_3];
  bool push_next_dependence = insn[VTA_INSN_MEM_4];
240 241 242

  // Pop dependence token if instructed
  if (pop_prev_dependence) {
243
    l2g_dep_queue.read();
244 245
  }
  if (pop_next_dependence) {
246
    s2g_dep_queue.read();
247 248 249
  }

  // Perform action based on opcode
250
  if (opcode == VTA_OPCODE_FINISH) {
251
    // Set done flag if we reach a FINISH instruction
252
    done = 1;
253
  } else if (opcode == VTA_OPCODE_LOAD || opcode == VTA_OPCODE_STORE) {
254
    // Set done value
255
    done = 0;
256 257

    // Decode instruction
258 259 260 261 262 263 264 265 266 267
    memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
    memop_sram_T sram_base = insn.range(VTA_INSN_MEM_6_1, VTA_INSN_MEM_6_0);
    memop_dram_T dram_base = insn.range(VTA_INSN_MEM_7_1, VTA_INSN_MEM_7_0);
    memop_size_T y_size = insn.range(VTA_INSN_MEM_8_1, VTA_INSN_MEM_8_0);
    memop_size_T x_size = insn.range(VTA_INSN_MEM_9_1, VTA_INSN_MEM_9_0);
    memop_stride_T x_stride = insn.range(VTA_INSN_MEM_A_1, VTA_INSN_MEM_A_0);
    memop_pad_T y_pad_0 = insn.range(VTA_INSN_MEM_B_1, VTA_INSN_MEM_B_0);
    memop_pad_T y_pad_1 = insn.range(VTA_INSN_MEM_C_1, VTA_INSN_MEM_C_0);
    memop_pad_T x_pad_0 = insn.range(VTA_INSN_MEM_D_1, VTA_INSN_MEM_D_0);
    memop_pad_T x_pad_1 = insn.range(VTA_INSN_MEM_E_1, VTA_INSN_MEM_E_0);
268 269 270 271 272 273 274 275 276

    // Initialize indices
    memop_sram_T sram_idx = sram_base;
    memop_dram_T dram_idx = dram_base;

    // Pre-compute dimensions, and offsets
    memop_size_T y_size_total = y_pad_0 + y_size + y_pad_1;
    memop_size_T x_size_total = x_pad_0 + x_size + x_pad_1;
    memop_sram_T y_offset = x_size_total * y_pad_0;
277 278
// Force this computation to be done with LUTs to avoid using too many DSPs
#pragma HLS RESOURCE variable = y_offset core = Mul_LUT
279

280
    if (memory_type == VTA_MEM_ID_UOP) {
281
      // Perform data transfer
282 283 284
      memcpy(&uop_mem[sram_base],
             (const uop_T*) &uops[dram_base],
             x_size * sizeof(uop_T));
285 286 287 288
    } else {
      // Skip vertical padding
      sram_idx += y_offset;
      // Perform data transfer from DRAM
289
      for (int y = 0; y < y_size; y++) {
290 291 292 293
#pragma HLS PIPELINE rewind
        // Skip padding along x dimension
        sram_idx += x_pad_0;
        // Perform data transfer
294 295 296
        memcpy(&acc_mem[sram_idx][0],
               (const acc_vec_T*) &biases[dram_idx * VTA_BATCH],
               x_size*VTA_ACC_ELEM_BYTES);
297 298 299 300 301 302
        sram_idx += x_size;
        dram_idx += x_stride;
        // Skip padding along x dimension
        sram_idx += x_pad_1;
      }
    }
303
  } else if (opcode == VTA_OPCODE_GEMM || opcode == VTA_OPCODE_ALU) {
304
    // Set done value
305
    done = 0;
306 307

    // Decode
308 309 310 311 312 313 314 315 316
    bool reset_out = insn[VTA_INSN_GEM_5];
    uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_6_1, VTA_INSN_GEM_6_0);
    uop_idx_T uop_end = insn.range(VTA_INSN_GEM_7_1, VTA_INSN_GEM_7_0);
    loop_T iter_out  = insn.range(VTA_INSN_GEM_8_1, VTA_INSN_GEM_8_0);
    loop_T iter_in  = insn.range(VTA_INSN_GEM_9_1, VTA_INSN_GEM_9_0);
    acc_idx_T dst_factor_out = insn.range(VTA_INSN_GEM_A_1, VTA_INSN_GEM_A_0);
    acc_idx_T dst_factor_in = insn.range(VTA_INSN_GEM_B_1, VTA_INSN_GEM_B_0);
    inp_idx_T src_factor_out = insn.range(VTA_INSN_GEM_C_1, VTA_INSN_GEM_C_0);
    inp_idx_T src_factor_in = insn.range(VTA_INSN_GEM_D_1, VTA_INSN_GEM_D_0);
317 318

    // GEMM-specific fields
319 320
    wgt_idx_T wgt_factor_out = insn.range(VTA_INSN_GEM_E_1, VTA_INSN_GEM_E_0);
    wgt_idx_T wgt_factor_in = insn.range(VTA_INSN_GEM_F_1, VTA_INSN_GEM_F_0);
321 322

    // ALU-specific field
323 324 325
    aluop_opcode_T alu_opcode = insn.range(VTA_INSN_ALU_E_1, VTA_INSN_ALU_E_0);
    bool use_imm = insn[VTA_INSN_ALU_F];
    aluop_imm_T imm = insn.range(VTA_INSN_ALU_G_1, VTA_INSN_ALU_G_0);
326 327 328 329 330
    acc_idx_T dst_offset_out = 0;
    inp_idx_T src_offset_out = 0;
    wgt_idx_T wgt_offset_out = 0;

    // Outer Loop
331 332
    EXE_OUT_LOOP: for (int it_out = 0; it_out < iter_out; it_out++) {
#pragma HLS DEPENDENCE variable = acc_mem inter false
333 334 335 336 337
      acc_idx_T dst_offset_in = dst_offset_out;
      inp_idx_T src_offset_in = src_offset_out;
      wgt_idx_T wgt_offset_in = wgt_offset_out;

      // Inner Loop
338
      EXE_IN_LOOP: for (int it_in = 0; it_in < iter_in; it_in++) {
339
        // Perform appropriate computation based on opcode
340
        if (opcode == VTA_OPCODE_GEMM) {
341
          // Iterate over micro op
342 343
          READ_GEMM_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) {
#pragma HLS PIPELINE II = 1 rewind
344 345 346 347 348 349

            // Read micro-op fields
            uop_T uop = uop_mem[upc];

            // Decode indices
            acc_idx_T dst_idx =
350 351 352
                uop.range(VTA_UOP_GEM_0_1, VTA_UOP_GEM_0_0) + dst_offset_in;
            inp_idx_T src_idx =
                uop.range(VTA_UOP_GEM_1_1, VTA_UOP_GEM_1_0) + src_offset_in;
353
            wgt_idx_T wgt_idx =
354
                uop.range(VTA_UOP_GEM_2_1, VTA_UOP_GEM_2_0) + wgt_offset_in;
355 356

            // Read weight matrix
357 358
            wgt_vec_T w_matrix[VTA_BLOCK_OUT];
            for (int i = 0; i < VTA_BLOCK_OUT; i++) {
359 360 361
              w_matrix[i] = wgt_mem[wgt_idx][i];
            }
            // Read input matrix and accum matrix
362
            acc_vec_T o_matrix[VTA_BATCH];
363
            inp_vec_T i_matrix[VTA_BATCH];
364
            for (int i = 0; i < VTA_BATCH; i++) {
365 366 367 368
              o_matrix[i] = acc_mem[dst_idx][i];
              i_matrix[i] = inp_mem[src_idx][i];
            }
            // Result matrices
369 370
            acc_vec_T acc_mem_val[VTA_BATCH];
            out_vec_T st_buf_val[VTA_BATCH];
371 372

            // Inner GEMM loop
373 374
            for (int i = 0; i < VTA_BATCH; i++) {
              for (int b = 0; b < VTA_BLOCK_OUT; b++) {
375 376
                // Initialize the accumulator values
                acc_T accum =
377
                  o_matrix[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
378 379 380
                // Dot product sum
                sum_T tmp = 0;
                // Inner matrix multiplication loop (input channel/feature)
381
                for (int k = 0; k < VTA_BLOCK_IN; k++) {
382
                  wgt_T w_elem =
383
                      w_matrix[b].range((k + 1) * VTA_WGT_WIDTH - 1, k * VTA_WGT_WIDTH);
384
                  inp_T i_elem =
385
                      i_matrix[i].range((k + 1) * VTA_INP_WIDTH - 1, k * VTA_INP_WIDTH);
386
                  mul_T prod = i_elem * w_elem;
387 388 389
#ifdef NO_DSP
#pragma HLS RESOURCE variable = prod core = Mul_LUT
#endif //  NO_DSP
390 391 392 393 394
                  tmp += (sum_T) prod;
                }
                // Update summation
                accum += (acc_T) tmp;
                // Update result vector
395 396 397
                acc_mem_val[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) =
                    reset_out ? (acc_T) 0 : accum;
                st_buf_val[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
398
                    (out_T) accum.range(VTA_OUT_WIDTH - 1, 0);
399 400 401 402 403 404
              }
              // Write to buffers
              acc_mem[dst_idx][i] = acc_mem_val[i];
              out_mem[dst_idx][i] = st_buf_val[i];
            }
          }
405 406 407
        }
#ifndef NO_ALU
        else if (opcode == VTA_OPCODE_ALU) {
408
          // Iterate over micro op
409
          READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) {
410 411 412 413 414
            // Read micro-op fields
            uop_T uop = uop_mem[upc];

            // Decode
            acc_idx_T dst_idx =
415
                uop.range(VTA_UOP_ALU_0_1, VTA_UOP_ALU_0_0) + dst_offset_in;
416
            acc_idx_T src_idx =
417
                uop.range(VTA_UOP_ALU_1_1, VTA_UOP_ALU_1_0) + src_offset_in;
418 419

            // Perform ALU op over matrix elements
420
            for (int i = 0; i < VTA_BATCH; i++) {
421 422 423 424 425 426 427 428 429 430
              // Read input matrix and accum matrix
              acc_vec_T dst_vector = acc_mem[dst_idx][i];
              acc_vec_T src_vector = acc_mem[src_idx][i];
              // Result matrices
              acc_vec_T cmp_res;
              acc_vec_T add_res;
              acc_vec_T shr_res;
              out_vec_T short_cmp_res;
              out_vec_T short_add_res;
              out_vec_T short_shr_res;
431 432
              // Results vector
              acc_vec_T res_vec = 0;
433
              for (int b = 0; b < VTA_BLOCK_OUT; b++) {
434
#pragma HLS PIPELINE II = 1 rewind
435
                // Read in operands
436
                acc_T src_0 = dst_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
437
                acc_T src_1 = use_imm ?
438
                    (acc_T) imm :
439
                    src_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
440
                // Compute Min/Max
441 442 443
                acc_T mix_val = src_0 < src_1 ?
                    (alu_opcode == VTA_ALU_OPCODE_MIN ? src_0 : src_1) :
                    (alu_opcode == VTA_ALU_OPCODE_MIN ? src_1 : src_0);
444 445 446
                cmp_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = mix_val;
                short_cmp_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
                    (out_T) mix_val.range(VTA_OUT_WIDTH - 1, 0);
447 448
                // Compute Sum
                acc_T add_val =
449
                    src_0.range(VTA_ACC_WIDTH - 1, 0) + src_1.range(VTA_ACC_WIDTH - 1, 0);
450 451 452 453 454
                add_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val;
                short_add_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
                    (out_T) add_val.range(VTA_OUT_WIDTH - 1, 0);
                // Compute Shift Right
                acc_T shr_val =
455
                    src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
456 457 458
                shr_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = shr_val;
                short_shr_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
                    (out_T) shr_val.range(VTA_OUT_WIDTH-1, 0);
459 460 461
              }

              // Store to accum memory/store buffer
462 463
              if (alu_opcode == VTA_ALU_OPCODE_MIN ||
                  alu_opcode == VTA_ALU_OPCODE_MAX) {
464 465
                acc_mem[dst_idx][i] = cmp_res;
                out_mem[dst_idx][i] = short_cmp_res;
466
              } else if (alu_opcode == VTA_ALU_OPCODE_ADD) {
467 468
                acc_mem[dst_idx][i] = add_res;
                out_mem[dst_idx][i] = short_add_res;
469
              } else if (alu_opcode == VTA_ALU_OPCODE_SHR) {
470 471
                acc_mem[dst_idx][i] = shr_res;
                out_mem[dst_idx][i] = short_shr_res;
472 473 474 475
              }
            }
          }
        }
476
#endif  // NO_ALU
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492

        // Update offsets
        dst_offset_in += dst_factor_in;
        src_offset_in += src_factor_in;
        wgt_offset_in += wgt_factor_in;
      }

      // Update offsets
      dst_offset_out += dst_factor_out;
      src_offset_out += src_factor_out;
      wgt_offset_out += wgt_factor_out;
    }
  }

  // Push dependence token if instructed
  if (push_prev_dependence) {
493
    g2l_dep_queue.write(1);
494 495
  }
  if (push_next_dependence) {
496
    g2s_dep_queue.write(1);
497 498 499
  }
}

500
void store(
501
  volatile out_vec_T *outputs,
502 503 504
  hls::stream<insn_T> &store_queue,
  hls::stream<bool> &g2s_dep_queue,
  hls::stream<bool> &s2g_dep_queue,
505
  out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
506
  ) {
507 508 509 510 511 512
#pragma HLS INTERFACE m_axi port = outputs offset = slave bundle = data_port
#pragma HLS INTERFACE axis port = store_queue
#pragma HLS INTERFACE axis port = g2s_dep_queue
#pragma HLS INTERFACE axis port = s2g_dep_queue
#pragma HLS INTERFACE bram port = out_mem
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
513 514

  // Load buffer
515
  insn_T insn = store_queue.read();
516 517

  // Decode
518 519 520 521 522 523 524 525 526 527 528 529 530 531
  bool pop_prev_dependence = insn[VTA_INSN_MEM_1];
  bool pop_next_dependence = insn[VTA_INSN_MEM_2];
  bool push_prev_dependence = insn[VTA_INSN_MEM_3];
  bool push_next_dependence = insn[VTA_INSN_MEM_4];
  memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
  memop_sram_T sram_base = insn.range(VTA_INSN_MEM_6_1, VTA_INSN_MEM_6_0);
  memop_dram_T dram_base = insn.range(VTA_INSN_MEM_7_1, VTA_INSN_MEM_7_0);
  memop_size_T y_size = insn.range(VTA_INSN_MEM_8_1, VTA_INSN_MEM_8_0);
  memop_size_T x_size = insn.range(VTA_INSN_MEM_9_1, VTA_INSN_MEM_9_0);
  memop_stride_T x_stride = insn.range(VTA_INSN_MEM_A_1, VTA_INSN_MEM_A_0);
  memop_pad_T y_pad_0 = insn.range(VTA_INSN_MEM_B_1, VTA_INSN_MEM_B_0);
  memop_pad_T y_pad_1 = insn.range(VTA_INSN_MEM_C_1, VTA_INSN_MEM_C_0);
  memop_pad_T x_pad_0 = insn.range(VTA_INSN_MEM_D_1, VTA_INSN_MEM_D_0);
  memop_pad_T x_pad_1 = insn.range(VTA_INSN_MEM_E_1, VTA_INSN_MEM_E_0);
532 533 534

  // Pop dependence token if instructed
  if (pop_prev_dependence) {
535
    g2s_dep_queue.read();
536 537 538 539 540 541 542 543 544
  }

  // Initialize indices
  memop_sram_T sram_idx = sram_base;
  memop_dram_T dram_idx = dram_base;

  // Skip padding along y dimension
  memop_sram_T y_offset = (x_pad_0 + x_size + x_pad_1) * y_pad_0;
  sram_idx += y_offset;
545 546
// Force this computation to be done with LUTs to avoid using too many DSPs
#pragma HLS RESOURCE variable = y_offset core = Mul_LUT
547 548

  // Copy along y dimension
549
  for (int y = 0; y < y_size; y++) {
550 551 552 553 554
#pragma HLS PIPELINE rewind
    // Skip padding along x dimension
    sram_idx += x_pad_0;
    // Perform data transfer
    memcpy(
555
      const_cast<out_vec_T*>(&outputs[dram_idx*VTA_BATCH]),
556
      (const out_vec_T*) &out_mem[sram_idx][0],
557
      x_size * VTA_INP_ELEM_BYTES);
558 559 560 561 562 563 564 565
    sram_idx += x_size;
    dram_idx += x_stride;
    // Skip padding along x dimension
    sram_idx += x_pad_1;
  }

  // Push dependence token if instructed
  if (push_prev_dependence) {
566
    s2g_dep_queue.write(1);
567 568 569
  }
}

570
void vta(
571 572 573 574 575 576 577
  uint32_t insn_count,
  volatile insn_T *insns,
  volatile uop_T *uops,
  volatile inp_vec_T *inputs,
  volatile wgt_vec_T *weights,
  volatile acc_vec_T *biases,
  volatile out_vec_T *outputs) {
578 579 580 581 582 583 584 585
#pragma HLS INTERFACE s_axilite port = insn_count bundle = CONTROL_BUS
#pragma HLS INTERFACE m_axi port = insns offset = slave bundle = ins_port
#pragma HLS INTERFACE m_axi port = uops offset = slave bundle = uop_port
#pragma HLS INTERFACE m_axi port = inputs offset = slave bundle = data_port
#pragma HLS INTERFACE m_axi port = weights offset = slave bundle = data_port
#pragma HLS INTERFACE m_axi port = biases offset = slave bundle = data_port
#pragma HLS INTERFACE m_axi port = outputs offset = slave bundle = data_port
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603

  // Instantiate temporary instruction queues (used for peeking)
  hls::stream<insn_T> tmp_load_queue;
  hls::stream<insn_T> tmp_gemm_queue;
  hls::stream<insn_T> tmp_store_queue;

  // Instatiate physical instruction queues
  hls::stream<insn_T> load_queue;
  hls::stream<insn_T> gemm_queue;
  hls::stream<insn_T> store_queue;

  // Dependence queues
  hls::stream<bool> l2g_dep_queue;
  hls::stream<bool> s2g_dep_queue;
  hls::stream<bool> g2l_dep_queue;
  hls::stream<bool> g2s_dep_queue;

  // Instantiate memories
604 605 606
  inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH];
  wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT];
  out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH];
607 608

  // Push all instructions into the queues
609
  fetch(insn_count, insns, tmp_load_queue, tmp_gemm_queue, tmp_store_queue);
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634

  // Global done indicator
  uint32_t done = 0;

  // Temporary instructions
  insn_T tmp_load;
  insn_T tmp_gemv;
  insn_T tmp_store;

  // Peeking status
  bool tmp_load_popped = false;
  bool tmp_gemm_popped = false;
  bool tmp_store_popped = false;
  int exit_counter = 0;

  // Main control loop
  while (true) {
    // First execute as many load instructions as possible
    while (!tmp_load_queue.empty() || tmp_load_popped == true) {
      // Pop the load instruction
      if (!tmp_load_popped) {
        tmp_load_queue.read(tmp_load);
        tmp_load_popped = true;
      }
      // Check dependences and invoke the load stage
635
      bool pop_next_dependence = tmp_load[VTA_INSN_MEM_2];
636 637 638 639 640
      if ((pop_next_dependence && !g2l_dep_queue.empty()) ||
          !pop_next_dependence) {
        // Push the instruction in the load queue
        load_queue.write(tmp_load);
        tmp_load_popped = false;
641
        load(inputs, weights, load_queue, g2l_dep_queue, l2g_dep_queue, inp_mem, wgt_mem);
642 643 644 645 646 647 648 649 650 651 652 653 654
      } else {
        // Execution of load stage pending on completion of other stages, so break here...
        break;
      }
    }
    // Next execute as many gemm instructions as possible
    while (!tmp_gemm_queue.empty() || tmp_gemm_popped == true) {
      // Pop the gemm instruction
      if (!tmp_gemm_popped) {
        tmp_gemm_queue.read(tmp_gemv);
        tmp_gemm_popped = true;
      }
      // Check dependences and invoke the load stage
655 656
      bool pop_prev_dependence = tmp_gemv[VTA_INSN_MEM_1];
      bool pop_next_dependence = tmp_gemv[VTA_INSN_MEM_2];
657 658 659 660 661 662 663 664 665 666 667 668
      if (
        (pop_prev_dependence && !l2g_dep_queue.empty() &&
         pop_next_dependence && !s2g_dep_queue.empty()) ||
        (!pop_prev_dependence && pop_next_dependence &&
         !s2g_dep_queue.empty()) ||
        (pop_prev_dependence && !l2g_dep_queue.empty() &&
        !pop_next_dependence) ||
        (!pop_prev_dependence && !pop_next_dependence)
      ) {
        // Push the instruction in the load queue
        gemm_queue.write(tmp_gemv);
        tmp_gemm_popped = false;
669 670
        compute(done, uops, biases, gemm_queue, l2g_dep_queue, s2g_dep_queue,
                g2l_dep_queue, g2s_dep_queue, inp_mem, wgt_mem, out_mem);
671 672 673 674 675 676 677 678 679 680 681 682 683 684
      } else {
        // Execution of load stage pending on completion of other stages,
        // so break here...
        break;
      }
    }
    // Finally execute as many store instructions as possible
    while (!tmp_store_queue.empty() || tmp_store_popped == true) {
      // Pop the load instruction
      if (!tmp_store_popped) {
        tmp_store_queue.read(tmp_store);
        tmp_store_popped = true;
      }
      // Check dependences and invoke the load stage
685
      bool pop_prev_dependence = tmp_store[VTA_INSN_MEM_1];
686 687 688 689 690
      if ((pop_prev_dependence && !g2s_dep_queue.empty()) ||
          !pop_prev_dependence) {
        // Push the instruction in the load queue
        store_queue.write(tmp_store);
        tmp_store_popped = false;
691
        store(outputs, store_queue, g2s_dep_queue, s2g_dep_queue, out_mem);
692 693 694 695 696 697 698 699 700
      } else {
        // Execution of load stage pending on completion of other stages, so break here...
        break;
      }
    }
    // Check if we get a signal that we are done
    if (done) {
      break;
    }
701
    exit_counter++;
702 703 704 705 706 707 708
    if (exit_counter > 1000) {
      if (tmp_load_popped) {
        if (g2l_dep_queue.empty()) {
          printf("waiting on g2l\n");
        }
      }
      if (tmp_gemm_popped) {
709
        if (l2g_dep_queue.empty() && tmp_gemv[VTA_INSN_MEM_1]) {
710 711
          printf("waiting on l2g\n");
        }
712
        if (s2g_dep_queue.empty() && tmp_gemv[VTA_INSN_MEM_2]) {
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
          printf("waiting on s2g\n");
        }
      }
      if (tmp_store_popped) {
        if (g2s_dep_queue.empty()) {
          printf("waiting on g2s\n");
        }
      }
      break;
    }
  }

  // Ensure that the tokens are empty
  bool tmp_tok;
  int l2g_count = 0;
  int s2g_count = 0;
  int g2l_count = 0;
  int g2s_count = 0;
731 732
  while (l2g_dep_queue.read_nb(tmp_tok)) {
    l2g_count++;
733
  }
734 735
  while (s2g_dep_queue.read_nb(tmp_tok)) {
    s2g_count++;
736
  }
737 738
  while (g2l_dep_queue.read_nb(tmp_tok)) {
    g2l_count++;
739
  }
740 741
  while (g2s_dep_queue.read_nb(tmp_tok)) {
    g2s_count++;
742 743 744 745
  }

  assert(l2g_count == 0 && g2s_count == 0 && g2l_count == 0 && g2s_count == 0);
}