runtime.h 9.87 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file runtime.h
 * \brief VTA runtime library.
 */

25 26
#ifndef VTA_RUNTIME_RUNTIME_H_
#define VTA_RUNTIME_RUNTIME_H_
27 28 29 30 31

#ifdef __cplusplus
extern "C" {
#endif

32
#include <tvm/runtime/c_runtime_api.h>
33
#include <vta/driver.h>
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

#define VTA_MEMCPY_H2D 1
#define VTA_MEMCPY_D2H 2
#define VTA_MEMCPY_D2D 3

#define VTA_DEBUG_DUMP_INSN (1 << 1)
#define VTA_DEBUG_DUMP_UOP (1 << 2)
#define VTA_DEBUG_SKIP_READ_BARRIER (1 << 3)
#define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4)
#define VTA_DEBUG_FORCE_SERIAL (1 << 5)

/*!
 * \brief Allocate data buffer.
 * \param size Buffer size.
 * \return A pointer to the allocated buffer.
 */
50
TVM_DLL void* VTABufferAlloc(size_t size);
51 52 53 54 55

/*!
 * \brief Free data buffer.
 * \param buffer The data buffer to be freed.
 */
56
TVM_DLL void VTABufferFree(void* buffer);
57 58 59 60 61 62 63 64 65 66

/*!
 * \brief Copy data buffer from one location to another.
 * \param from The source buffer base address.
 * \param from_offset The offset of the source buffer.
 * \param to The target buffer base address.
 * \param to_offset The offset of the target buffer.
 * \param size Size of copy.
 * \param kind_mask The memory copy kind.
 */
67 68 69 70 71 72
TVM_DLL void VTABufferCopy(const void* from,
                           size_t from_offset,
                           void* to,
                           size_t to_offset,
                           size_t size,
                           int kind_mask);
73

74 75 76 77
/*! \brief VTA command handle */
typedef void* VTACommandHandle;

/*! \brief Shutdown hook of VTA to cleanup resources */
78
TVM_DLL void VTARuntimeShutdown();
79

80
/*!
81 82 83
 * \brief Get thread local command handle.
 * \return A thread local command handle.
 */
84
TVM_DLL VTACommandHandle VTATLSCommandHandle();
85 86 87

/*!
 * \brief Get the buffer access pointer on CPU.
88
 * \param cmd The VTA command handle.
89 90
 * \param buffer The data buffer.
 * \return The pointer that can be accessed by the CPU.
91
 */
92
TVM_DLL void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
93 94 95 96 97 98 99 100 101

/*!
 * \brief Perform a write barrier to make a memory region visible to the CPU.
 * \param cmd The VTA command handle.
 * \param buffer The head buffer pointer.
 * \param elem_bits The size in bits of each element.
 * \param start The start of the region (in elements).
 * \param extent The end of the region (in elements).
 */
102 103 104 105 106
TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd,
                             void* buffer,
                             uint32_t elem_bits,
                             uint32_t start,
                             uint32_t extent);
107

108 109 110 111 112 113 114 115
/*!
 * \brief Perform a read barrier to a memory region visible to VTA.
 * \param cmd The VTA command handle.
 * \param buffer The head buffer pointer.
 * \param elem_bits The unit bits of each elements.
 * \param start The start of the region (in elements).
 * \param extent The end of the region (in elements).
 */
116 117 118 119 120
TVM_DLL void VTAReadBarrier(VTACommandHandle cmd,
                            void* buffer,
                            uint32_t elem_bits,
                            uint32_t start,
                            uint32_t extent);
121 122 123 124 125 126

/*!
 * \brief Set debug mode on the command handle.
 * \param cmd The VTA command handle.
 * \param debug_flag The debug flag.
 */
127
TVM_DLL void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

/*!
 * \brief Perform a 2D data load from DRAM.
 *  Sizes are measured in units of vector elements.
 * \param cmd The VTA command handle.
 * \param src_dram_addr Source DRAM address.
 * \param src_elem_offset The source DRAM offset in number of unit elements.
 * \param x_size The lowest dimension (x axis) size in number of unit elements.
 * \param y_size The number of rows (y axis).
 * \param x_stride The x axis stride.
 * \param x_pad_before The start padding on x axis.
 * \param y_pad_before The start padding on y axis.
 * \param x_pad_after The end padding on x axis.
 * \param y_pad_after The end padding of y axis.
 * \param dst_sram_index Destination SRAM index.
 * \param dst_memory_type Destination memory type.
 */
145 146 147 148 149 150 151 152 153 154 155 156
TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd,
                             void* src_dram_addr,
                             uint32_t src_elem_offset,
                             uint32_t x_size,
                             uint32_t y_size,
                             uint32_t x_stride,
                             uint32_t x_pad_before,
                             uint32_t y_pad_before,
                             uint32_t x_pad_after,
                             uint32_t y_pad_after,
                             uint32_t dst_sram_index,
                             uint32_t dst_memory_type);
157 158 159 160 161 162 163 164

/*!
 * \brief Perform a 2D data store into DRAM
 *  Sizes are measured in units of vector elements.
 * \param cmd The VTA command handle.
 * \param src_sram_index Source SRAM index.
 * \param src_memory_type Source memory type.
 * \param dst_dram_addr Destination DRAM address.
165
 * \param dst_elem_offset The destination DRAM offset in number of unit elements.
166 167 168 169
 * \param x_size The lowest dimension (x axis) size in number of unit elements.
 * \param y_size The number of rows.
 * \param x_stride The x axis stride.
 */
170 171 172 173 174 175 176 177
TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd,
                              uint32_t src_sram_index,
                              uint32_t src_memory_type,
                              void* dst_dram_addr,
                              uint32_t dst_elem_offset,
                              uint32_t x_size,
                              uint32_t y_size,
                              uint32_t x_stride);
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209

/*!
 * \brief Push uop into kernel buffer.
 * In GEMM mode, do a blocked GEMM with 2d access pattern.
 * In ALU mode, do a vectorized ALU operation with 2d access pattern.
 *
 *  \code
 *
 *   DType accum[INP_BUFF_DEPTH][l][n];
 *   DType weight[WGT_BUFF_DEPTH][n][m];
 *   DType input[INP_BUFF_DEPTH][l][m];
 *   if reset_out == 1
 *    accum[dst_index] = 0
 *   elif mode == 0
 *    accum[dst_index] += GEMM(input[src_index], weight[wgt_index]);
 *   else
 *    if (use_imm)
 *      accum[dst_index] = opcode(accum[dst_index], imm_val);
 *    else
 *      accum[dst_index] = opcode(accum[dst_index], accum[src_index]);
 *
 *  \endcode
 *
 * \param mode Set to GEMM mode if set to 0, ALU mode is set to 1.
 * \param reset_out Resets the accum to 0.
 * \param dst_index The accum memory index.
 * \param src_index The input memory (gemm) / accum memory (alu) index.
 * \param wgt_index The weight memory index.
 * \param opcode The ALU opcode.
 * \param use_imm Use immediate in ALU mode if set to true.
 * \param imm_val Immediate value in ALU mode.
 */
210 211 212 213 214 215 216 217
TVM_DLL void VTAUopPush(uint32_t mode,
                        uint32_t reset_out,
                        uint32_t dst_index,
                        uint32_t src_index,
                        uint32_t wgt_index,
                        uint32_t opcode,
                        uint32_t use_imm,
                        int32_t imm_val);
218 219 220 221 222 223 224 225

/*!
 * \brief Mark start of a micro op loop.
 * \param extent The extent of the loop.
 * \param dst_factor The accum factor.
 * \param src_factor The input factor.
 * \param wgt_factor The weight factor.
 */
226 227 228 229
TVM_DLL void VTAUopLoopBegin(uint32_t extent,
                             uint32_t dst_factor,
                             uint32_t src_factor,
                             uint32_t wgt_factor);
230 231 232 233

/*!
 * \brief Mark end of a micro op loop.
 */
234
TVM_DLL void VTAUopLoopEnd();
235 236 237 238 239 240 241 242 243

/*!
 * \brief Push GEMM uop kernel into the command handle.
 * \param uop_handle The uop cache handle.
 * \param finit The initalization function to initialize uop.
 * \param signature The closure arguments of the finit.
 * \param nbytes Number of bytes to in the closure arguments.
 * \return 0 if success.
 */
244 245 246 247
TVM_DLL int VTAPushGEMMOp(void** uop_handle,
                          int (*finit)(void*),
                          void* signature,
                          int nbytes);
248 249 250 251 252 253 254 255 256

/*!
 * \brief Push ALU uop kernel into the command handle.
 * \param uop_handle The uop cache handle.
 * \param finit The initalization function to initialize uop.
 * \param signature The closure arguments of the finit.
 * \param nbytes Number of bytes to in the closure arguments.
 * \return 0 if success.
 */
257 258 259 260
TVM_DLL int VTAPushALUOp(void** uop_handle,
                         int (*finit)(void*),
                         void* signature,
                         int nbytes);
261 262 263 264 265 266 267 268

/*!
 * \brief Push dependence token.
 * \param cmd The VTA command handle.
 * \param from_qid The source queue.
 * \param to_qid The destination queue.
 * \return 0 if success.
 */
269
TVM_DLL int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid);
270 271 272 273 274 275 276 277

/*!
 * \brief Pop dependence signal.
 * \param cmd The VTA command handle.
 * \param from_qid The source queue.
 * \param to_qid The destination queue.
 * \return 0 if success.
 */
278
TVM_DLL int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid);
279 280 281 282 283 284 285 286 287 288

/*!
 * \brief Synchronize the command handle.
 *  Commit all the instructions to VTA and wait until
 *  the accelerator finishes its job.
 *  Perform all of the out-of-order DRAM stores.
 * \param cmd The VTA command handle.
 * \param wait_cycles The limit of poll cycles.
 *
 */
289
TVM_DLL void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles);
290 291 292 293

#ifdef __cplusplus
}
#endif
294
#endif  // VTA_RUNTIME_RUNTIME_H_