Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
32f74f31
Commit
32f74f31
authored
Jun 05, 2019
by
Luis Vega
Committed by
Tianqi Chen
Jun 05, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[VTA] [Hardware] Chisel implementation (#3258)
parent
1f62d956
Hide whitespace changes
Inline
Side-by-side
Showing
43 changed files
with
4784 additions
and
23 deletions
+4784
-23
cmake/config.cmake
+0
-3
cmake/modules/VTA.cmake
+8
-8
vta/apps/tsim_example/README.md
+1
-1
vta/apps/tsim_example/cmake/modules/hw.cmake
+1
-1
vta/hardware/chisel/Makefile
+76
-0
vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
+1
-1
vta/hardware/chisel/src/main/scala/core/Compute.scala
+201
-0
vta/hardware/chisel/src/main/scala/core/Configs.scala
+46
-0
vta/hardware/chisel/src/main/scala/core/Core.scala
+109
-0
vta/hardware/chisel/src/main/scala/core/Decode.scala
+229
-0
vta/hardware/chisel/src/main/scala/core/Fetch.scala
+197
-0
vta/hardware/chisel/src/main/scala/core/ISA.scala
+93
-0
vta/hardware/chisel/src/main/scala/core/Load.scala
+131
-0
vta/hardware/chisel/src/main/scala/core/LoadUop.scala
+214
-0
vta/hardware/chisel/src/main/scala/core/Semaphore.scala
+42
-0
vta/hardware/chisel/src/main/scala/core/Store.scala
+114
-0
vta/hardware/chisel/src/main/scala/core/TensorAlu.scala
+295
-0
vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
+364
-0
vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
+278
-0
vta/hardware/chisel/src/main/scala/core/TensorStore.scala
+224
-0
vta/hardware/chisel/src/main/scala/core/TensorUtil.scala
+304
-0
vta/hardware/chisel/src/main/scala/core/package.scala
+23
-0
vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala
+83
-0
vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
+98
-0
vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala
+312
-0
vta/hardware/chisel/src/main/scala/shell/Configs.scala
+51
-0
vta/hardware/chisel/src/main/scala/shell/SimShell.scala
+78
-0
vta/hardware/chisel/src/main/scala/shell/VCR.scala
+242
-0
vta/hardware/chisel/src/main/scala/shell/VME.scala
+254
-0
vta/hardware/chisel/src/main/scala/shell/VTAShell.scala
+57
-0
vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala
+117
-0
vta/hardware/chisel/src/main/scala/test/Test.scala
+33
-0
vta/hardware/chisel/src/main/scala/util/Config.scala
+104
-0
vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala
+40
-0
vta/hardware/chisel/src/main/scala/vta/Configs.scala
+51
-0
vta/hardware/dpi/tsim_device.cc
+10
-0
vta/include/vta/driver.h
+16
-0
vta/python/vta/environment.py
+1
-1
vta/python/vta/testing/simulator.py
+19
-0
vta/python/vta/testing/util.py
+3
-2
vta/src/runtime.cc
+59
-4
vta/src/tsim/tsim_driver.cc
+179
-0
vta/tests/python/unittest/test_vta_insn.py
+26
-2
No files found.
cmake/config.cmake
View file @
32f74f31
...
...
@@ -132,9 +132,6 @@ set(USE_SORT ON)
# Build ANTLR parser for Relay text format
set
(
USE_ANTLR OFF
)
# Build TSIM for VTA
set
(
USE_VTA_TSIM OFF
)
# Whether use Relay debug mode
set
(
USE_RELAY_DEBUG OFF
)
cmake/modules/VTA.cmake
View file @
32f74f31
...
...
@@ -29,8 +29,7 @@ elseif(PYTHON)
--use-cfg=
${
CMAKE_CURRENT_BINARY_DIR
}
/vta_config.json
)
endif
()
execute_process
(
COMMAND
${
VTA_CONFIG
}
--target OUTPUT_VARIABLE __vta_target
)
string
(
STRIP
${
__vta_target
}
VTA_TARGET
)
execute_process
(
COMMAND
${
VTA_CONFIG
}
--target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE
)
message
(
STATUS
"Build VTA runtime with target: "
${
VTA_TARGET
}
)
...
...
@@ -44,6 +43,13 @@ elseif(PYTHON)
add_library
(
vta SHARED
${
VTA_RUNTIME_SRCS
}
)
if
(
${
VTA_TARGET
}
STREQUAL
"tsim"
)
target_compile_definitions
(
vta PUBLIC USE_TSIM
)
include_directories
(
"vta/include"
)
file
(
GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc
)
list
(
APPEND RUNTIME_SRCS
${
RUNTIME_DPI_SRCS
}
)
endif
()
target_include_directories
(
vta PUBLIC vta/include
)
foreach
(
__def
${
VTA_DEFINITIONS
}
)
...
...
@@ -61,12 +67,6 @@ elseif(PYTHON)
target_link_libraries
(
vta
${
__cma_lib
}
)
endif
()
if
(
NOT USE_VTA_TSIM STREQUAL
"OFF"
)
include_directories
(
"vta/include"
)
file
(
GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc
)
list
(
APPEND RUNTIME_SRCS
${
RUNTIME_DPI_SRCS
}
)
endif
()
else
()
message
(
STATUS
"Cannot found python in env, VTA build is skipped.."
)
endif
()
vta/apps/tsim_example/README.md
View file @
32f74f31
...
...
@@ -49,7 +49,7 @@ sudo apt install verilator sbt
## Setup in TVM
1.
Install
`verilator`
and
`sbt`
as described above
2.
Enable VTA TSIM by turning on the switch
`USE_VTA_TSIM`
in config.cmake
2.
Set the VTA TARGET to
`tsim`
on
`<tvm-root>/vta/config/vta_config.json`
3.
Build tvm
## How to run VTA TSIM examples
...
...
vta/apps/tsim_example/cmake/modules/hw.cmake
View file @
32f74f31
...
...
@@ -124,7 +124,7 @@ else()
file
(
GLOB VERILATOR_SRC
${
VTA_HW_DPI_DIR
}
/tsim_device.cc
)
add_library
(
hw SHARED
${
VERILATOR_LIB_SRC
}
${
VERILATOR_GEN_SRC
}
${
VERILATOR_SRC
}
)
set
(
VERILATOR_DEF VL_TSIM_NAME=V
${
TSIM_TOP_NAME
}
VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0
)
set
(
VERILATOR_DEF VL_
USER_FINISH VL_
TSIM_NAME=V
${
TSIM_TOP_NAME
}
VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0
)
if
(
NOT TSIM_USE_TRACE STREQUAL
"OFF"
)
list
(
APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=
${
TSIM_BUILD_DIR
}
/
${
TSIM_TRACE_NAME
}
.vcd
)
else
()
...
...
vta/hardware/chisel/Makefile
View file @
32f74f31
...
...
@@ -15,5 +15,81 @@
# specific language governing permissions and limitations
# under the License.
CONFIG
=
DefaultF1Config
TOP
=
VTA
TOP_TEST
=
Test
BUILD_NAME
=
build
USE_TRACE
=
0
VTA_LIBNAME
=
libvta_hw
config_test
=
$(TOP_TEST)$(CONFIG)
vta_dir
=
$
(
abspath ../../
)
tvm_dir
=
$
(
abspath ../../../
)
verilator_inc_dir
=
/usr/local/share/verilator/include
verilator_build_dir
=
$(vta_dir)
/
$(BUILD_NAME)
/verilator
chisel_build_dir
=
$(vta_dir)
/
$(BUILD_NAME)
/chisel
verilator_opt
=
--cc
verilator_opt
+=
+define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt
+=
+define+RANDOMIZE_REG_INIT
verilator_opt
+=
+define+RANDOMIZE_MEM_INIT
verilator_opt
+=
--x-assign
unique
verilator_opt
+=
--output-split
20000
verilator_opt
+=
--output-split-cfuncs
20000
verilator_opt
+=
--top-module
${
TOP_TEST
}
verilator_opt
+=
-Mdir
${
verilator_build_dir
}
verilator_opt
+=
-I
$(chisel_build_dir)
cxx_flags
=
-O2
-Wall
-fPIC
-shared
cxx_flags
+=
-fvisibility
=
hidden
-std
=
c++11
cxx_flags
+=
-DVL_TSIM_NAME
=
V
$(TOP_TEST)
cxx_flags
+=
-DVL_PRINTF
=
printf
cxx_flags
+=
-DVL_USER_FINISH
cxx_flags
+=
-DVM_COVERAGE
=
0
cxx_flags
+=
-DVM_SC
=
0
cxx_flags
+=
-Wno-sign-compare
cxx_flags
+=
-include
V
$(TOP_TEST)
.h
cxx_flags
+=
-I
$(verilator_build_dir)
cxx_flags
+=
-I
$(verilator_inc_dir)
cxx_flags
+=
-I
$(verilator_inc_dir)
/vltstd
cxx_flags
+=
-I
$(vta_dir)
/include
cxx_flags
+=
-I
$(tvm_dir)
/include
cxx_flags
+=
-I
$(tvm_dir)
/3rdparty/dlpack/include
cxx_files
=
$(verilator_inc_dir)
/verilated.cpp
cxx_files
+=
$(verilator_inc_dir)
/verilated_dpi.cpp
cxx_files
+=
$
(
wildcard
$(verilator_build_dir)
/
*
.cpp
)
cxx_files
+=
$(vta_dir)
/hardware/dpi/tsim_device.cc
ifneq
($(USE_TRACE),
0)
verilator_opt
+=
--trace
cxx_flags
+=
-DVM_TRACE
=
1
cxx_flags
+=
-DTSIM_TRACE_FILE
=
$(verilator_build_dir)
/
$(TOP_TEST)
.vcd
cxx_files
+=
$(verilator_inc_dir)
/verilated_vcd_c.cpp
else
cxx_flags
+=
-DVM_TRACE
=
0
endif
default
:
lib
lib
:
$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
:
$(verilator_build_dir)/V$(TOP_TEST).cpp
g++
$(cxx_flags)
$(cxx_files)
-o
$@
verilator
:
$(verilator_build_dir)/V$(TOP_TEST).cpp
$(verilator_build_dir)/V$(TOP_TEST).cpp
:
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
verilator
$(verilator_opt)
$<
verilog
:
$(chisel_build_dir)/$(TOP).$(CONFIG).v
$(chisel_build_dir)/$(TOP).$(CONFIG).v
:
sbt
'runMain vta.
$(CONFIG)
--target-dir
$(chisel_build_dir)
--top-name
$(TOP)
.
$(CONFIG)
'
verilog_test
:
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
:
sbt
'runMain vta.
$(config_test)
--target-dir
$(chisel_build_dir)
--top-name
$(TOP_TEST)
.
$(CONFIG)
'
clean
:
-
rm
-rf
target project/target project/project
cleanall
:
-
rm
-rf
$(vta_dir)
/
$(BUILD_NAME)
vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
View file @
32f74f31
...
...
@@ -112,7 +112,7 @@ module VTAHostDPI #
always_ff
@
(
posedge
clock
)
begin
if
(
__
exit
==
'd1
)
begin
$
display
(
"[
DONE]
at cycle:%016d"
,
cycles
)
;
$
display
(
"[
TSIM] Verilog $finish called
at cycle:%016d"
,
cycles
)
;
$
finish
;
end
end
...
...
vta/hardware/chisel/src/main/scala/core/Compute.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** Compute.
*
* The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
class
Compute
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
i_post
=
Vec
(
2
,
Input
(
Bool
()))
val
o_post
=
Vec
(
2
,
Output
(
Bool
()))
val
inst
=
Flipped
(
Decoupled
(
UInt
(
INST_BITS
.
W
)))
val
uop_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
acc_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_rd
=
Vec
(
2
,
new
VMEReadMaster
)
val
inp
=
new
TensorMaster
(
tensorType
=
"inp"
)
val
wgt
=
new
TensorMaster
(
tensorType
=
"wgt"
)
val
out
=
new
TensorMaster
(
tensorType
=
"out"
)
val
finish
=
Output
(
Bool
())
})
val
sIdle
::
sSync
::
sExe
::
Nil
=
Enum
(
3
)
val
state
=
RegInit
(
sIdle
)
val
s
=
Seq
.
tabulate
(
2
)(
_
=>
Module
(
new
Semaphore
(
counterBits
=
8
,
counterInitValue
=
0
)))
val
loadUop
=
Module
(
new
LoadUop
)
val
tensorAcc
=
Module
(
new
TensorLoad
(
tensorType
=
"acc"
))
val
tensorGemm
=
Module
(
new
TensorGemm
)
val
tensorAlu
=
Module
(
new
TensorAlu
)
val
inst_q
=
Module
(
new
Queue
(
UInt
(
INST_BITS
.
W
),
p
(
CoreKey
).
instQueueEntries
))
// decode
val
dec
=
Module
(
new
ComputeDecode
)
dec
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
val
inst_type
=
Cat
(
dec
.
io
.
isFinish
,
dec
.
io
.
isAlu
,
dec
.
io
.
isGemm
,
dec
.
io
.
isLoadAcc
,
dec
.
io
.
isLoadUop
).
asUInt
val
sprev
=
inst_q
.
io
.
deq
.
valid
&
Mux
(
dec
.
io
.
pop_prev
,
s
(
0
).
io
.
sready
,
true
.
B
)
val
snext
=
inst_q
.
io
.
deq
.
valid
&
Mux
(
dec
.
io
.
pop_next
,
s
(
1
).
io
.
sready
,
true
.
B
)
val
start
=
snext
&
sprev
val
done
=
MuxLookup
(
inst_type
,
false
.
B
,
// default
Array
(
"h_01"
.
U
->
loadUop
.
io
.
done
,
"h_02"
.
U
->
tensorAcc
.
io
.
done
,
"h_04"
.
U
->
tensorGemm
.
io
.
done
,
"h_08"
.
U
->
tensorAlu
.
io
.
done
,
"h_10"
.
U
->
true
.
B
// Finish
)
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
start
)
{
when
(
dec
.
io
.
isSync
)
{
state
:=
sSync
}
.
elsewhen
(
inst_type
.
orR
)
{
state
:=
sExe
}
}
}
is
(
sSync
)
{
state
:=
sIdle
}
is
(
sExe
)
{
when
(
done
)
{
state
:=
sIdle
}
}
}
// instructions
inst_q
.
io
.
enq
<>
io
.
inst
inst_q
.
io
.
deq
.
ready
:=
(
state
===
sExe
&
done
)
|
(
state
===
sSync
)
// uop
loadUop
.
io
.
start
:=
state
===
sIdle
&
start
&
dec
.
io
.
isLoadUop
loadUop
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
loadUop
.
io
.
baddr
:=
io
.
uop_baddr
io
.
vme_rd
(
0
)
<>
loadUop
.
io
.
vme_rd
loadUop
.
io
.
uop
.
idx
<>
Mux
(
dec
.
io
.
isGemm
,
tensorGemm
.
io
.
uop
.
idx
,
tensorAlu
.
io
.
uop
.
idx
)
// acc
tensorAcc
.
io
.
start
:=
state
===
sIdle
&
start
&
dec
.
io
.
isLoadAcc
tensorAcc
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
tensorAcc
.
io
.
baddr
:=
io
.
acc_baddr
tensorAcc
.
io
.
tensor
.
rd
.
idx
<>
Mux
(
dec
.
io
.
isGemm
,
tensorGemm
.
io
.
acc
.
rd
.
idx
,
tensorAlu
.
io
.
acc
.
rd
.
idx
)
tensorAcc
.
io
.
tensor
.
wr
<>
Mux
(
dec
.
io
.
isGemm
,
tensorGemm
.
io
.
acc
.
wr
,
tensorAlu
.
io
.
acc
.
wr
)
io
.
vme_rd
(
1
)
<>
tensorAcc
.
io
.
vme_rd
// gemm
tensorGemm
.
io
.
start
:=
state
===
sIdle
&
start
&
dec
.
io
.
isGemm
tensorGemm
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
tensorGemm
.
io
.
uop
.
data
.
valid
:=
loadUop
.
io
.
uop
.
data
.
valid
&
dec
.
io
.
isGemm
tensorGemm
.
io
.
uop
.
data
.
bits
<>
loadUop
.
io
.
uop
.
data
.
bits
tensorGemm
.
io
.
inp
<>
io
.
inp
tensorGemm
.
io
.
wgt
<>
io
.
wgt
tensorGemm
.
io
.
acc
.
rd
.
data
.
valid
:=
tensorAcc
.
io
.
tensor
.
rd
.
data
.
valid
&
dec
.
io
.
isGemm
tensorGemm
.
io
.
acc
.
rd
.
data
.
bits
<>
tensorAcc
.
io
.
tensor
.
rd
.
data
.
bits
tensorGemm
.
io
.
out
.
rd
.
data
.
valid
:=
io
.
out
.
rd
.
data
.
valid
&
dec
.
io
.
isGemm
tensorGemm
.
io
.
out
.
rd
.
data
.
bits
<>
io
.
out
.
rd
.
data
.
bits
// alu
tensorAlu
.
io
.
start
:=
state
===
sIdle
&
start
&
dec
.
io
.
isAlu
tensorAlu
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
tensorAlu
.
io
.
uop
.
data
.
valid
:=
loadUop
.
io
.
uop
.
data
.
valid
&
dec
.
io
.
isAlu
tensorAlu
.
io
.
uop
.
data
.
bits
<>
loadUop
.
io
.
uop
.
data
.
bits
tensorAlu
.
io
.
acc
.
rd
.
data
.
valid
:=
tensorAcc
.
io
.
tensor
.
rd
.
data
.
valid
&
dec
.
io
.
isAlu
tensorAlu
.
io
.
acc
.
rd
.
data
.
bits
<>
tensorAcc
.
io
.
tensor
.
rd
.
data
.
bits
tensorAlu
.
io
.
out
.
rd
.
data
.
valid
:=
io
.
out
.
rd
.
data
.
valid
&
dec
.
io
.
isAlu
tensorAlu
.
io
.
out
.
rd
.
data
.
bits
<>
io
.
out
.
rd
.
data
.
bits
// out
io
.
out
.
rd
.
idx
<>
Mux
(
dec
.
io
.
isGemm
,
tensorGemm
.
io
.
out
.
rd
.
idx
,
tensorAlu
.
io
.
out
.
rd
.
idx
)
io
.
out
.
wr
<>
Mux
(
dec
.
io
.
isGemm
,
tensorGemm
.
io
.
out
.
wr
,
tensorAlu
.
io
.
out
.
wr
)
// semaphore
s
(
0
).
io
.
spost
:=
io
.
i_post
(
0
)
s
(
1
).
io
.
spost
:=
io
.
i_post
(
1
)
s
(
0
).
io
.
swait
:=
dec
.
io
.
pop_prev
&
(
state
===
sIdle
&
start
)
s
(
1
).
io
.
swait
:=
dec
.
io
.
pop_next
&
(
state
===
sIdle
&
start
)
io
.
o_post
(
0
)
:=
dec
.
io
.
push_prev
&
((
state
===
sExe
&
done
)
|
(
state
===
sSync
))
io
.
o_post
(
1
)
:=
dec
.
io
.
push_next
&
((
state
===
sExe
&
done
)
|
(
state
===
sSync
))
// finish
io
.
finish
:=
state
===
sExe
&
done
&
dec
.
io
.
isFinish
// debug
if
(
debug
)
{
// start
when
(
state
===
sIdle
&&
start
)
{
when
(
dec
.
io
.
isSync
)
{
printf
(
"[Compute] start sync\n"
)
}
.
elsewhen
(
dec
.
io
.
isLoadUop
)
{
printf
(
"[Compute] start load uop\n"
)
}
.
elsewhen
(
dec
.
io
.
isLoadAcc
)
{
printf
(
"[Compute] start load acc\n"
)
}
.
elsewhen
(
dec
.
io
.
isGemm
)
{
printf
(
"[Compute] start gemm\n"
)
}
.
elsewhen
(
dec
.
io
.
isAlu
)
{
printf
(
"[Compute] start alu\n"
)
}
.
elsewhen
(
dec
.
io
.
isFinish
)
{
printf
(
"[Compute] start finish\n"
)
}
}
// done
when
(
state
===
sSync
)
{
printf
(
"[Compute] done sync\n"
)
}
when
(
state
===
sExe
)
{
when
(
done
)
{
when
(
dec
.
io
.
isLoadUop
)
{
printf
(
"[Compute] done load uop\n"
)
}
.
elsewhen
(
dec
.
io
.
isLoadAcc
)
{
printf
(
"[Compute] done load acc\n"
)
}
.
elsewhen
(
dec
.
io
.
isGemm
)
{
printf
(
"[Compute] done gemm\n"
)
}
.
elsewhen
(
dec
.
io
.
isAlu
)
{
printf
(
"[Compute] done alu\n"
)
}
.
elsewhen
(
dec
.
io
.
isFinish
)
{
printf
(
"[Compute] done finish\n"
)
}
}
}
}
}
vta/hardware/chisel/src/main/scala/core/Configs.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
vta.util.config._
/** CoreConfig.
*
* This is one supported configuration for VTA. This file will
* be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends.
*/
class
CoreConfig
extends
Config
((
site
,
here
,
up
)
=>
{
case
CoreKey
=>
CoreParams
(
batch
=
1
,
blockOut
=
16
,
blockIn
=
16
,
inpBits
=
8
,
wgtBits
=
8
,
uopBits
=
32
,
accBits
=
32
,
outBits
=
8
,
uopMemDepth
=
2048
,
inpMemDepth
=
2048
,
wgtMemDepth
=
1024
,
accMemDepth
=
2048
,
outMemDepth
=
2048
,
instQueueEntries
=
512
)
})
vta/hardware/chisel/src/main/scala/core/Core.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
vta.util.config._
import
vta.shell._
/** Core parameters */
case
class
CoreParams
(
batch
:
Int
=
1
,
blockOut
:
Int
=
16
,
blockIn
:
Int
=
16
,
inpBits
:
Int
=
8
,
wgtBits
:
Int
=
8
,
uopBits
:
Int
=
32
,
accBits
:
Int
=
32
,
outBits
:
Int
=
8
,
uopMemDepth
:
Int
=
512
,
inpMemDepth
:
Int
=
512
,
wgtMemDepth
:
Int
=
512
,
accMemDepth
:
Int
=
512
,
outMemDepth
:
Int
=
512
,
instQueueEntries
:
Int
=
32
)
case
object
CoreKey
extends
Field
[
CoreParams
]
/** Core.
*
* The core defines the current VTA architecture by connecting memory and
* compute modules together such as load/store and compute. Most of the
* connections in the core are bulk (<>), and we should try to keep it this
* way, because it is easier to understand what is going on.
*
* Also, the core must be instantiated by a shell using the
* VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
* More info about these interfaces and modules can be found in the shell
* directory.
*/
class
Core
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
vcr
=
new
VCRClient
val
vme
=
new
VMEMaster
})
val
fetch
=
Module
(
new
Fetch
)
val
load
=
Module
(
new
Load
)
val
compute
=
Module
(
new
Compute
)
val
store
=
Module
(
new
Store
)
// Read(rd) and write(wr) from/to memory (i.e. DRAM)
io
.
vme
.
rd
(
0
)
<>
fetch
.
io
.
vme_rd
io
.
vme
.
rd
(
1
)
<>
compute
.
io
.
vme_rd
(
0
)
io
.
vme
.
rd
(
2
)
<>
load
.
io
.
vme_rd
(
0
)
io
.
vme
.
rd
(
3
)
<>
load
.
io
.
vme_rd
(
1
)
io
.
vme
.
rd
(
4
)
<>
compute
.
io
.
vme_rd
(
1
)
io
.
vme
.
wr
(
0
)
<>
store
.
io
.
vme_wr
// Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs)
fetch
.
io
.
launch
:=
io
.
vcr
.
launch
fetch
.
io
.
ins_baddr
:=
io
.
vcr
.
ptrs
(
0
)
fetch
.
io
.
ins_count
:=
io
.
vcr
.
vals
(
0
)
// Load inputs and weights from memory (DRAM) into scratchpads (SRAMs)
load
.
io
.
i_post
:=
compute
.
io
.
o_post
(
0
)
load
.
io
.
inst
<>
fetch
.
io
.
inst
.
ld
load
.
io
.
inp_baddr
:=
io
.
vcr
.
ptrs
(
2
)
load
.
io
.
wgt_baddr
:=
io
.
vcr
.
ptrs
(
3
)
// The compute module performs the following:
// - Load micro-ops (uops) and accumulations (acc)
// - Compute dense and ALU instructions (tasks)
compute
.
io
.
i_post
(
0
)
:=
load
.
io
.
o_post
compute
.
io
.
i_post
(
1
)
:=
store
.
io
.
o_post
compute
.
io
.
inst
<>
fetch
.
io
.
inst
.
co
compute
.
io
.
uop_baddr
:=
io
.
vcr
.
ptrs
(
1
)
compute
.
io
.
acc_baddr
:=
io
.
vcr
.
ptrs
(
4
)
compute
.
io
.
inp
<>
load
.
io
.
inp
compute
.
io
.
wgt
<>
load
.
io
.
wgt
// The store module performs the following:
// - Writes results from compute into scratchpads (SRAMs)
// - Store results from scratchpads (SRAMs) to memory (DRAM)
store
.
io
.
i_post
:=
compute
.
io
.
o_post
(
1
)
store
.
io
.
inst
<>
fetch
.
io
.
inst
.
st
store
.
io
.
out_baddr
:=
io
.
vcr
.
ptrs
(
5
)
store
.
io
.
out
<>
compute
.
io
.
out
// Finish instruction is executed and asserts the VCR finish flag
val
finish
=
RegNext
(
compute
.
io
.
finish
)
io
.
vcr
.
finish
:=
finish
}
vta/hardware/chisel/src/main/scala/core/Decode.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
ISA._
/** MemDecode.
*
* Decode memory instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions
* decoded with this bundle:
* - LUOP
* - LWGT
* - LINP
* - LACC
* - SOUT
*/
class
MemDecode
extends
Bundle
{
val
xpad_1
=
UInt
(
M_PAD_BITS
.
W
)
val
xpad_0
=
UInt
(
M_PAD_BITS
.
W
)
val
ypad_1
=
UInt
(
M_PAD_BITS
.
W
)
val
ypad_0
=
UInt
(
M_PAD_BITS
.
W
)
val
xstride
=
UInt
(
M_STRIDE_BITS
.
W
)
val
xsize
=
UInt
(
M_SIZE_BITS
.
W
)
val
ysize
=
UInt
(
M_SIZE_BITS
.
W
)
val
empty_0
=
UInt
(
7.
W
)
// derive this
val
dram_offset
=
UInt
(
M_DRAM_OFFSET_BITS
.
W
)
val
sram_offset
=
UInt
(
M_SRAM_OFFSET_BITS
.
W
)
val
id
=
UInt
(
M_ID_BITS
.
W
)
val
push_next
=
Bool
()
val
push_prev
=
Bool
()
val
pop_next
=
Bool
()
val
pop_prev
=
Bool
()
val
op
=
UInt
(
OP_BITS
.
W
)
}
/** GemmDecode.
*
* Decode GEMM instruction with a Bundle. This is similar to an union,
* therefore order matters when declaring fields.
*/
class
GemmDecode
extends
Bundle
{
val
wgt_1
=
UInt
(
C_WIDX_BITS
.
W
)
val
wgt_0
=
UInt
(
C_WIDX_BITS
.
W
)
val
inp_1
=
UInt
(
C_IIDX_BITS
.
W
)
val
inp_0
=
UInt
(
C_IIDX_BITS
.
W
)
val
acc_1
=
UInt
(
C_AIDX_BITS
.
W
)
val
acc_0
=
UInt
(
C_AIDX_BITS
.
W
)
val
empty_0
=
Bool
()
val
lp_1
=
UInt
(
C_ITER_BITS
.
W
)
val
lp_0
=
UInt
(
C_ITER_BITS
.
W
)
val
uop_end
=
UInt
(
C_UOP_END_BITS
.
W
)
val
uop_begin
=
UInt
(
C_UOP_BGN_BITS
.
W
)
val
reset
=
Bool
()
val
push_next
=
Bool
()
val
push_prev
=
Bool
()
val
pop_next
=
Bool
()
val
pop_prev
=
Bool
()
val
op
=
UInt
(
OP_BITS
.
W
)
}
/** AluDecode.
*
* Decode ALU instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions
* decoded with this bundle:
* - VMIN
* - VMAX
* - VADD
* - VSHX
*/
class
AluDecode
extends
Bundle
{
val
empty_1
=
Bool
()
val
alu_imm
=
UInt
(
C_ALU_IMM_BITS
.
W
)
val
alu_use_imm
=
Bool
()
val
alu_op
=
UInt
(
C_ALU_DEC_BITS
.
W
)
val
src_1
=
UInt
(
C_IIDX_BITS
.
W
)
val
src_0
=
UInt
(
C_IIDX_BITS
.
W
)
val
dst_1
=
UInt
(
C_AIDX_BITS
.
W
)
val
dst_0
=
UInt
(
C_AIDX_BITS
.
W
)
val
empty_0
=
Bool
()
val
lp_1
=
UInt
(
C_ITER_BITS
.
W
)
val
lp_0
=
UInt
(
C_ITER_BITS
.
W
)
val
uop_end
=
UInt
(
C_UOP_END_BITS
.
W
)
val
uop_begin
=
UInt
(
C_UOP_BGN_BITS
.
W
)
val
reset
=
Bool
()
val
push_next
=
Bool
()
val
push_prev
=
Bool
()
val
pop_next
=
Bool
()
val
pop_prev
=
Bool
()
val
op
=
UInt
(
OP_BITS
.
W
)
}
/** UopDecode.
*
* Decode micro-ops (uops).
*/
class
UopDecode
extends
Bundle
{
val
u2
=
UInt
(
10.
W
)
val
u1
=
UInt
(
11.
W
)
val
u0
=
UInt
(
11.
W
)
}
/** FetchDecode.
*
* Partial decoding for dispatching instructions to Load, Compute, and Store.
*/
class
FetchDecode
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
isLoad
=
Output
(
Bool
())
val
isCompute
=
Output
(
Bool
())
val
isStore
=
Output
(
Bool
())
})
val
csignals
=
ListLookup
(
io
.
inst
,
List
(
N
,
OP_X
),
Array
(
LUOP
->
List
(
Y
,
OP_G
),
LWGT
->
List
(
Y
,
OP_L
),
LINP
->
List
(
Y
,
OP_L
),
LACC
->
List
(
Y
,
OP_G
),
SOUT
->
List
(
Y
,
OP_S
),
GEMM
->
List
(
Y
,
OP_G
),
FNSH
->
List
(
Y
,
OP_G
),
VMIN
->
List
(
Y
,
OP_G
),
VMAX
->
List
(
Y
,
OP_G
),
VADD
->
List
(
Y
,
OP_G
),
VSHX
->
List
(
Y
,
OP_G
)
)
)
val
(
cs_val_inst
:
Bool
)
::
cs_op_type
::
Nil
=
csignals
io
.
isLoad
:=
cs_val_inst
&
cs_op_type
===
OP_L
io
.
isCompute
:=
cs_val_inst
&
cs_op_type
===
OP_G
io
.
isStore
:=
cs_val_inst
&
cs_op_type
===
OP_S
}
/** LoadDecode.
*
* Decode dependencies, type and sync for Load module.
*/
class
LoadDecode
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
push_next
=
Output
(
Bool
())
val
pop_next
=
Output
(
Bool
())
val
isInput
=
Output
(
Bool
())
val
isWeight
=
Output
(
Bool
())
val
isSync
=
Output
(
Bool
())
})
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
io
.
push_next
:=
dec
.
push_next
io
.
pop_next
:=
dec
.
pop_next
io
.
isInput
:=
io
.
inst
===
LINP
&
dec
.
xsize
=/=
0.
U
io
.
isWeight
:=
io
.
inst
===
LWGT
&
dec
.
xsize
=/=
0.
U
io
.
isSync
:=
(
io
.
inst
===
LINP
|
io
.
inst
===
LWGT
)
&
dec
.
xsize
===
0.
U
}
/** ComputeDecode.
*
* Decode dependencies, type and sync for Compute module.
*/
class
ComputeDecode
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
push_next
=
Output
(
Bool
())
val
push_prev
=
Output
(
Bool
())
val
pop_next
=
Output
(
Bool
())
val
pop_prev
=
Output
(
Bool
())
val
isLoadAcc
=
Output
(
Bool
())
val
isLoadUop
=
Output
(
Bool
())
val
isSync
=
Output
(
Bool
())
val
isAlu
=
Output
(
Bool
())
val
isGemm
=
Output
(
Bool
())
val
isFinish
=
Output
(
Bool
())
})
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
io
.
push_next
:=
dec
.
push_next
io
.
push_prev
:=
dec
.
push_prev
io
.
pop_next
:=
dec
.
pop_next
io
.
pop_prev
:=
dec
.
pop_prev
io
.
isLoadAcc
:=
io
.
inst
===
LACC
&
dec
.
xsize
=/=
0.
U
io
.
isLoadUop
:=
io
.
inst
===
LUOP
&
dec
.
xsize
=/=
0.
U
io
.
isSync
:=
(
io
.
inst
===
LACC
|
io
.
inst
===
LUOP
)
&
dec
.
xsize
===
0.
U
io
.
isAlu
:=
io
.
inst
===
VMIN
|
io
.
inst
===
VMAX
|
io
.
inst
===
VADD
|
io
.
inst
===
VSHX
io
.
isGemm
:=
io
.
inst
===
GEMM
io
.
isFinish
:=
io
.
inst
===
FNSH
}
/** StoreDecode.
*
* Decode dependencies, type and sync for Store module.
*/
class
StoreDecode
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
push_prev
=
Output
(
Bool
())
val
pop_prev
=
Output
(
Bool
())
val
isStore
=
Output
(
Bool
())
val
isSync
=
Output
(
Bool
())
})
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
io
.
push_prev
:=
dec
.
push_prev
io
.
pop_prev
:=
dec
.
pop_prev
io
.
isStore
:=
io
.
inst
===
SOUT
&
dec
.
xsize
=/=
0.
U
io
.
isSync
:=
io
.
inst
===
SOUT
&
dec
.
xsize
===
0.
U
}
vta/hardware/chisel/src/main/scala/core/Fetch.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** Fetch.
*
* The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
* VTA Memory Engine (VME), and push them into an instruction queue called
* inst_q. Once the instruction queue is full, instructions are dispatched to
* the Load, Compute and Store module queues based on the instruction opcode.
* After draining the queue, the fetch unit checks if there are more instructions
* via the ins_count register which is written by the host.
*
* Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
* because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
* This should be configurable for larger payloads, i.e. 64-bytes, which can load
* more than one instruction at the time. Finally, the instruction queue is
* sized (entries_q), depending on the maximum burst allowed in the memory.
*/
class
Fetch
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
vp
=
p
(
ShellKey
).
vcrParams
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
launch
=
Input
(
Bool
())
val
ins_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
ins_count
=
Input
(
UInt
(
vp
.
regBits
.
W
))
val
vme_rd
=
new
VMEReadMaster
val
inst
=
new
Bundle
{
val
ld
=
Decoupled
(
UInt
(
INST_BITS
.
W
))
val
co
=
Decoupled
(
UInt
(
INST_BITS
.
W
))
val
st
=
Decoupled
(
UInt
(
INST_BITS
.
W
))
}
})
val
entries_q
=
1
<<
(
mp
.
lenBits
-
1
)
// one-instr-every-two-vme-word
val
inst_q
=
Module
(
new
Queue
(
UInt
(
INST_BITS
.
W
),
entries_q
))
val
dec
=
Module
(
new
FetchDecode
)
val
s1_launch
=
RegNext
(
io
.
launch
)
val
pulse
=
io
.
launch
&
~
s1_launch
val
raddr
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
addr
))
val
rlen
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
len
))
val
ilen
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
len
))
val
xrem
=
Reg
(
chiselTypeOf
(
io
.
ins_count
))
val
xsize
=
(
io
.
ins_count
<<
1.
U
)
-
1.
U
val
xmax
=
(
1
<<
mp
.
lenBits
).
U
val
xmax_bytes
=
((
1
<<
mp
.
lenBits
)*
mp
.
dataBits
/
8
).
U
val
sIdle
::
sReadCmd
::
sReadLSB
::
sReadMSB
::
sDrain
::
Nil
=
Enum
(
5
)
val
state
=
RegInit
(
sIdle
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
pulse
)
{
state
:=
sReadCmd
when
(
xsize
<
xmax
)
{
rlen
:=
xsize
ilen
:=
xsize
>>
1.
U
xrem
:=
0.
U
}
.
otherwise
{
rlen
:=
xmax
-
1.
U
ilen
:=
(
xmax
>>
1.
U
)
-
1.
U
xrem
:=
xsize
-
xmax
}
}
}
is
(
sReadCmd
)
{
when
(
io
.
vme_rd
.
cmd
.
ready
)
{
state
:=
sReadLSB
}
}
is
(
sReadLSB
)
{
when
(
io
.
vme_rd
.
data
.
valid
)
{
state
:=
sReadMSB
}
}
is
(
sReadMSB
)
{
when
(
io
.
vme_rd
.
data
.
valid
)
{
when
(
inst_q
.
io
.
count
===
ilen
)
{
state
:=
sDrain
}
.
otherwise
{
state
:=
sReadLSB
}
}
}
is
(
sDrain
)
{
when
(
inst_q
.
io
.
count
===
0.
U
)
{
when
(
xrem
===
0.
U
)
{
state
:=
sIdle
}
.
elsewhen
(
xrem
<
xmax
)
{
state
:=
sReadCmd
rlen
:=
xrem
ilen
:=
xrem
>>
1.
U
xrem
:=
0.
U
}
.
otherwise
{
state
:=
sReadCmd
rlen
:=
xmax
-
1.
U
ilen
:=
(
xmax
>>
1.
U
)
-
1.
U
xrem
:=
xrem
-
xmax
}
}
}
}
// read instructions from dram
when
(
state
===
sIdle
)
{
raddr
:=
io
.
ins_baddr
}
.
elsewhen
(
state
===
sDrain
&&
inst_q
.
io
.
count
===
0.
U
&&
xrem
=/=
0.
U
)
{
raddr
:=
raddr
+
xmax_bytes
}
io
.
vme_rd
.
cmd
.
valid
:=
state
===
sReadCmd
io
.
vme_rd
.
cmd
.
bits
.
addr
:=
raddr
io
.
vme_rd
.
cmd
.
bits
.
len
:=
rlen
io
.
vme_rd
.
data
.
ready
:=
inst_q
.
io
.
enq
.
ready
val
lsb
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
data
.
bits
))
val
msb
=
io
.
vme_rd
.
data
.
bits
val
inst
=
Cat
(
msb
,
lsb
)
when
(
state
===
sReadLSB
)
{
lsb
:=
io
.
vme_rd
.
data
.
bits
}
inst_q
.
io
.
enq
.
valid
:=
io
.
vme_rd
.
data
.
valid
&
state
===
sReadMSB
inst_q
.
io
.
enq
.
bits
:=
inst
// decode
dec
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
// instruction queues
io
.
inst
.
ld
.
valid
:=
dec
.
io
.
isLoad
&
inst_q
.
io
.
deq
.
valid
&
state
===
sDrain
io
.
inst
.
co
.
valid
:=
dec
.
io
.
isCompute
&
inst_q
.
io
.
deq
.
valid
&
state
===
sDrain
io
.
inst
.
st
.
valid
:=
dec
.
io
.
isStore
&
inst_q
.
io
.
deq
.
valid
&
state
===
sDrain
io
.
inst
.
ld
.
bits
:=
inst_q
.
io
.
deq
.
bits
io
.
inst
.
co
.
bits
:=
inst_q
.
io
.
deq
.
bits
io
.
inst
.
st
.
bits
:=
inst_q
.
io
.
deq
.
bits
// check if selected queue is ready
val
deq_sel
=
Cat
(
dec
.
io
.
isCompute
,
dec
.
io
.
isStore
,
dec
.
io
.
isLoad
).
asUInt
val
deq_ready
=
MuxLookup
(
deq_sel
,
false
.
B
,
// default
Array
(
"h_01"
.
U
->
io
.
inst
.
ld
.
ready
,
"h_02"
.
U
->
io
.
inst
.
st
.
ready
,
"h_04"
.
U
->
io
.
inst
.
co
.
ready
)
)
// dequeue instruction
inst_q
.
io
.
deq
.
ready
:=
deq_ready
&
inst_q
.
io
.
deq
.
valid
&
state
===
sDrain
// debug
if
(
debug
)
{
when
(
state
===
sIdle
&&
pulse
)
{
printf
(
"[Fetch] Launch\n"
)
}
// instruction
when
(
inst_q
.
io
.
deq
.
fire
())
{
when
(
dec
.
io
.
isLoad
)
{
printf
(
"[Fetch] [instruction decode] [L] %x\n"
,
inst_q
.
io
.
deq
.
bits
)
}
when
(
dec
.
io
.
isCompute
)
{
printf
(
"[Fetch] [instruction decode] [C] %x\n"
,
inst_q
.
io
.
deq
.
bits
)
}
when
(
dec
.
io
.
isStore
)
{
printf
(
"[Fetch] [instruction decode] [S] %x\n"
,
inst_q
.
io
.
deq
.
bits
)
}
}
}
}
vta/hardware/chisel/src/main/scala/core/ISA.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
/** ISAConstants.
*
* These constants are used for decoding (parsing) fields on instructions.
*/
trait
ISAConstants
{
val
INST_BITS
=
128
val
OP_BITS
=
3
val
M_DEP_BITS
=
4
val
M_ID_BITS
=
2
val
M_SRAM_OFFSET_BITS
=
16
val
M_DRAM_OFFSET_BITS
=
32
val
M_SIZE_BITS
=
16
val
M_STRIDE_BITS
=
16
val
M_PAD_BITS
=
4
val
C_UOP_BGN_BITS
=
13
val
C_UOP_END_BITS
=
14
val
C_ITER_BITS
=
14
val
C_AIDX_BITS
=
11
val
C_IIDX_BITS
=
11
val
C_WIDX_BITS
=
10
val
C_ALU_DEC_BITS
=
2
// FIXME: there should be a SHL and SHR instruction
val
C_ALU_OP_BITS
=
3
val
C_ALU_IMM_BITS
=
16
val
Y
=
true
.
B
val
N
=
false
.
B
val
OP_L
=
0.
asUInt
(
OP_BITS
.
W
)
val
OP_S
=
1.
asUInt
(
OP_BITS
.
W
)
val
OP_G
=
2.
asUInt
(
OP_BITS
.
W
)
val
OP_F
=
3.
asUInt
(
OP_BITS
.
W
)
val
OP_A
=
4.
asUInt
(
OP_BITS
.
W
)
val
OP_X
=
5.
asUInt
(
OP_BITS
.
W
)
val
ALU_OP_NUM
=
5
val
ALU_OP
=
Enum
(
ALU_OP_NUM
)
val
M_ID_U
=
0.
asUInt
(
M_ID_BITS
.
W
)
val
M_ID_W
=
1.
asUInt
(
M_ID_BITS
.
W
)
val
M_ID_I
=
2.
asUInt
(
M_ID_BITS
.
W
)
val
M_ID_A
=
3.
asUInt
(
M_ID_BITS
.
W
)
}
/** ISA.
*
* This is the VTA ISA, here we specify the cares and dont-cares that makes
* decoding easier. Since instructions are quite long 128-bit, we could generate
* these based on ISAConstants.
*
* FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
* TODO: Add VXOR to clear accumulator
*/
object
ISA
{
def
LUOP
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000"
)
def
LWGT
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000"
)
def
LINP
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000"
)
def
LACC
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000"
)
def
SOUT
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001"
)
def
GEMM
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010"
)
def
VMIN
=
BitPat
(
"b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100"
)
def
VMAX
=
BitPat
(
"b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100"
)
def
VADD
=
BitPat
(
"b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100"
)
def
VSHX
=
BitPat
(
"b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100"
)
def
FNSH
=
BitPat
(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011"
)
}
vta/hardware/chisel/src/main/scala/core/Load.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** Load.
*
* Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
* This module instantiate the TensorLoad unit which is in charge of
* loading 1D and 2D tensors to scratchpads, so it can be used by
* other modules such as Compute.
*/
class
Load
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
i_post
=
Input
(
Bool
())
val
o_post
=
Output
(
Bool
())
val
inst
=
Flipped
(
Decoupled
(
UInt
(
INST_BITS
.
W
)))
val
inp_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
wgt_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_rd
=
Vec
(
2
,
new
VMEReadMaster
)
val
inp
=
new
TensorClient
(
tensorType
=
"inp"
)
val
wgt
=
new
TensorClient
(
tensorType
=
"wgt"
)
})
val
sIdle
::
sSync
::
sExe
::
Nil
=
Enum
(
3
)
val
state
=
RegInit
(
sIdle
)
val
s
=
Module
(
new
Semaphore
(
counterBits
=
8
,
counterInitValue
=
0
))
val
inst_q
=
Module
(
new
Queue
(
UInt
(
INST_BITS
.
W
),
p
(
CoreKey
).
instQueueEntries
))
val
dec
=
Module
(
new
LoadDecode
)
dec
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
val
tensorType
=
Seq
(
"inp"
,
"wgt"
)
val
tensorDec
=
Seq
(
dec
.
io
.
isInput
,
dec
.
io
.
isWeight
)
val
tensorLoad
=
Seq
.
tabulate
(
2
)(
i
=>
Module
(
new
TensorLoad
(
tensorType
=
tensorType
(
i
))))
val
start
=
inst_q
.
io
.
deq
.
valid
&
Mux
(
dec
.
io
.
pop_next
,
s
.
io
.
sready
,
true
.
B
)
val
done
=
Mux
(
dec
.
io
.
isInput
,
tensorLoad
(
0
).
io
.
done
,
tensorLoad
(
1
).
io
.
done
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
start
)
{
when
(
dec
.
io
.
isSync
)
{
state
:=
sSync
}
.
elsewhen
(
dec
.
io
.
isInput
||
dec
.
io
.
isWeight
)
{
state
:=
sExe
}
}
}
is
(
sSync
)
{
state
:=
sIdle
}
is
(
sExe
)
{
when
(
done
)
{
state
:=
sIdle
}
}
}
// instructions
inst_q
.
io
.
enq
<>
io
.
inst
inst_q
.
io
.
deq
.
ready
:=
(
state
===
sExe
&
done
)
|
(
state
===
sSync
)
// load tensor
// [0] input (inp)
// [1] weight (wgt)
val
ptr
=
Seq
(
io
.
inp_baddr
,
io
.
wgt_baddr
)
val
tsor
=
Seq
(
io
.
inp
,
io
.
wgt
)
for
(
i
<-
0
until
2
)
{
tensorLoad
(
i
).
io
.
start
:=
state
===
sIdle
&
start
&
tensorDec
(
i
)
tensorLoad
(
i
).
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
tensorLoad
(
i
).
io
.
baddr
:=
ptr
(
i
)
tensorLoad
(
i
).
io
.
tensor
<>
tsor
(
i
)
io
.
vme_rd
(
i
)
<>
tensorLoad
(
i
).
io
.
vme_rd
}
// semaphore
s
.
io
.
spost
:=
io
.
i_post
s
.
io
.
swait
:=
dec
.
io
.
pop_next
&
(
state
===
sIdle
&
start
)
io
.
o_post
:=
dec
.
io
.
push_next
&
((
state
===
sExe
&
done
)
|
(
state
===
sSync
))
// debug
if
(
debug
)
{
// start
when
(
state
===
sIdle
&&
start
)
{
when
(
dec
.
io
.
isSync
)
{
printf
(
"[Load] start sync\n"
)
}
.
elsewhen
(
dec
.
io
.
isInput
)
{
printf
(
"[Load] start input\n"
)
}
.
elsewhen
(
dec
.
io
.
isWeight
)
{
printf
(
"[Load] start weight\n"
)
}
}
// done
when
(
state
===
sSync
)
{
printf
(
"[Load] done sync\n"
)
}
when
(
state
===
sExe
)
{
when
(
done
)
{
when
(
dec
.
io
.
isInput
)
{
printf
(
"[Load] done input\n"
)
}
.
elsewhen
(
dec
.
io
.
isWeight
)
{
printf
(
"[Load] done weight\n"
)
}
}
}
}
}
vta/hardware/chisel/src/main/scala/core/LoadUop.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** UopMaster.
*
* Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
* to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
* used as an address to find the uop in the uop-scratchpad.
*/
class
UopMaster
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
addrBits
=
log2Ceil
(
p
(
CoreKey
).
uopMemDepth
)
val
idx
=
ValidIO
(
UInt
(
addrBits
.
W
))
val
data
=
Flipped
(
ValidIO
(
new
UopDecode
))
override
def
cloneType
=
new
UopMaster
().
asInstanceOf
[
this.
type
]
}
/** UopClient.
*
* Uop interface used by a client module, i.e. LoadUop, to receive
* a request from a master module, i.e. TensorAlu or TensorGemm.
* The index (idx) is used as an address to find the uop in the uop-scratchpad.
*/
class
UopClient
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
addrBits
=
log2Ceil
(
p
(
CoreKey
).
uopMemDepth
)
val
idx
=
Flipped
(
ValidIO
(
UInt
(
addrBits
.
W
)))
val
data
=
ValidIO
(
new
UopDecode
)
override
def
cloneType
=
new
UopClient
().
asInstanceOf
[
this.
type
]
}
/** LoadUop.
*
* Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
* uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
* group of 2 given the fact that the DRAM payload is 8-bytes. This module
* should be modified later on to support different DRAM sizes efficiently.
*/
class
LoadUop
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_rd
=
new
VMEReadMaster
val
uop
=
new
UopClient
})
val
numUop
=
2
// store two uops per sram word
val
uopBits
=
p
(
CoreKey
).
uopBits
val
uopDepth
=
p
(
CoreKey
).
uopMemDepth
/
numUop
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
val
raddr
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
addr
))
val
xcnt
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
len
))
val
xlen
=
Reg
(
chiselTypeOf
(
io
.
vme_rd
.
cmd
.
bits
.
len
))
val
xrem
=
Reg
(
chiselTypeOf
(
dec
.
xsize
))
val
xsize
=
dec
.
xsize
(
0
)
+
(
dec
.
xsize
>>
log2Ceil
(
numUop
))
-
1.
U
val
xmax
=
(
1
<<
mp
.
lenBits
).
U
val
xmax_bytes
=
((
1
<<
mp
.
lenBits
)*
mp
.
dataBits
/
8
).
U
val
offsetIsEven
=
(
dec
.
sram_offset
%
2.
U
)
===
0.
U
val
sizeIsEven
=
(
dec
.
xsize
%
2.
U
)
===
0.
U
val
sIdle
::
sReadCmd
::
sReadData
::
Nil
=
Enum
(
3
)
val
state
=
RegInit
(
sIdle
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
state
:=
sReadCmd
when
(
xsize
<
xmax
)
{
xlen
:=
xsize
xrem
:=
0.
U
}
.
otherwise
{
xlen
:=
xmax
-
1.
U
xrem
:=
xsize
-
xmax
}
}
}
is
(
sReadCmd
)
{
when
(
io
.
vme_rd
.
cmd
.
ready
)
{
state
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
vme_rd
.
data
.
valid
)
{
when
(
xcnt
===
xlen
)
{
when
(
xrem
===
0.
U
)
{
state
:=
sIdle
}
.
elsewhen
(
xrem
<
xmax
)
{
state
:=
sReadCmd
xlen
:=
xrem
xrem
:=
0.
U
}
.
otherwise
{
state
:=
sReadCmd
xlen
:=
xmax
-
1.
U
xrem
:=
xrem
-
xmax
}
}
}
}
}
// read-from-dram
when
(
state
===
sIdle
)
{
when
(
offsetIsEven
)
{
raddr
:=
io
.
baddr
+
dec
.
dram_offset
}
.
otherwise
{
raddr
:=
io
.
baddr
+
dec
.
dram_offset
-
4.
U
}
}
.
elsewhen
(
state
===
sReadData
&&
xcnt
===
xlen
&&
xrem
=/=
0.
U
)
{
raddr
:=
raddr
+
xmax_bytes
}
io
.
vme_rd
.
cmd
.
valid
:=
state
===
sReadCmd
io
.
vme_rd
.
cmd
.
bits
.
addr
:=
raddr
io
.
vme_rd
.
cmd
.
bits
.
len
:=
xlen
io
.
vme_rd
.
data
.
ready
:=
state
===
sReadData
when
(
state
=/=
sReadData
)
{
xcnt
:=
0.
U
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
())
{
xcnt
:=
xcnt
+
1.
U
}
val
waddr
=
Reg
(
UInt
(
log2Ceil
(
uopDepth
).
W
))
when
(
state
===
sIdle
)
{
waddr
:=
dec
.
sram_offset
>>
log2Ceil
(
numUop
)
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
())
{
waddr
:=
waddr
+
1.
U
}
val
wdata
=
Wire
(
Vec
(
numUop
,
UInt
(
uopBits
.
W
)))
val
mem
=
SyncReadMem
(
uopDepth
,
chiselTypeOf
(
wdata
))
val
wmask
=
Reg
(
Vec
(
numUop
,
Bool
()))
when
(
offsetIsEven
)
{
when
(
sizeIsEven
)
{
wmask
:=
"b_11"
.
U
.
asTypeOf
(
wmask
)
}
.
elsewhen
(
io
.
vme_rd
.
cmd
.
fire
())
{
when
(
dec
.
xsize
===
1.
U
)
{
wmask
:=
"b_01"
.
U
.
asTypeOf
(
wmask
)
}
.
otherwise
{
wmask
:=
"b_11"
.
U
.
asTypeOf
(
wmask
)
}
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
())
{
when
(
xcnt
===
xlen
-
1.
U
)
{
wmask
:=
"b_01"
.
U
.
asTypeOf
(
wmask
)
}
.
otherwise
{
wmask
:=
"b_11"
.
U
.
asTypeOf
(
wmask
)
}
}
}
.
otherwise
{
when
(
io
.
vme_rd
.
cmd
.
fire
())
{
wmask
:=
"b_10"
.
U
.
asTypeOf
(
wmask
)
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
())
{
when
(
sizeIsEven
&&
xcnt
===
xlen
-
1.
U
)
{
wmask
:=
"b_01"
.
U
.
asTypeOf
(
wmask
)
}
.
otherwise
{
wmask
:=
"b_11"
.
U
.
asTypeOf
(
wmask
)
}
}
}
wdata
:=
io
.
vme_rd
.
data
.
bits
.
asTypeOf
(
wdata
)
when
(
io
.
vme_rd
.
data
.
fire
())
{
mem
.
write
(
waddr
,
wdata
,
wmask
)
}
// read-from-sram
io
.
uop
.
data
.
valid
:=
RegNext
(
io
.
uop
.
idx
.
valid
)
val
sIdx
=
io
.
uop
.
idx
.
bits
%
numUop
.
U
val
rIdx
=
io
.
uop
.
idx
.
bits
>>
log2Ceil
(
numUop
)
val
memRead
=
mem
.
read
(
rIdx
,
io
.
uop
.
idx
.
valid
)
val
sWord
=
memRead
.
asUInt
.
asTypeOf
(
wdata
)
val
sUop
=
sWord
(
sIdx
).
asTypeOf
(
io
.
uop
.
data
.
bits
)
io
.
uop
.
data
.
bits
<>
sUop
// done
io
.
done
:=
state
===
sReadData
&
io
.
vme_rd
.
data
.
valid
&
xcnt
===
xlen
&
xrem
===
0.
U
// debug
if
(
debug
)
{
when
(
io
.
vme_rd
.
cmd
.
fire
())
{
printf
(
"[LoadUop] cmd addr:%x len:%x rem:%x\n"
,
raddr
,
xlen
,
xrem
)
}
}
}
vta/hardware/chisel/src/main/scala/core/Semaphore.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
/** Semaphore.
*
* This semaphore is used instead of push/pop fifo, used in the initial
* version of VTA. This semaphore is incremented (spost) or decremented (swait)
* depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards.
*/
class
Semaphore
(
counterBits
:
Int
=
1
,
counterInitValue
:
Int
=
1
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
spost
=
Input
(
Bool
())
val
swait
=
Input
(
Bool
())
val
sready
=
Output
(
Bool
())
})
val
cnt
=
RegInit
(
counterInitValue
.
U
(
counterBits
.
W
))
when
(
io
.
spost
&&
!
io
.
swait
&&
cnt
=/=
((
1
<<
counterBits
)
-
1
).
asUInt
)
{
cnt
:=
cnt
+
1.
U
}
when
(!
io
.
spost
&&
io
.
swait
&&
cnt
=/=
0.
U
)
{
cnt
:=
cnt
-
1.
U
}
io
.
sready
:=
cnt
=/=
0.
U
}
vta/hardware/chisel/src/main/scala/core/Store.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** Store.
*
* Store results back to memory (DRAM) from scratchpads (SRAMs).
* This module instantiate the TensorStore unit which is in charge
* of storing 1D and 2D tensors to main memory.
*/
class
Store
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
i_post
=
Input
(
Bool
())
val
o_post
=
Output
(
Bool
())
val
inst
=
Flipped
(
Decoupled
(
UInt
(
INST_BITS
.
W
)))
val
out_baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_wr
=
new
VMEWriteMaster
val
out
=
new
TensorClient
(
tensorType
=
"out"
)
})
val
sIdle
::
sSync
::
sExe
::
Nil
=
Enum
(
3
)
val
state
=
RegInit
(
sIdle
)
val
s
=
Module
(
new
Semaphore
(
counterBits
=
8
,
counterInitValue
=
0
))
val
inst_q
=
Module
(
new
Queue
(
UInt
(
INST_BITS
.
W
),
p
(
CoreKey
).
instQueueEntries
))
val
dec
=
Module
(
new
StoreDecode
)
dec
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
val
tensorStore
=
Module
(
new
TensorStore
(
tensorType
=
"out"
))
val
start
=
inst_q
.
io
.
deq
.
valid
&
Mux
(
dec
.
io
.
pop_prev
,
s
.
io
.
sready
,
true
.
B
)
val
done
=
tensorStore
.
io
.
done
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
start
)
{
when
(
dec
.
io
.
isSync
)
{
state
:=
sSync
}
.
elsewhen
(
dec
.
io
.
isStore
)
{
state
:=
sExe
}
}
}
is
(
sSync
)
{
state
:=
sIdle
}
is
(
sExe
)
{
when
(
done
)
{
state
:=
sIdle
}
}
}
// instructions
inst_q
.
io
.
enq
<>
io
.
inst
inst_q
.
io
.
deq
.
ready
:=
(
state
===
sExe
&
done
)
|
(
state
===
sSync
)
// store
tensorStore
.
io
.
start
:=
state
===
sIdle
&
start
&
dec
.
io
.
isStore
tensorStore
.
io
.
inst
:=
inst_q
.
io
.
deq
.
bits
tensorStore
.
io
.
baddr
:=
io
.
out_baddr
io
.
vme_wr
<>
tensorStore
.
io
.
vme_wr
tensorStore
.
io
.
tensor
<>
io
.
out
// semaphore
s
.
io
.
spost
:=
io
.
i_post
s
.
io
.
swait
:=
dec
.
io
.
pop_prev
&
(
state
===
sIdle
&
start
)
io
.
o_post
:=
dec
.
io
.
push_prev
&
((
state
===
sExe
&
done
)
|
(
state
===
sSync
))
// debug
if
(
debug
)
{
// start
when
(
state
===
sIdle
&&
start
)
{
when
(
dec
.
io
.
isSync
)
{
printf
(
"[Store] start sync\n"
)
}
.
elsewhen
(
dec
.
io
.
isStore
)
{
printf
(
"[Store] start\n"
)
}
}
// done
when
(
state
===
sSync
)
{
printf
(
"[Store] done sync\n"
)
}
when
(
state
===
sExe
)
{
when
(
done
)
{
printf
(
"[Store] done\n"
)
}
}
}
}
vta/hardware/chisel/src/main/scala/core/TensorAlu.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
/** ALU datapath */
class
Alu
(
implicit
p
:
Parameters
)
extends
Module
{
val
aluBits
=
p
(
CoreKey
).
accBits
val
io
=
IO
(
new
Bundle
{
val
opcode
=
Input
(
UInt
(
C_ALU_OP_BITS
.
W
))
val
a
=
Input
(
SInt
(
aluBits
.
W
))
val
b
=
Input
(
SInt
(
aluBits
.
W
))
val
y
=
Output
(
SInt
(
aluBits
.
W
))
})
// FIXME: the following three will change once we support properly SHR and SHL
val
ub
=
io
.
b
.
asUInt
val
width
=
log2Ceil
(
aluBits
)
val
m
=
~
ub
(
width
-
1
,
0
)
+
1.
U
val
n
=
ub
(
width
-
1
,
0
)
val
fop
=
Seq
(
Mux
(
io
.
a
<
io
.
b
,
io
.
a
,
io
.
b
),
Mux
(
io
.
a
<
io
.
b
,
io
.
b
,
io
.
a
),
io
.
a
+
io
.
b
,
io
.
a
>>
n
,
io
.
a
<<
m
)
val
opmux
=
Seq
.
tabulate
(
ALU_OP_NUM
)(
i
=>
ALU_OP
(
i
)
->
fop
(
i
))
io
.
y
:=
MuxLookup
(
io
.
opcode
,
io
.
a
,
opmux
)
}
/** Pipelined ALU */
class
AluReg
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
opcode
=
Input
(
UInt
(
C_ALU_OP_BITS
.
W
))
val
a
=
Flipped
(
ValidIO
(
UInt
(
p
(
CoreKey
).
accBits
.
W
)))
val
b
=
Flipped
(
ValidIO
(
UInt
(
p
(
CoreKey
).
accBits
.
W
)))
val
y
=
ValidIO
(
UInt
(
p
(
CoreKey
).
accBits
.
W
))
})
val
alu
=
Module
(
new
Alu
)
val
rA
=
RegEnable
(
io
.
a
.
bits
,
io
.
a
.
valid
)
val
rB
=
RegEnable
(
io
.
b
.
bits
,
io
.
b
.
valid
)
val
valid
=
RegNext
(
io
.
b
.
valid
)
alu
.
io
.
opcode
:=
io
.
opcode
// register input
alu
.
io
.
a
:=
rA
.
asSInt
alu
.
io
.
b
:=
rB
.
asSInt
// output
io
.
y
.
valid
:=
valid
io
.
y
.
bits
:=
alu
.
io
.
y
.
asUInt
}
/** Vector of pipeline ALUs */
class
AluVector
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
opcode
=
Input
(
UInt
(
C_ALU_OP_BITS
.
W
))
val
acc_a
=
new
TensorMasterData
(
tensorType
=
"acc"
)
val
acc_b
=
new
TensorMasterData
(
tensorType
=
"acc"
)
val
acc_y
=
new
TensorClientData
(
tensorType
=
"acc"
)
val
out
=
new
TensorClientData
(
tensorType
=
"out"
)
})
val
blockOut
=
p
(
CoreKey
).
blockOut
val
f
=
Seq
.
fill
(
blockOut
)(
Module
(
new
AluReg
))
val
valid
=
Wire
(
Vec
(
blockOut
,
Bool
()))
for
(
i
<-
0
until
blockOut
)
{
f
(
i
).
io
.
opcode
:=
io
.
opcode
f
(
i
).
io
.
a
.
valid
:=
io
.
acc_a
.
data
.
valid
f
(
i
).
io
.
a
.
bits
:=
io
.
acc_a
.
data
.
bits
(
0
)(
i
)
f
(
i
).
io
.
b
.
valid
:=
io
.
acc_b
.
data
.
valid
f
(
i
).
io
.
b
.
bits
:=
io
.
acc_b
.
data
.
bits
(
0
)(
i
)
valid
(
i
)
:=
f
(
i
).
io
.
y
.
valid
io
.
acc_y
.
data
.
bits
(
0
)(
i
)
:=
f
(
i
).
io
.
y
.
bits
io
.
out
.
data
.
bits
(
0
)(
i
)
:=
f
(
i
).
io
.
y
.
bits
}
io
.
acc_y
.
data
.
valid
:=
valid
.
asUInt
.
andR
io
.
out
.
data
.
valid
:=
valid
.
asUInt
.
andR
}
/** TensorAlu.
*
* This unit instantiate the ALU vector unit (AluVector) and go over the
* micro-ops (uops) which are used to read the source operands (vectors)
* from the acc-scratchpad and then they are written back the same
* acc-scratchpad.
*/
class
TensorAlu
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
uop
=
new
UopMaster
val
acc
=
new
TensorMaster
(
tensorType
=
"acc"
)
val
out
=
new
TensorMaster
(
tensorType
=
"out"
)
})
val
sIdle
::
sReadUop
::
sComputeIdx
::
sReadTensorA
::
sReadTensorB
::
sExe
::
Nil
=
Enum
(
6
)
val
state
=
RegInit
(
sIdle
)
val
alu
=
Module
(
new
AluVector
)
val
dec
=
io
.
inst
.
asTypeOf
(
new
AluDecode
)
val
uop_idx
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
uop_end
=
dec
.
uop_end
val
uop_dst
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
uop_src
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
cnt_o
=
Reg
(
chiselTypeOf
(
dec
.
lp_0
))
val
dst_o
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
src_o
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
cnt_i
=
Reg
(
chiselTypeOf
(
dec
.
lp_1
))
val
dst_i
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
src_i
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
done
=
state
===
sExe
&
alu
.
io
.
out
.
data
.
valid
&
(
cnt_o
===
dec
.
lp_0
-
1.
U
)
&
(
cnt_i
===
dec
.
lp_1
-
1.
U
)
&
(
uop_idx
===
uop_end
-
1.
U
)
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
state
:=
sReadUop
}
}
is
(
sReadUop
)
{
state
:=
sComputeIdx
}
is
(
sComputeIdx
)
{
state
:=
sReadTensorA
}
is
(
sReadTensorA
)
{
state
:=
sReadTensorB
}
is
(
sReadTensorB
)
{
state
:=
sExe
}
is
(
sExe
)
{
when
(
alu
.
io
.
out
.
data
.
valid
)
{
when
((
cnt_o
===
dec
.
lp_0
-
1.
U
)
&&
(
cnt_i
===
dec
.
lp_1
-
1.
U
)
&&
(
uop_idx
===
uop_end
-
1.
U
))
{
state
:=
sIdle
}
.
otherwise
{
state
:=
sReadUop
}
}
}
}
when
(
state
===
sIdle
||
(
state
===
sExe
&&
alu
.
io
.
out
.
data
.
valid
&&
uop_idx
===
uop_end
-
1.
U
))
{
uop_idx
:=
dec
.
uop_begin
}
.
elsewhen
(
state
===
sExe
&&
alu
.
io
.
out
.
data
.
valid
)
{
uop_idx
:=
uop_idx
+
1.
U
}
when
(
state
===
sIdle
)
{
cnt_o
:=
0.
U
dst_o
:=
0.
U
src_o
:=
0.
U
}
.
elsewhen
(
state
===
sExe
&&
alu
.
io
.
out
.
data
.
valid
&&
uop_idx
===
uop_end
-
1.
U
&&
cnt_i
===
dec
.
lp_1
-
1.
U
)
{
cnt_o
:=
cnt_o
+
1.
U
dst_o
:=
dst_o
+
dec
.
dst_0
src_o
:=
src_o
+
dec
.
src_0
}
when
(
state
===
sIdle
)
{
cnt_i
:=
0.
U
dst_i
:=
0.
U
src_i
:=
0.
U
}
.
elsewhen
(
state
===
sReadUop
&&
cnt_i
===
dec
.
lp_1
)
{
cnt_i
:=
0.
U
dst_i
:=
dst_o
src_i
:=
src_o
}
.
elsewhen
(
state
===
sExe
&&
alu
.
io
.
out
.
data
.
valid
&&
uop_idx
===
uop_end
-
1.
U
)
{
cnt_i
:=
cnt_i
+
1.
U
dst_i
:=
dst_i
+
dec
.
dst_1
src_i
:=
src_i
+
dec
.
src_1
}
when
(
state
===
sComputeIdx
&&
io
.
uop
.
data
.
valid
)
{
uop_dst
:=
io
.
uop
.
data
.
bits
.
u0
+
dst_i
uop_src
:=
io
.
uop
.
data
.
bits
.
u1
+
src_i
}
// uop
io
.
uop
.
idx
.
valid
:=
state
===
sReadUop
io
.
uop
.
idx
.
bits
:=
uop_idx
// acc_i
io
.
acc
.
rd
.
idx
.
valid
:=
state
===
sReadTensorA
|
(
state
===
sReadTensorB
&
~
dec
.
alu_use_imm
)
io
.
acc
.
rd
.
idx
.
bits
:=
Mux
(
state
===
sReadTensorA
,
uop_dst
,
uop_src
)
// imm
val
tensorImm
=
Wire
(
new
TensorClientData
(
tensorType
=
"acc"
))
tensorImm
.
data
.
valid
:=
state
===
sReadTensorB
tensorImm
.
data
.
bits
.
foreach
{
b
=>
b
.
foreach
{
c
=>
c
:=
dec
.
alu_imm
}
}
// alu
val
isSHR
=
dec
.
alu_op
===
ALU_OP
(
3
)
val
neg_shift
=
isSHR
&
dec
.
alu_imm
(
C_ALU_IMM_BITS
-
1
)
val
fixme_alu_op
=
Cat
(
neg_shift
,
Mux
(
neg_shift
,
0.
U
,
dec
.
alu_op
))
alu
.
io
.
opcode
:=
fixme_alu_op
alu
.
io
.
acc_a
.
data
.
valid
:=
io
.
acc
.
rd
.
data
.
valid
&
state
===
sReadTensorB
alu
.
io
.
acc_a
.
data
.
bits
<>
io
.
acc
.
rd
.
data
.
bits
alu
.
io
.
acc_b
.
data
.
valid
:=
Mux
(
dec
.
alu_use_imm
,
tensorImm
.
data
.
valid
,
io
.
acc
.
rd
.
data
.
valid
&
state
===
sExe
)
alu
.
io
.
acc_b
.
data
.
bits
<>
Mux
(
dec
.
alu_use_imm
,
tensorImm
.
data
.
bits
,
io
.
acc
.
rd
.
data
.
bits
)
// acc_o
io
.
acc
.
wr
.
valid
:=
alu
.
io
.
acc_y
.
data
.
valid
io
.
acc
.
wr
.
bits
.
idx
:=
uop_dst
io
.
acc
.
wr
.
bits
.
data
<>
alu
.
io
.
acc_y
.
data
.
bits
// out
io
.
out
.
wr
.
valid
:=
alu
.
io
.
out
.
data
.
valid
io
.
out
.
wr
.
bits
.
idx
:=
uop_dst
io
.
out
.
wr
.
bits
.
data
<>
alu
.
io
.
out
.
data
.
bits
io
.
out
.
tieoffRead
()
// write-only
io
.
done
:=
done
if
(
debug
)
{
when
(
state
===
sReadUop
)
{
printf
(
"[TensorAlu] [uop] idx:%x\n"
,
uop_idx
)
}
when
(
state
===
sReadTensorA
)
{
printf
(
"[TensorAlu] [uop] dst:%x src:%x\n"
,
uop_dst
,
uop_src
)
}
when
(
state
===
sIdle
&&
io
.
start
)
{
printf
(
p
"[TensorAlu] decode:$dec\n"
)
}
alu
.
io
.
acc_a
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
alu
.
io
.
acc_a
.
data
.
valid
)
{
printf
(
"[TensorAlu] [a] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
alu
.
io
.
acc_b
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
alu
.
io
.
acc_b
.
data
.
valid
)
{
printf
(
"[TensorAlu] [b] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
alu
.
io
.
acc_y
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
alu
.
io
.
acc_y
.
data
.
valid
)
{
printf
(
"[TensorAlu] [y] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
alu
.
io
.
out
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
alu
.
io
.
out
.
data
.
valid
)
{
printf
(
"[TensorAlu] [out] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
}
}
vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
chisel3.experimental._
import
vta.util.config._
import
scala.math.pow
/** Pipelined multiply and accumulate */
class
MAC
(
dataBits
:
Int
=
8
,
cBits
:
Int
=
16
,
outBits
:
Int
=
17
)
extends
Module
{
require
(
cBits
>=
dataBits
*
2
)
require
(
outBits
>=
dataBits
*
2
)
val
io
=
IO
(
new
Bundle
{
val
a
=
Input
(
SInt
(
dataBits
.
W
))
val
b
=
Input
(
SInt
(
dataBits
.
W
))
val
c
=
Input
(
SInt
(
cBits
.
W
))
val
y
=
Output
(
SInt
(
outBits
.
W
))
})
val
mult
=
Wire
(
SInt
(
cBits
.
W
))
val
add
=
Wire
(
SInt
(
outBits
.
W
))
val
rA
=
RegNext
(
io
.
a
)
val
rB
=
RegNext
(
io
.
b
)
val
rC
=
RegNext
(
io
.
c
)
mult
:=
rA
*
rB
add
:=
rC
+
mult
io
.
y
:=
add
}
/** Pipelined adder */
class
Adder
(
dataBits
:
Int
=
8
,
outBits
:
Int
=
17
)
extends
Module
{
require
(
outBits
>=
dataBits
)
val
io
=
IO
(
new
Bundle
{
val
a
=
Input
(
SInt
(
dataBits
.
W
))
val
b
=
Input
(
SInt
(
dataBits
.
W
))
val
y
=
Output
(
SInt
(
outBits
.
W
))
})
val
add
=
Wire
(
SInt
(
outBits
.
W
))
val
rA
=
RegNext
(
io
.
a
)
val
rB
=
RegNext
(
io
.
b
)
add
:=
rA
+
rB
io
.
y
:=
add
}
/** Pipelined DotProduct based on MAC and Adder */
class
DotProduct
(
dataBits
:
Int
=
8
,
size
:
Int
=
16
)
extends
Module
{
val
errMsg
=
s
"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require
(
size
>=
4
&&
isPow2
(
size
),
errMsg
)
val
b
=
dataBits
*
2
val
outBits
=
b
+
log2Ceil
(
size
)
+
1
val
io
=
IO
(
new
Bundle
{
val
a
=
Input
(
Vec
(
size
,
SInt
(
dataBits
.
W
)))
val
b
=
Input
(
Vec
(
size
,
SInt
(
dataBits
.
W
)))
val
y
=
Output
(
SInt
(
outBits
.
W
))
})
val
p
=
log2Ceil
(
size
/
2
)
val
s
=
Seq
.
tabulate
(
log2Ceil
(
size
))(
i
=>
pow
(
2
,
p
-
i
).
toInt
)
val
da
=
Seq
.
tabulate
(
s
(
0
))(
i
=>
RegNext
(
io
.
a
(
s
(
0
)
+
i
)))
val
db
=
Seq
.
tabulate
(
s
(
0
))(
i
=>
RegNext
(
io
.
b
(
s
(
0
)
+
i
)))
val
m
=
Seq
.
tabulate
(
2
)(
i
=>
Seq
.
fill
(
s
(
0
))(
Module
(
new
MAC
(
dataBits
=
dataBits
,
cBits
=
b
+
i
,
outBits
=
b
+
i
+
1
)))
)
val
a
=
Seq
.
tabulate
(
p
)(
i
=>
Seq
.
fill
(
s
(
i
+
1
))(
Module
(
new
Adder
(
dataBits
=
b
+
i
+
2
,
outBits
=
b
+
i
+
3
)))
)
for
(
i
<-
0
until
log2Ceil
(
size
))
{
for
(
j
<-
0
until
s
(
i
))
{
if
(
i
==
0
)
{
m
(
i
)(
j
).
io
.
a
:=
io
.
a
(
j
)
m
(
i
)(
j
).
io
.
b
:=
io
.
b
(
j
)
m
(
i
)(
j
).
io
.
c
:=
0.
S
m
(
i
+
1
)(
j
).
io
.
a
:=
da
(
j
)
m
(
i
+
1
)(
j
).
io
.
b
:=
db
(
j
)
m
(
i
+
1
)(
j
).
io
.
c
:=
m
(
i
)(
j
).
io
.
y
}
else
if
(
i
==
1
)
{
a
(
i
-
1
)(
j
).
io
.
a
:=
m
(
i
)(
2
*
j
).
io
.
y
a
(
i
-
1
)(
j
).
io
.
b
:=
m
(
i
)(
2
*
j
+
1
).
io
.
y
}
else
{
a
(
i
-
1
)(
j
).
io
.
a
:=
a
(
i
-
2
)(
2
*
j
).
io
.
y
a
(
i
-
1
)(
j
).
io
.
b
:=
a
(
i
-
2
)(
2
*
j
+
1
).
io
.
y
}
}
}
io
.
y
:=
a
(
p
-
1
)(
0
).
io
.
y
}
/** Perform matric-vector-multiplication based on DotProduct */
class
MatrixVectorCore
(
implicit
p
:
Parameters
)
extends
Module
{
val
accBits
=
p
(
CoreKey
).
accBits
val
size
=
p
(
CoreKey
).
blockOut
val
dataBits
=
p
(
CoreKey
).
inpBits
val
io
=
IO
(
new
Bundle
{
val
reset
=
Input
(
Bool
())
// FIXME: reset should be replaced by a load-acc instr
val
inp
=
new
TensorMasterData
(
tensorType
=
"inp"
)
val
wgt
=
new
TensorMasterData
(
tensorType
=
"wgt"
)
val
acc_i
=
new
TensorMasterData
(
tensorType
=
"acc"
)
val
acc_o
=
new
TensorClientData
(
tensorType
=
"acc"
)
val
out
=
new
TensorClientData
(
tensorType
=
"out"
)
})
val
dot
=
Seq
.
fill
(
size
)(
Module
(
new
DotProduct
(
dataBits
,
size
)))
val
acc
=
Seq
.
fill
(
size
)(
Module
(
new
Pipe
(
UInt
(
accBits
.
W
),
latency
=
log2Ceil
(
size
)
+
1
)))
val
add
=
Seq
.
fill
(
size
)(
Wire
(
SInt
(
accBits
.
W
)))
val
vld
=
Wire
(
Vec
(
size
,
Bool
()))
for
(
i
<-
0
until
size
)
{
acc
(
i
).
io
.
enq
.
valid
:=
io
.
inp
.
data
.
valid
&
io
.
wgt
.
data
.
valid
&
io
.
acc_i
.
data
.
valid
&
~
io
.
reset
acc
(
i
).
io
.
enq
.
bits
:=
io
.
acc_i
.
data
.
bits
(
0
)(
i
)
for
(
j
<-
0
until
size
)
{
dot
(
i
).
io
.
a
(
j
)
:=
io
.
inp
.
data
.
bits
(
0
)(
j
).
asSInt
dot
(
i
).
io
.
b
(
j
)
:=
io
.
wgt
.
data
.
bits
(
i
)(
j
).
asSInt
}
add
(
i
)
:=
acc
(
i
).
io
.
deq
.
bits
.
asSInt
+
dot
(
i
).
io
.
y
io
.
acc_o
.
data
.
bits
(
0
)(
i
)
:=
Mux
(
io
.
reset
,
0.
U
,
add
(
i
).
asUInt
)
io
.
out
.
data
.
bits
(
0
)(
i
)
:=
add
(
i
).
asUInt
vld
(
i
)
:=
acc
(
i
).
io
.
deq
.
valid
}
io
.
acc_o
.
data
.
valid
:=
vld
.
asUInt
.
andR
|
io
.
reset
io
.
out
.
data
.
valid
:=
vld
.
asUInt
.
andR
}
/** TensorGemm.
*
* This unit instantiate the MatrixVectorCore and go over the
* micro-ops (uops) which are used to read inputs, weights and biases,
* and writes results back to the acc and out scratchpads.
*
* Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/
class
TensorGemm
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
uop
=
new
UopMaster
val
inp
=
new
TensorMaster
(
tensorType
=
"inp"
)
val
wgt
=
new
TensorMaster
(
tensorType
=
"wgt"
)
val
acc
=
new
TensorMaster
(
tensorType
=
"acc"
)
val
out
=
new
TensorMaster
(
tensorType
=
"out"
)
})
val
sIdle
::
sReadUop
::
sComputeIdx
::
sReadTensor
::
sExe
::
sWait
::
Nil
=
Enum
(
6
)
val
state
=
RegInit
(
sIdle
)
val
mvc
=
Module
(
new
MatrixVectorCore
)
val
dec
=
io
.
inst
.
asTypeOf
(
new
GemmDecode
)
val
uop_idx
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
uop_end
=
dec
.
uop_end
val
uop_acc
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
uop_inp
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
uop_wgt
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
cnt_o
=
Reg
(
chiselTypeOf
(
dec
.
lp_0
))
val
acc_o
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
inp_o
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
wgt_o
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
cnt_i
=
Reg
(
chiselTypeOf
(
dec
.
lp_1
))
val
acc_i
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
inp_i
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
wgt_i
=
Reg
(
chiselTypeOf
(
dec
.
uop_end
))
val
pBits
=
log2Ceil
(
p
(
CoreKey
).
blockOut
)
+
1
val
inflight
=
Reg
(
UInt
(
pBits
.
W
))
val
wrpipe
=
Module
(
new
Pipe
(
chiselTypeOf
(
dec
.
uop_end
),
latency
=
pBits
))
val
done
=
inflight
===
0.
U
&
((
state
===
sExe
&
cnt_o
===
dec
.
lp_0
-
1.
U
&
cnt_i
===
dec
.
lp_1
-
1.
U
&
uop_idx
===
uop_end
-
1.
U
&
inflight
===
0.
U
)
|
state
===
sWait
)
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
state
:=
sReadUop
}
}
is
(
sReadUop
)
{
state
:=
sComputeIdx
}
is
(
sComputeIdx
)
{
state
:=
sReadTensor
}
is
(
sReadTensor
)
{
state
:=
sExe
}
is
(
sExe
)
{
when
((
cnt_o
===
dec
.
lp_0
-
1.
U
)
&&
(
cnt_i
===
dec
.
lp_1
-
1.
U
)
&&
(
uop_idx
===
uop_end
-
1.
U
))
{
when
(
inflight
=/=
0.
U
)
{
state
:=
sWait
}
.
otherwise
{
state
:=
sIdle
}
}
.
otherwise
{
state
:=
sReadUop
}
}
is
(
sWait
)
{
when
(
inflight
===
0.
U
)
{
state
:=
sIdle
}
}
}
when
(
state
===
sIdle
)
{
inflight
:=
0.
U
}
.
elsewhen
(!
dec
.
reset
)
{
when
(
state
===
sExe
&&
inflight
=/=
((
1
<<
pBits
)
-
1
).
asUInt
)
{
// overflow check
inflight
:=
inflight
+
1.
U
}
.
elsewhen
(
mvc
.
io
.
acc_o
.
data
.
valid
&&
inflight
=/=
0.
U
)
{
// underflow check
inflight
:=
inflight
-
1.
U
}
}
when
(
state
===
sIdle
||
(
state
===
sExe
&&
uop_idx
===
uop_end
-
1.
U
))
{
uop_idx
:=
dec
.
uop_begin
}
.
elsewhen
(
state
===
sExe
)
{
uop_idx
:=
uop_idx
+
1.
U
}
when
(
state
===
sIdle
)
{
cnt_o
:=
0.
U
acc_o
:=
0.
U
inp_o
:=
0.
U
wgt_o
:=
0.
U
}
.
elsewhen
(
state
===
sExe
&&
uop_idx
===
uop_end
-
1.
U
&&
cnt_i
===
dec
.
lp_1
-
1.
U
)
{
cnt_o
:=
cnt_o
+
1.
U
acc_o
:=
acc_o
+
dec
.
acc_0
inp_o
:=
inp_o
+
dec
.
inp_0
wgt_o
:=
wgt_o
+
dec
.
wgt_0
}
when
(
state
===
sIdle
)
{
cnt_i
:=
0.
U
acc_i
:=
0.
U
inp_i
:=
0.
U
wgt_i
:=
0.
U
}
.
elsewhen
(
state
===
sReadUop
&&
cnt_i
===
dec
.
lp_1
)
{
cnt_i
:=
0.
U
acc_i
:=
acc_o
inp_i
:=
inp_o
wgt_i
:=
wgt_o
}
.
elsewhen
(
state
===
sExe
&&
uop_idx
===
uop_end
-
1.
U
)
{
cnt_i
:=
cnt_i
+
1.
U
acc_i
:=
acc_i
+
dec
.
acc_1
inp_i
:=
inp_i
+
dec
.
inp_1
wgt_i
:=
wgt_i
+
dec
.
wgt_1
}
when
(
state
===
sComputeIdx
&&
io
.
uop
.
data
.
valid
)
{
uop_acc
:=
io
.
uop
.
data
.
bits
.
u0
+
acc_i
uop_inp
:=
io
.
uop
.
data
.
bits
.
u1
+
inp_i
uop_wgt
:=
io
.
uop
.
data
.
bits
.
u2
+
wgt_i
}
wrpipe
.
io
.
enq
.
valid
:=
state
===
sExe
&
~
dec
.
reset
wrpipe
.
io
.
enq
.
bits
:=
uop_acc
// uop
io
.
uop
.
idx
.
valid
:=
state
===
sReadUop
io
.
uop
.
idx
.
bits
:=
uop_idx
// inp
io
.
inp
.
rd
.
idx
.
valid
:=
state
===
sReadTensor
io
.
inp
.
rd
.
idx
.
bits
:=
uop_inp
io
.
inp
.
tieoffWrite
()
// read-only
// wgt
io
.
wgt
.
rd
.
idx
.
valid
:=
state
===
sReadTensor
io
.
wgt
.
rd
.
idx
.
bits
:=
uop_wgt
io
.
wgt
.
tieoffWrite
()
// read-only
// acc_i
io
.
acc
.
rd
.
idx
.
valid
:=
state
===
sReadTensor
io
.
acc
.
rd
.
idx
.
bits
:=
uop_acc
// mvc
mvc
.
io
.
reset
:=
dec
.
reset
&
state
===
sExe
mvc
.
io
.
inp
.
data
<>
io
.
inp
.
rd
.
data
mvc
.
io
.
wgt
.
data
<>
io
.
wgt
.
rd
.
data
mvc
.
io
.
acc_i
.
data
<>
io
.
acc
.
rd
.
data
// acc_o
io
.
acc
.
wr
.
valid
:=
mvc
.
io
.
acc_o
.
data
.
valid
&
Mux
(
dec
.
reset
,
true
.
B
,
wrpipe
.
io
.
deq
.
valid
)
io
.
acc
.
wr
.
bits
.
idx
:=
Mux
(
dec
.
reset
,
uop_acc
,
wrpipe
.
io
.
deq
.
bits
)
io
.
acc
.
wr
.
bits
.
data
<>
mvc
.
io
.
acc_o
.
data
.
bits
// out
io
.
out
.
wr
.
valid
:=
mvc
.
io
.
out
.
data
.
valid
&
wrpipe
.
io
.
deq
.
valid
io
.
out
.
wr
.
bits
.
idx
:=
wrpipe
.
io
.
deq
.
bits
io
.
out
.
wr
.
bits
.
data
<>
mvc
.
io
.
out
.
data
.
bits
io
.
out
.
tieoffRead
()
// write-only
io
.
done
:=
done
if
(
debug
)
{
when
(
state
===
sReadUop
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [uop] idx:%x\n"
,
uop_idx
)
}
when
(
state
===
sReadTensor
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n"
,
uop_acc
,
uop_inp
,
uop_wgt
)
}
io
.
inp
.
rd
.
data
.
bits
.
zipWithIndex
.
foreach
{
case
(
r
,
i
)
=>
when
(
io
.
inp
.
rd
.
data
.
valid
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [inp] i:%x val:%x\n"
,
i
.
U
,
r
.
asUInt
)
}
}
io
.
wgt
.
rd
.
data
.
bits
.
zipWithIndex
.
foreach
{
case
(
r
,
i
)
=>
when
(
io
.
wgt
.
rd
.
data
.
valid
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [wgt] i:%x val:%x\n"
,
i
.
U
,
r
.
asUInt
)
}
}
io
.
acc
.
rd
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
io
.
acc
.
rd
.
data
.
valid
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [acc_i] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
mvc
.
io
.
acc_o
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
mvc
.
io
.
acc_o
.
data
.
valid
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [acc_o] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
mvc
.
io
.
out
.
data
.
bits
.
foreach
{
tensor
=>
tensor
.
zipWithIndex
.
foreach
{
case
(
elem
,
i
)
=>
when
(
mvc
.
io
.
out
.
data
.
valid
&&
~
dec
.
reset
)
{
printf
(
"[TensorGemm] [out] i:%x val:%x\n"
,
i
.
U
,
elem
)
}
}
}
}
}
vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** TensorStore.
*
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
* scratchpads (SRAM). Also, there is support for zero padding, while
* doing the load. Zero-padding works on the y and x axis, and it is
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads.
*/
class
TensorLoad
(
tensorType
:
String
=
"none"
,
debug
:
Boolean
=
false
)
(
implicit
p
:
Parameters
)
extends
Module
{
val
tp
=
new
TensorParams
(
tensorType
)
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_rd
=
new
VMEReadMaster
val
tensor
=
new
TensorClient
(
tensorType
)
})
val
sizeFactor
=
tp
.
tensorLength
*
tp
.
numMemBlock
val
strideFactor
=
tp
.
tensorLength
*
tp
.
tensorWidth
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
val
dataCtrl
=
Module
(
new
TensorDataCtrl
(
sizeFactor
,
strideFactor
))
val
dataCtrlDone
=
RegInit
(
false
.
B
)
val
yPadCtrl0
=
Module
(
new
TensorPadCtrl
(
padType
=
"YPad0"
,
sizeFactor
))
val
yPadCtrl1
=
Module
(
new
TensorPadCtrl
(
padType
=
"YPad1"
,
sizeFactor
))
val
xPadCtrl0
=
Module
(
new
TensorPadCtrl
(
padType
=
"XPad0"
,
sizeFactor
))
val
xPadCtrl1
=
Module
(
new
TensorPadCtrl
(
padType
=
"XPad1"
,
sizeFactor
))
val
tag
=
Reg
(
UInt
(
8.
W
))
val
set
=
Reg
(
UInt
(
8.
W
))
val
sIdle
::
sYPad0
::
sXPad0
::
sReadCmd
::
sReadData
::
sXPad1
::
sYPad1
::
Nil
=
Enum
(
7
)
val
state
=
RegInit
(
sIdle
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
when
(
dec
.
ypad_0
=/=
0.
U
)
{
state
:=
sYPad0
}
.
elsewhen
(
dec
.
xpad_0
=/=
0.
U
)
{
state
:=
sXPad0
}
.
otherwise
{
state
:=
sReadCmd
}
}
}
is
(
sYPad0
)
{
when
(
yPadCtrl0
.
io
.
done
)
{
when
(
dec
.
xpad_0
=/=
0.
U
)
{
state
:=
sXPad0
}
.
otherwise
{
state
:=
sReadCmd
}
}
}
is
(
sXPad0
)
{
when
(
xPadCtrl0
.
io
.
done
)
{
state
:=
sReadCmd
}
}
is
(
sReadCmd
)
{
when
(
io
.
vme_rd
.
cmd
.
ready
)
{
state
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
vme_rd
.
data
.
valid
)
{
when
(
dataCtrl
.
io
.
done
)
{
when
(
dec
.
xpad_1
=/=
0.
U
)
{
state
:=
sXPad1
}
.
elsewhen
(
dec
.
ypad_1
=/=
0.
U
)
{
state
:=
sYPad1
}
.
otherwise
{
state
:=
sIdle
}
}
.
elsewhen
(
dataCtrl
.
io
.
stride
||
dataCtrl
.
io
.
split
)
{
when
(
dec
.
xpad_1
=/=
0.
U
)
{
state
:=
sXPad1
}
.
elsewhen
(
dec
.
xpad_0
=/=
0.
U
)
{
state
:=
sXPad0
}
.
otherwise
{
state
:=
sReadCmd
}
}
}
}
is
(
sXPad1
)
{
when
(
xPadCtrl1
.
io
.
done
)
{
when
(
dataCtrlDone
)
{
when
(
dec
.
ypad_1
=/=
0.
U
)
{
state
:=
sYPad1
}
.
otherwise
{
state
:=
sIdle
}
}
.
otherwise
{
when
(
dec
.
xpad_0
=/=
0.
U
)
{
state
:=
sXPad0
}
.
otherwise
{
state
:=
sReadCmd
}
}
}
}
is
(
sYPad1
)
{
when
(
yPadCtrl1
.
io
.
done
&&
dataCtrlDone
)
{
state
:=
sIdle
}
}
}
// data controller
dataCtrl
.
io
.
start
:=
state
===
sIdle
&
io
.
start
dataCtrl
.
io
.
inst
:=
io
.
inst
dataCtrl
.
io
.
baddr
:=
io
.
baddr
dataCtrl
.
io
.
xinit
:=
io
.
vme_rd
.
cmd
.
fire
()
dataCtrl
.
io
.
xupdate
:=
io
.
vme_rd
.
data
.
fire
()
dataCtrl
.
io
.
yupdate
:=
io
.
vme_rd
.
data
.
fire
()
when
(
state
===
sIdle
)
{
dataCtrlDone
:=
false
.
B
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
()
&&
dataCtrl
.
io
.
done
)
{
dataCtrlDone
:=
true
.
B
}
// pad
yPadCtrl0
.
io
.
start
:=
dec
.
ypad_0
=/=
0.
U
&
state
===
sIdle
&
io
.
start
yPadCtrl1
.
io
.
start
:=
dec
.
ypad_1
=/=
0.
U
&
((
io
.
vme_rd
.
data
.
fire
()
&
dataCtrl
.
io
.
done
&
dec
.
xpad_1
===
0.
U
)
|
(
state
===
sXPad1
&
xPadCtrl1
.
io
.
done
&
dataCtrlDone
))
xPadCtrl0
.
io
.
start
:=
dec
.
xpad_0
=/=
0.
U
&
((
state
===
sIdle
&
io
.
start
)
|
(
state
===
sYPad0
&
yPadCtrl0
.
io
.
done
)
|
(
io
.
vme_rd
.
data
.
fire
()
&
~
dataCtrlDone
&
(
dataCtrl
.
io
.
stride
|
dataCtrl
.
io
.
split
)
&
dec
.
xpad_1
===
0.
U
)
|
(
state
===
sXPad1
&
xPadCtrl1
.
io
.
done
&
~
dataCtrlDone
))
xPadCtrl1
.
io
.
start
:=
dec
.
xpad_1
=/=
0.
U
&
io
.
vme_rd
.
data
.
fire
()
&
((
dataCtrl
.
io
.
done
)
|
(~
dataCtrl
.
io
.
done
&
(
dataCtrl
.
io
.
stride
|
dataCtrl
.
io
.
split
)
&
dec
.
xpad_1
=/=
0.
U
))
yPadCtrl0
.
io
.
inst
:=
io
.
inst
yPadCtrl1
.
io
.
inst
:=
io
.
inst
xPadCtrl0
.
io
.
inst
:=
io
.
inst
xPadCtrl1
.
io
.
inst
:=
io
.
inst
// read-from-dram
io
.
vme_rd
.
cmd
.
valid
:=
state
===
sReadCmd
io
.
vme_rd
.
cmd
.
bits
.
addr
:=
dataCtrl
.
io
.
addr
io
.
vme_rd
.
cmd
.
bits
.
len
:=
dataCtrl
.
io
.
len
io
.
vme_rd
.
data
.
ready
:=
state
===
sReadData
// write-to-sram
val
isZeroPad
=
state
===
sYPad0
|
state
===
sXPad0
|
state
===
sXPad1
|
state
===
sYPad1
when
(
state
===
sIdle
||
state
===
sReadCmd
||
tag
===
(
tp
.
numMemBlock
-
1
).
U
)
{
tag
:=
0.
U
}
.
elsewhen
(
io
.
vme_rd
.
data
.
fire
()
||
isZeroPad
)
{
tag
:=
tag
+
1.
U
}
when
(
state
===
sIdle
||
state
===
sReadCmd
||
(
set
===
(
tp
.
tensorLength
-
1
).
U
&&
tag
===
(
tp
.
numMemBlock
-
1
).
U
))
{
set
:=
0.
U
}
.
elsewhen
((
io
.
vme_rd
.
data
.
fire
()
||
isZeroPad
)
&&
tag
===
(
tp
.
numMemBlock
-
1
).
U
)
{
set
:=
set
+
1.
U
}
val
waddr_cur
=
Reg
(
UInt
(
tp
.
memAddrBits
.
W
))
val
waddr_nxt
=
Reg
(
UInt
(
tp
.
memAddrBits
.
W
))
when
(
state
===
sIdle
)
{
waddr_cur
:=
dec
.
sram_offset
waddr_nxt
:=
dec
.
sram_offset
}
.
elsewhen
((
io
.
vme_rd
.
data
.
fire
()
||
isZeroPad
)
&&
set
===
(
tp
.
tensorLength
-
1
).
U
&&
tag
===
(
tp
.
numMemBlock
-
1
).
U
)
{
waddr_cur
:=
waddr_cur
+
1.
U
}
.
elsewhen
(
dataCtrl
.
io
.
stride
)
{
waddr_cur
:=
waddr_nxt
+
dec
.
xsize
waddr_nxt
:=
waddr_nxt
+
dec
.
xsize
}
val
tensorFile
=
Seq
.
fill
(
tp
.
tensorLength
)
{
SyncReadMem
(
tp
.
memDepth
,
Vec
(
tp
.
numMemBlock
,
UInt
(
tp
.
memBlockBits
.
W
)))
}
val
wmask
=
Seq
.
fill
(
tp
.
tensorLength
)
{
Wire
(
Vec
(
tp
.
numMemBlock
,
Bool
()))
}
val
wdata
=
Seq
.
fill
(
tp
.
tensorLength
)
{
Wire
(
Vec
(
tp
.
numMemBlock
,
UInt
(
tp
.
memBlockBits
.
W
)))
}
val
no_mask
=
Wire
(
Vec
(
tp
.
numMemBlock
,
Bool
()))
no_mask
.
foreach
{
m
=>
m
:=
true
.
B
}
for
(
i
<-
0
until
tp
.
tensorLength
)
{
for
(
j
<-
0
until
tp
.
numMemBlock
)
{
wmask
(
i
)(
j
)
:=
tag
===
j
.
U
wdata
(
i
)(
j
)
:=
Mux
(
isZeroPad
,
0.
U
,
io
.
vme_rd
.
data
.
bits
)
}
val
tdata
=
io
.
tensor
.
wr
.
bits
.
data
(
i
).
asUInt
.
asTypeOf
(
wdata
(
i
))
val
muxWen
=
Mux
(
state
===
sIdle
,
io
.
tensor
.
wr
.
valid
,
(
io
.
vme_rd
.
data
.
fire
()
|
isZeroPad
)
&
set
===
i
.
U
)
val
muxWaddr
=
Mux
(
state
===
sIdle
,
io
.
tensor
.
wr
.
bits
.
idx
,
waddr_cur
)
val
muxWdata
=
Mux
(
state
===
sIdle
,
tdata
,
wdata
(
i
))
val
muxWmask
=
Mux
(
state
===
sIdle
,
no_mask
,
wmask
(
i
))
when
(
muxWen
)
{
tensorFile
(
i
).
write
(
muxWaddr
,
muxWdata
,
muxWmask
)
}
}
// read-from-sram
val
rvalid
=
RegNext
(
io
.
tensor
.
rd
.
idx
.
valid
)
io
.
tensor
.
rd
.
data
.
valid
:=
rvalid
val
rdata
=
tensorFile
.
map
(
_
.
read
(
io
.
tensor
.
rd
.
idx
.
bits
,
io
.
tensor
.
rd
.
idx
.
valid
))
rdata
.
zipWithIndex
.
foreach
{
case
(
r
,
i
)
=>
io
.
tensor
.
rd
.
data
.
bits
(
i
)
:=
r
.
asUInt
.
asTypeOf
(
io
.
tensor
.
rd
.
data
.
bits
(
i
))
}
// done
val
done_no_pad
=
io
.
vme_rd
.
data
.
fire
()
&
dataCtrl
.
io
.
done
&
dec
.
xpad_1
===
0.
U
&
dec
.
ypad_1
===
0.
U
val
done_x_pad
=
state
===
sXPad1
&
xPadCtrl1
.
io
.
done
&
dataCtrlDone
&
dec
.
ypad_1
===
0.
U
val
done_y_pad
=
state
===
sYPad1
&
dataCtrlDone
&
yPadCtrl1
.
io
.
done
io
.
done
:=
done_no_pad
|
done_x_pad
|
done_y_pad
// debug
if
(
debug
)
{
if
(
tensorType
==
"inp"
)
{
when
(
io
.
vme_rd
.
cmd
.
fire
())
{
printf
(
"[TensorLoad] [inp] cmd addr:%x len:%x\n"
,
dataCtrl
.
io
.
addr
,
dataCtrl
.
io
.
len
)
}
when
(
state
===
sYPad0
)
{
printf
(
"[TensorLoad] [inp] sYPad0\n"
)
}
when
(
state
===
sYPad1
)
{
printf
(
"[TensorLoad] [inp] sYPad1\n"
)
}
when
(
state
===
sXPad0
)
{
printf
(
"[TensorLoad] [inp] sXPad0\n"
)
}
when
(
state
===
sXPad1
)
{
printf
(
"[TensorLoad] [inp] sXPad1\n"
)
}
}
else
if
(
tensorType
==
"wgt"
)
{
when
(
io
.
vme_rd
.
cmd
.
fire
())
{
printf
(
"[TensorLoad] [wgt] cmd addr:%x len:%x\n"
,
dataCtrl
.
io
.
addr
,
dataCtrl
.
io
.
len
)
}
}
else
if
(
tensorType
==
"acc"
)
{
when
(
io
.
vme_rd
.
cmd
.
fire
())
{
printf
(
"[TensorLoad] [acc] cmd addr:%x len:%x\n"
,
dataCtrl
.
io
.
addr
,
dataCtrl
.
io
.
len
)
}
}
}
}
vta/hardware/chisel/src/main/scala/core/TensorStore.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** TensorStore.
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
class
TensorStore
(
tensorType
:
String
=
"true"
,
debug
:
Boolean
=
false
)
(
implicit
p
:
Parameters
)
extends
Module
{
val
tp
=
new
TensorParams
(
tensorType
)
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
vme_wr
=
new
VMEWriteMaster
val
tensor
=
new
TensorClient
(
tensorType
)
})
val
tensorLength
=
tp
.
tensorLength
val
tensorWidth
=
tp
.
tensorWidth
val
tensorElemBits
=
tp
.
tensorElemBits
val
memBlockBits
=
tp
.
memBlockBits
val
memDepth
=
tp
.
memDepth
val
numMemBlock
=
tp
.
numMemBlock
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
val
waddr_cur
=
Reg
(
chiselTypeOf
(
io
.
vme_wr
.
cmd
.
bits
.
addr
))
val
waddr_nxt
=
Reg
(
chiselTypeOf
(
io
.
vme_wr
.
cmd
.
bits
.
addr
))
val
xcnt
=
Reg
(
chiselTypeOf
(
io
.
vme_wr
.
cmd
.
bits
.
len
))
val
xlen
=
Reg
(
chiselTypeOf
(
io
.
vme_wr
.
cmd
.
bits
.
len
))
val
xrem
=
Reg
(
chiselTypeOf
(
dec
.
xsize
))
val
xsize
=
(
dec
.
xsize
<<
log2Ceil
(
tensorLength
*
numMemBlock
))
-
1.
U
val
xmax
=
(
1
<<
mp
.
lenBits
).
U
val
xmax_bytes
=
((
1
<<
mp
.
lenBits
)*
mp
.
dataBits
/
8
).
U
val
ycnt
=
Reg
(
chiselTypeOf
(
dec
.
ysize
))
val
ysize
=
dec
.
ysize
val
tag
=
Reg
(
UInt
(
8.
W
))
val
set
=
Reg
(
UInt
(
8.
W
))
val
sIdle
::
sWriteCmd
::
sWriteData
::
sReadMem
::
sWriteAck
::
Nil
=
Enum
(
5
)
val
state
=
RegInit
(
sIdle
)
// control
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
state
:=
sWriteCmd
when
(
xsize
<
xmax
)
{
xlen
:=
xsize
xrem
:=
0.
U
}
.
otherwise
{
xlen
:=
xmax
-
1.
U
xrem
:=
xsize
-
xmax
}
}
}
is
(
sWriteCmd
)
{
when
(
io
.
vme_wr
.
cmd
.
ready
)
{
state
:=
sWriteData
}
}
is
(
sWriteData
)
{
when
(
io
.
vme_wr
.
data
.
ready
)
{
when
(
xcnt
===
xlen
)
{
state
:=
sWriteAck
}
.
elsewhen
(
tag
===
(
numMemBlock
-
1
).
U
)
{
state
:=
sReadMem
}
}
}
is
(
sReadMem
)
{
state
:=
sWriteData
}
is
(
sWriteAck
)
{
when
(
io
.
vme_wr
.
ack
)
{
when
(
xrem
===
0.
U
)
{
when
(
ycnt
===
ysize
-
1.
U
)
{
state
:=
sIdle
}
.
otherwise
{
state
:=
sWriteCmd
when
(
xsize
<
xmax
)
{
xlen
:=
xsize
xrem
:=
0.
U
}
.
otherwise
{
xlen
:=
xmax
-
1.
U
xrem
:=
xsize
-
xmax
}
}
}
.
elsewhen
(
xrem
<
xmax
)
{
state
:=
sWriteCmd
xlen
:=
xrem
xrem
:=
0.
U
}
.
otherwise
{
state
:=
sWriteCmd
xlen
:=
xmax
-
1.
U
xrem
:=
xrem
-
xmax
}
}
}
}
// write-to-sram
val
tensorFile
=
Seq
.
fill
(
tensorLength
)
{
SyncReadMem
(
memDepth
,
Vec
(
numMemBlock
,
UInt
(
memBlockBits
.
W
)))
}
val
wdata_t
=
Wire
(
Vec
(
numMemBlock
,
UInt
(
memBlockBits
.
W
)))
val
no_mask
=
Wire
(
Vec
(
numMemBlock
,
Bool
()))
wdata_t
:=
DontCare
no_mask
.
foreach
{
m
=>
m
:=
true
.
B
}
for
(
i
<-
0
until
tensorLength
)
{
val
inWrData
=
io
.
tensor
.
wr
.
bits
.
data
(
i
).
asUInt
.
asTypeOf
(
wdata_t
)
when
(
io
.
tensor
.
wr
.
valid
)
{
tensorFile
(
i
).
write
(
io
.
tensor
.
wr
.
bits
.
idx
,
inWrData
,
no_mask
)
}
}
// read-from-sram
val
stride
=
state
===
sWriteAck
&
io
.
vme_wr
.
ack
&
xcnt
===
xlen
+
1.
U
&
xrem
===
0.
U
&
ycnt
=/=
ysize
-
1.
U
when
(
state
===
sIdle
)
{
ycnt
:=
0.
U
}
.
elsewhen
(
stride
)
{
ycnt
:=
ycnt
+
1.
U
}
when
(
state
===
sWriteCmd
||
tag
===
(
numMemBlock
-
1
).
U
)
{
tag
:=
0.
U
}
.
elsewhen
(
io
.
vme_wr
.
data
.
fire
())
{
tag
:=
tag
+
1.
U
}
when
(
state
===
sWriteCmd
||
(
set
===
(
tensorLength
-
1
).
U
&&
tag
===
(
numMemBlock
-
1
).
U
))
{
set
:=
0.
U
}
.
elsewhen
(
io
.
vme_wr
.
data
.
fire
()
&&
tag
===
(
numMemBlock
-
1
).
U
)
{
set
:=
set
+
1.
U
}
val
raddr_cur
=
Reg
(
UInt
(
tp
.
memAddrBits
.
W
))
val
raddr_nxt
=
Reg
(
UInt
(
tp
.
memAddrBits
.
W
))
when
(
state
===
sIdle
)
{
raddr_cur
:=
dec
.
sram_offset
raddr_nxt
:=
dec
.
sram_offset
}
.
elsewhen
(
io
.
vme_wr
.
data
.
fire
()
&&
set
===
(
tensorLength
-
1
).
U
&&
tag
===
(
numMemBlock
-
1
).
U
)
{
raddr_cur
:=
raddr_cur
+
1.
U
}
.
elsewhen
(
stride
)
{
raddr_cur
:=
raddr_nxt
+
dec
.
xsize
raddr_nxt
:=
raddr_nxt
+
dec
.
xsize
}
val
tread
=
Seq
.
tabulate
(
tensorLength
)
{
i
=>
i
.
U
->
tensorFile
(
i
).
read
(
raddr_cur
,
state
===
sWriteCmd
|
state
===
sReadMem
)
}
val
mdata
=
MuxLookup
(
set
,
0.
U
.
asTypeOf
(
chiselTypeOf
(
wdata_t
)),
tread
)
// write-to-dram
when
(
state
===
sIdle
)
{
waddr_cur
:=
io
.
baddr
+
dec
.
dram_offset
waddr_nxt
:=
io
.
baddr
+
dec
.
dram_offset
}
.
elsewhen
(
state
===
sWriteAck
&&
io
.
vme_wr
.
ack
&&
xrem
=/=
0.
U
)
{
waddr_cur
:=
waddr_cur
+
xmax_bytes
}
.
elsewhen
(
stride
)
{
waddr_cur
:=
waddr_nxt
+
(
dec
.
xstride
<<
log2Ceil
(
tensorLength
*
tensorWidth
))
waddr_nxt
:=
waddr_nxt
+
(
dec
.
xstride
<<
log2Ceil
(
tensorLength
*
tensorWidth
))
}
io
.
vme_wr
.
cmd
.
valid
:=
state
===
sWriteCmd
io
.
vme_wr
.
cmd
.
bits
.
addr
:=
waddr_cur
io
.
vme_wr
.
cmd
.
bits
.
len
:=
xlen
io
.
vme_wr
.
data
.
valid
:=
state
===
sWriteData
io
.
vme_wr
.
data
.
bits
:=
mdata
(
tag
)
when
(
state
===
sWriteCmd
)
{
xcnt
:=
0.
U
}
.
elsewhen
(
io
.
vme_wr
.
data
.
fire
())
{
xcnt
:=
xcnt
+
1.
U
}
// disable external read-from-sram requests
io
.
tensor
.
tieoffRead
()
// done
io
.
done
:=
state
===
sWriteAck
&
io
.
vme_wr
.
ack
&
xrem
===
0.
U
&
ycnt
===
ysize
-
1.
U
// debug
if
(
debug
)
{
when
(
io
.
vme_wr
.
cmd
.
fire
())
{
printf
(
"[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n"
,
ysize
,
ycnt
,
raddr_cur
,
waddr_cur
,
xlen
,
xrem
)
}
when
(
io
.
vme_wr
.
data
.
fire
())
{
printf
(
"[TensorStore] data:%x\n"
,
io
.
vme_wr
.
data
.
bits
)
}
when
(
io
.
vme_wr
.
ack
)
{
printf
(
"[TensorStore] ack\n"
)
}
}
}
vta/hardware/chisel/src/main/scala/core/TensorUtil.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.core
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.shell._
/** TensorParams.
*
* This Bundle derives parameters for each tensorType, including inputs (inp),
* weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again.
*/
class
TensorParams
(
tensorType
:
String
=
"none"
)(
implicit
p
:
Parameters
)
extends
Bundle
{
val
errorMsg
=
s
"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require
(
tensorType
==
"inp"
||
tensorType
==
"wgt"
||
tensorType
==
"acc"
||
tensorType
==
"out"
,
errorMsg
)
val
(
tensorLength
,
tensorWidth
,
tensorElemBits
)
=
if
(
tensorType
==
"inp"
)
(
p
(
CoreKey
).
batch
,
p
(
CoreKey
).
blockIn
,
p
(
CoreKey
).
inpBits
)
else
if
(
tensorType
==
"wgt"
)
(
p
(
CoreKey
).
blockOut
,
p
(
CoreKey
).
blockIn
,
p
(
CoreKey
).
wgtBits
)
else
if
(
tensorType
==
"acc"
)
(
p
(
CoreKey
).
batch
,
p
(
CoreKey
).
blockOut
,
p
(
CoreKey
).
accBits
)
else
(
p
(
CoreKey
).
batch
,
p
(
CoreKey
).
blockOut
,
p
(
CoreKey
).
outBits
)
val
memBlockBits
=
p
(
ShellKey
).
memParams
.
dataBits
val
numMemBlock
=
(
tensorWidth
*
tensorElemBits
)
/
memBlockBits
val
memDepth
=
if
(
tensorType
==
"inp"
)
p
(
CoreKey
).
inpMemDepth
else
if
(
tensorType
==
"wgt"
)
p
(
CoreKey
).
wgtMemDepth
else
if
(
tensorType
==
"acc"
)
p
(
CoreKey
).
accMemDepth
else
p
(
CoreKey
).
outMemDepth
val
memAddrBits
=
log2Ceil
(
memDepth
)
}
/** TensorMaster.
*
* This interface issue read and write tensor-requests to scratchpads. For example,
* The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
* biases (acc), and outputs (out).
*
*/
class
TensorMaster
(
tensorType
:
String
=
"none"
)
(
implicit
p
:
Parameters
)
extends
TensorParams
(
tensorType
)
{
val
rd
=
new
Bundle
{
val
idx
=
ValidIO
(
UInt
(
memAddrBits
.
W
))
val
data
=
Flipped
(
ValidIO
(
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
)))))
}
val
wr
=
ValidIO
(
new
Bundle
{
val
idx
=
UInt
(
memAddrBits
.
W
)
val
data
=
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
)))
})
def
tieoffRead
()
{
rd
.
idx
.
valid
:=
false
.
B
rd
.
idx
.
bits
:=
0.
U
}
def
tieoffWrite
()
{
wr
.
valid
:=
false
.
B
wr
.
bits
.
idx
:=
0.
U
wr
.
bits
.
data
.
foreach
{
b
=>
b
.
foreach
{
c
=>
c
:=
0.
U
}
}
}
override
def
cloneType
=
new
TensorMaster
(
tensorType
).
asInstanceOf
[
this.
type
]
}
/** TensorClient.
*
* This interface receives read and write tensor-requests to scratchpads. For example,
* The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit.
*/
class
TensorClient
(
tensorType
:
String
=
"none"
)
(
implicit
p
:
Parameters
)
extends
TensorParams
(
tensorType
)
{
val
rd
=
new
Bundle
{
val
idx
=
Flipped
(
ValidIO
(
UInt
(
memAddrBits
.
W
)))
val
data
=
ValidIO
(
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
))))
}
val
wr
=
Flipped
(
ValidIO
(
new
Bundle
{
val
idx
=
UInt
(
memAddrBits
.
W
)
val
data
=
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
)))
}))
def
tieoffRead
()
{
rd
.
data
.
valid
:=
false
.
B
rd
.
data
.
bits
.
foreach
{
b
=>
b
.
foreach
{
c
=>
c
:=
0.
U
}
}
}
override
def
cloneType
=
new
TensorClient
(
tensorType
).
asInstanceOf
[
this.
type
]
}
/** TensorMasterData.
*
* This interface is only used for datapath only purposes and the direction convention
* is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class
TensorMasterData
(
tensorType
:
String
=
"none"
)
(
implicit
p
:
Parameters
)
extends
TensorParams
(
tensorType
)
{
val
data
=
Flipped
(
ValidIO
(
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
)))))
override
def
cloneType
=
new
TensorMasterData
(
tensorType
).
asInstanceOf
[
this.
type
]
}
/** TensorClientData.
*
* This interface is only used for datapath only purposes and the direction convention
* is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class
TensorClientData
(
tensorType
:
String
=
"none"
)
(
implicit
p
:
Parameters
)
extends
TensorParams
(
tensorType
)
{
val
data
=
ValidIO
(
Vec
(
tensorLength
,
Vec
(
tensorWidth
,
UInt
(
tensorElemBits
.
W
))))
override
def
cloneType
=
new
TensorClientData
(
tensorType
).
asInstanceOf
[
this.
type
]
}
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
class
TensorPadCtrl
(
padType
:
String
=
"none"
,
sizeFactor
:
Int
=
1
)
extends
Module
{
val
errorMsg
=
s
"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require
(
padType
==
"YPad0"
||
padType
==
"YPad1"
||
padType
==
"XPad0"
||
padType
==
"XPad1"
,
errorMsg
)
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
})
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
val
xmax
=
Reg
(
chiselTypeOf
(
dec
.
xsize
))
val
ymax
=
Reg
(
chiselTypeOf
(
dec
.
ypad_0
))
val
xcnt
=
Reg
(
chiselTypeOf
(
dec
.
xsize
))
val
ycnt
=
Reg
(
chiselTypeOf
(
dec
.
ypad_0
))
val
xval
=
if
(
padType
==
"YPad0"
||
padType
==
"YPad1"
)
((
dec
.
xpad_0
+
dec
.
xsize
+
dec
.
xpad_1
)
<<
log2Ceil
(
sizeFactor
))
-
1.
U
else
if
(
padType
==
"XPad0"
)
(
dec
.
xpad_0
<<
log2Ceil
(
sizeFactor
))
-
1.
U
else
(
dec
.
xpad_1
<<
log2Ceil
(
sizeFactor
))
-
1.
U
val
yval
=
if
(
padType
==
"YPad0"
)
Mux
(
dec
.
ypad_0
=/=
0.
U
,
dec
.
ypad_0
-
1.
U
,
0.
U
)
else
if
(
padType
==
"YPad1"
)
Mux
(
dec
.
ypad_1
=/=
0.
U
,
dec
.
ypad_1
-
1.
U
,
0.
U
)
else
0.
U
val
sIdle
::
sActive
::
Nil
=
Enum
(
2
)
val
state
=
RegInit
(
sIdle
)
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
start
)
{
state
:=
sActive
}
}
is
(
sActive
)
{
when
(
ycnt
===
ymax
&&
xcnt
===
xmax
)
{
state
:=
sIdle
}
}
}
when
(
state
===
sIdle
)
{
xmax
:=
xval
ymax
:=
yval
}
when
(
state
===
sIdle
||
xcnt
===
xmax
)
{
xcnt
:=
0.
U
}
.
elsewhen
(
state
===
sActive
)
{
xcnt
:=
xcnt
+
1.
U
}
when
(
state
===
sIdle
||
ymax
===
0.
U
)
{
ycnt
:=
0.
U
}
.
elsewhen
(
state
===
sActive
&&
xcnt
===
xmax
)
{
ycnt
:=
ycnt
+
1.
U
}
io
.
done
:=
state
===
sActive
&
ycnt
===
ymax
&
xcnt
===
xmax
}
/** TensorDataCtrl. Data controller for TensorLoad. */
class
TensorDataCtrl
(
sizeFactor
:
Int
=
1
,
strideFactor
:
Int
=
1
)(
implicit
p
:
Parameters
)
extends
Module
{
val
mp
=
p
(
ShellKey
).
memParams
val
io
=
IO
(
new
Bundle
{
val
start
=
Input
(
Bool
())
val
done
=
Output
(
Bool
())
val
inst
=
Input
(
UInt
(
INST_BITS
.
W
))
val
baddr
=
Input
(
UInt
(
mp
.
addrBits
.
W
))
val
xinit
=
Input
(
Bool
())
val
xupdate
=
Input
(
Bool
())
val
yupdate
=
Input
(
Bool
())
val
stride
=
Output
(
Bool
())
val
split
=
Output
(
Bool
())
val
commit
=
Output
(
Bool
())
val
addr
=
Output
(
UInt
(
mp
.
addrBits
.
W
))
val
len
=
Output
(
UInt
(
mp
.
lenBits
.
W
))
})
val
dec
=
io
.
inst
.
asTypeOf
(
new
MemDecode
)
val
caddr
=
Reg
(
UInt
(
mp
.
addrBits
.
W
))
val
baddr
=
Reg
(
UInt
(
mp
.
addrBits
.
W
))
val
len
=
Reg
(
UInt
(
mp
.
lenBits
.
W
))
val
xmax_bytes
=
((
1
<<
mp
.
lenBits
)*
mp
.
dataBits
/
8
).
U
val
xcnt
=
Reg
(
UInt
(
mp
.
lenBits
.
W
))
val
xrem
=
Reg
(
chiselTypeOf
(
dec
.
xsize
))
val
xsize
=
(
dec
.
xsize
<<
log2Ceil
(
sizeFactor
))
-
1.
U
val
xmax
=
(
1
<<
mp
.
lenBits
).
U
val
ycnt
=
Reg
(
chiselTypeOf
(
dec
.
ysize
))
val
stride
=
xcnt
===
len
&
xrem
===
0.
U
&
ycnt
=/=
dec
.
ysize
-
1.
U
val
split
=
xcnt
===
len
&
xrem
=/=
0.
U
when
(
io
.
start
||
(
io
.
xupdate
&&
stride
))
{
when
(
xsize
<
xmax
)
{
len
:=
xsize
xrem
:=
0.
U
}
.
otherwise
{
len
:=
xmax
-
1.
U
xrem
:=
xsize
-
xmax
}
}
.
elsewhen
(
io
.
xupdate
&&
split
)
{
when
(
xrem
<
xmax
)
{
len
:=
xrem
xrem
:=
0.
U
}
.
otherwise
{
len
:=
xmax
-
1.
U
xrem
:=
xrem
-
xmax
}
}
when
(
io
.
xinit
)
{
xcnt
:=
0.
U
}
.
elsewhen
(
io
.
xupdate
)
{
xcnt
:=
xcnt
+
1.
U
}
when
(
io
.
start
)
{
ycnt
:=
0.
U
}
.
elsewhen
(
io
.
yupdate
&&
stride
)
{
ycnt
:=
ycnt
+
1.
U
}
when
(
io
.
start
)
{
caddr
:=
io
.
baddr
+
dec
.
dram_offset
baddr
:=
io
.
baddr
+
dec
.
dram_offset
}
.
elsewhen
(
io
.
yupdate
)
{
when
(
split
)
{
caddr
:=
caddr
+
xmax_bytes
}
.
elsewhen
(
stride
)
{
caddr
:=
baddr
+
(
dec
.
xstride
<<
log2Ceil
(
strideFactor
))
baddr
:=
baddr
+
(
dec
.
xstride
<<
log2Ceil
(
strideFactor
))
}
}
io
.
stride
:=
stride
io
.
split
:=
split
io
.
commit
:=
xcnt
===
len
io
.
addr
:=
caddr
io
.
len
:=
len
io
.
done
:=
xcnt
===
len
&
xrem
===
0.
U
&
ycnt
===
dec
.
ysize
-
1.
U
}
vta/hardware/chisel/src/main/scala/core/package.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta
/** This trick makes ISAConstants globally available */
package
object
core
extends
vta
.
core
.
ISAConstants
vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala
View file @
32f74f31
...
...
@@ -21,6 +21,9 @@ package vta.dpi
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.interface.axi._
import
vta.shell._
/** Host DPI parameters */
trait
VTAHostDPIParams
{
...
...
@@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
})
setResource
(
"/verilog/VTAHostDPI.v"
)
}
/** Host DPI to AXI Converter.
*
* Convert Host DPI to AXI for VTAShell
*/
class
VTAHostDPIToAXI
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
dpi
=
new
VTAHostDPIClient
val
axi
=
new
AXILiteMaster
(
p
(
ShellKey
).
hostParams
)
})
val
addr
=
RegInit
(
0.
U
.
asTypeOf
(
chiselTypeOf
(
io
.
dpi
.
req
.
addr
)))
val
data
=
RegInit
(
0.
U
.
asTypeOf
(
chiselTypeOf
(
io
.
dpi
.
req
.
value
)))
val
sIdle
::
sReadAddress
::
sReadData
::
sWriteAddress
::
sWriteData
::
sWriteResponse
::
Nil
=
Enum
(
6
)
val
state
=
RegInit
(
sIdle
)
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
dpi
.
req
.
valid
)
{
when
(
io
.
dpi
.
req
.
opcode
)
{
state
:=
sWriteAddress
}
.
otherwise
{
state
:=
sReadAddress
}
}
}
is
(
sReadAddress
)
{
when
(
io
.
axi
.
ar
.
ready
)
{
state
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
axi
.
r
.
valid
)
{
state
:=
sIdle
}
}
is
(
sWriteAddress
)
{
when
(
io
.
axi
.
aw
.
ready
)
{
state
:=
sWriteData
}
}
is
(
sWriteData
)
{
when
(
io
.
axi
.
w
.
ready
)
{
state
:=
sWriteResponse
}
}
is
(
sWriteResponse
)
{
when
(
io
.
axi
.
b
.
valid
)
{
state
:=
sIdle
}
}
}
when
(
state
===
sIdle
&&
io
.
dpi
.
req
.
valid
)
{
addr
:=
io
.
dpi
.
req
.
addr
data
:=
io
.
dpi
.
req
.
value
}
io
.
axi
.
aw
.
valid
:=
state
===
sWriteAddress
io
.
axi
.
aw
.
bits
.
addr
:=
addr
io
.
axi
.
w
.
valid
:=
state
===
sWriteData
io
.
axi
.
w
.
bits
.
data
:=
data
io
.
axi
.
w
.
bits
.
strb
:=
"h_f"
.
U
io
.
axi
.
b
.
ready
:=
state
===
sWriteResponse
io
.
axi
.
ar
.
valid
:=
state
===
sReadAddress
io
.
axi
.
ar
.
bits
.
addr
:=
addr
io
.
axi
.
r
.
ready
:=
state
===
sReadData
io
.
dpi
.
req
.
deq
:=
(
state
===
sReadAddress
&
io
.
axi
.
ar
.
ready
)
|
(
state
===
sWriteAddress
&
io
.
axi
.
aw
.
ready
)
io
.
dpi
.
resp
.
valid
:=
io
.
axi
.
r
.
valid
io
.
dpi
.
resp
.
bits
:=
io
.
axi
.
r
.
bits
.
data
if
(
debug
)
{
when
(
state
===
sWriteAddress
&&
io
.
axi
.
aw
.
ready
)
{
printf
(
"[VTAHostDPIToAXI] [AW] addr:%x\n"
,
addr
)
}
when
(
state
===
sReadAddress
&&
io
.
axi
.
ar
.
ready
)
{
printf
(
"[VTAHostDPIToAXI] [AR] addr:%x\n"
,
addr
)
}
when
(
io
.
axi
.
r
.
fire
())
{
printf
(
"[VTAHostDPIToAXI] [R] value:%x\n"
,
io
.
axi
.
r
.
bits
.
data
)
}
when
(
io
.
axi
.
w
.
fire
())
{
printf
(
"[VTAHostDPIToAXI] [W] value:%x\n"
,
io
.
axi
.
w
.
bits
.
data
)
}
}
}
vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
View file @
32f74f31
...
...
@@ -21,6 +21,9 @@ package vta.dpi
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.interface.axi._
import
vta.shell._
/** Memory DPI parameters */
trait
VTAMemDPIParams
{
...
...
@@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
})
setResource
(
"/verilog/VTAMemDPI.v"
)
}
class
VTAMemDPIToAXI
(
debug
:
Boolean
=
false
)(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
dpi
=
new
VTAMemDPIMaster
val
axi
=
new
AXIClient
(
p
(
ShellKey
).
memParams
)
})
val
opcode
=
RegInit
(
false
.
B
)
val
len
=
RegInit
(
0.
U
.
asTypeOf
(
chiselTypeOf
(
io
.
dpi
.
req
.
len
)))
val
addr
=
RegInit
(
0.
U
.
asTypeOf
(
chiselTypeOf
(
io
.
dpi
.
req
.
addr
)))
val
sIdle
::
sReadAddress
::
sReadData
::
sWriteAddress
::
sWriteData
::
sWriteResponse
::
Nil
=
Enum
(
6
)
val
state
=
RegInit
(
sIdle
)
switch
(
state
)
{
is
(
sIdle
)
{
when
(
io
.
axi
.
ar
.
valid
)
{
state
:=
sReadAddress
}
.
elsewhen
(
io
.
axi
.
aw
.
valid
)
{
state
:=
sWriteAddress
}
}
is
(
sReadAddress
)
{
when
(
io
.
axi
.
ar
.
valid
)
{
state
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
axi
.
r
.
ready
&&
io
.
dpi
.
rd
.
valid
&&
len
===
0.
U
)
{
state
:=
sIdle
}
}
is
(
sWriteAddress
)
{
when
(
io
.
axi
.
aw
.
valid
)
{
state
:=
sWriteData
}
}
is
(
sWriteData
)
{
when
(
io
.
axi
.
w
.
valid
&&
io
.
axi
.
w
.
bits
.
last
)
{
state
:=
sWriteResponse
}
}
is
(
sWriteResponse
)
{
when
(
io
.
axi
.
b
.
ready
)
{
state
:=
sIdle
}
}
}
when
(
state
===
sIdle
)
{
when
(
io
.
axi
.
ar
.
valid
)
{
opcode
:=
false
.
B
len
:=
io
.
axi
.
ar
.
bits
.
len
addr
:=
io
.
axi
.
ar
.
bits
.
addr
}
.
elsewhen
(
io
.
axi
.
aw
.
valid
)
{
opcode
:=
true
.
B
len
:=
io
.
axi
.
aw
.
bits
.
len
addr
:=
io
.
axi
.
aw
.
bits
.
addr
}
}
.
elsewhen
(
state
===
sReadData
)
{
when
(
io
.
axi
.
r
.
ready
&&
io
.
dpi
.
rd
.
valid
&&
len
=/=
0.
U
)
{
len
:=
len
-
1.
U
}
}
io
.
dpi
.
req
.
valid
:=
(
state
===
sReadAddress
&
io
.
axi
.
ar
.
valid
)
|
(
state
===
sWriteAddress
&
io
.
axi
.
aw
.
valid
)
io
.
dpi
.
req
.
opcode
:=
opcode
io
.
dpi
.
req
.
len
:=
len
io
.
dpi
.
req
.
addr
:=
addr
io
.
axi
.
ar
.
ready
:=
state
===
sReadAddress
io
.
axi
.
aw
.
ready
:=
state
===
sWriteAddress
io
.
axi
.
r
.
valid
:=
state
===
sReadData
&
io
.
dpi
.
rd
.
valid
io
.
axi
.
r
.
bits
.
data
:=
io
.
dpi
.
rd
.
bits
io
.
axi
.
r
.
bits
.
last
:=
len
===
0.
U
io
.
axi
.
r
.
bits
.
resp
:=
0.
U
io
.
axi
.
r
.
bits
.
user
:=
0.
U
io
.
axi
.
r
.
bits
.
id
:=
0.
U
io
.
dpi
.
rd
.
ready
:=
state
===
sReadData
&
io
.
axi
.
r
.
ready
io
.
dpi
.
wr
.
valid
:=
state
===
sWriteData
&
io
.
axi
.
w
.
valid
io
.
dpi
.
wr
.
bits
:=
io
.
axi
.
w
.
bits
.
data
io
.
axi
.
w
.
ready
:=
state
===
sWriteData
io
.
axi
.
b
.
valid
:=
state
===
sWriteResponse
io
.
axi
.
b
.
bits
.
resp
:=
0.
U
io
.
axi
.
b
.
bits
.
user
:=
0.
U
io
.
axi
.
b
.
bits
.
id
:=
0.
U
if
(
debug
)
{
when
(
state
===
sReadAddress
&&
io
.
axi
.
ar
.
valid
)
{
printf
(
"[VTAMemDPIToAXI] [AR] addr:%x len:%x\n"
,
addr
,
len
)
}
when
(
state
===
sWriteAddress
&&
io
.
axi
.
aw
.
valid
)
{
printf
(
"[VTAMemDPIToAXI] [AW] addr:%x len:%x\n"
,
addr
,
len
)
}
when
(
io
.
axi
.
r
.
fire
())
{
printf
(
"[VTAMemDPIToAXI] [R] last:%x data:%x\n"
,
io
.
axi
.
r
.
bits
.
last
,
io
.
axi
.
r
.
bits
.
data
)
}
when
(
io
.
axi
.
w
.
fire
())
{
printf
(
"[VTAMemDPIToAXI] [W] last:%x data:%x\n"
,
io
.
axi
.
w
.
bits
.
last
,
io
.
axi
.
w
.
bits
.
data
)
}
}
}
vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.interface.axi
import
chisel3._
import
chisel3.util._
import
vta.util.genericbundle._
case
class
AXIParams
(
addrBits
:
Int
=
32
,
dataBits
:
Int
=
64
)
{
require
(
addrBits
>
0
)
require
(
dataBits
>=
8
&&
dataBits
%
2
==
0
)
val
idBits
=
1
val
userBits
=
1
val
strbBits
=
dataBits
/
8
val
lenBits
=
8
val
sizeBits
=
3
val
burstBits
=
2
val
lockBits
=
2
val
cacheBits
=
4
val
protBits
=
3
val
qosBits
=
4
val
regionBits
=
4
val
respBits
=
2
val
sizeConst
=
log2Ceil
(
dataBits
/
8
)
val
idConst
=
0
val
userConst
=
0
val
burstConst
=
1
val
lockConst
=
0
val
cacheConst
=
3
val
protConst
=
0
val
qosConst
=
0
val
regionConst
=
0
}
abstract
class
AXIBase
(
params
:
AXIParams
)
extends
GenericParameterizedBundle
(
params
)
// AXILite
class
AXILiteAddress
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
addr
=
UInt
(
params
.
addrBits
.
W
)
}
class
AXILiteWriteData
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
data
=
UInt
(
params
.
dataBits
.
W
)
val
strb
=
UInt
(
params
.
strbBits
.
W
)
}
class
AXILiteWriteResponse
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
resp
=
UInt
(
params
.
respBits
.
W
)
}
class
AXILiteReadData
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
data
=
UInt
(
params
.
dataBits
.
W
)
val
resp
=
UInt
(
params
.
respBits
.
W
)
}
class
AXILiteMaster
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
aw
=
Decoupled
(
new
AXILiteAddress
(
params
))
val
w
=
Decoupled
(
new
AXILiteWriteData
(
params
))
val
b
=
Flipped
(
Decoupled
(
new
AXILiteWriteResponse
(
params
)))
val
ar
=
Decoupled
(
new
AXILiteAddress
(
params
))
val
r
=
Flipped
(
Decoupled
(
new
AXILiteReadData
(
params
)))
def
tieoff
()
{
aw
.
valid
:=
false
.
B
aw
.
bits
.
addr
:=
0.
U
w
.
valid
:=
false
.
B
w
.
bits
.
data
:=
0.
U
w
.
bits
.
strb
:=
0.
U
b
.
ready
:=
false
.
B
ar
.
valid
:=
false
.
B
ar
.
bits
.
addr
:=
0.
U
r
.
ready
:=
false
.
B
}
}
class
AXILiteClient
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
aw
=
Flipped
(
Decoupled
(
new
AXILiteAddress
(
params
)))
val
w
=
Flipped
(
Decoupled
(
new
AXILiteWriteData
(
params
)))
val
b
=
Decoupled
(
new
AXILiteWriteResponse
(
params
))
val
ar
=
Flipped
(
Decoupled
(
new
AXILiteAddress
(
params
)))
val
r
=
Decoupled
(
new
AXILiteReadData
(
params
))
def
tieoff
()
{
aw
.
ready
:=
false
.
B
w
.
ready
:=
false
.
B
b
.
valid
:=
false
.
B
b
.
bits
.
resp
:=
0.
U
ar
.
ready
:=
false
.
B
r
.
valid
:=
false
.
B
r
.
bits
.
resp
:=
0.
U
r
.
bits
.
data
:=
0.
U
}
}
// AXI extends AXILite
class
AXIAddress
(
params
:
AXIParams
)
extends
AXILiteAddress
(
params
)
{
val
id
=
UInt
(
params
.
idBits
.
W
)
val
user
=
UInt
(
params
.
userBits
.
W
)
val
len
=
UInt
(
params
.
lenBits
.
W
)
val
size
=
UInt
(
params
.
sizeBits
.
W
)
val
burst
=
UInt
(
params
.
burstBits
.
W
)
val
lock
=
UInt
(
params
.
lockBits
.
W
)
val
cache
=
UInt
(
params
.
cacheBits
.
W
)
val
prot
=
UInt
(
params
.
protBits
.
W
)
val
qos
=
UInt
(
params
.
qosBits
.
W
)
val
region
=
UInt
(
params
.
regionBits
.
W
)
}
class
AXIWriteData
(
params
:
AXIParams
)
extends
AXILiteWriteData
(
params
)
{
val
last
=
Bool
()
val
id
=
UInt
(
params
.
idBits
.
W
)
val
user
=
UInt
(
params
.
userBits
.
W
)
}
class
AXIWriteResponse
(
params
:
AXIParams
)
extends
AXILiteWriteResponse
(
params
)
{
val
id
=
UInt
(
params
.
idBits
.
W
)
val
user
=
UInt
(
params
.
userBits
.
W
)
}
class
AXIReadData
(
params
:
AXIParams
)
extends
AXILiteReadData
(
params
)
{
val
last
=
Bool
()
val
id
=
UInt
(
params
.
idBits
.
W
)
val
user
=
UInt
(
params
.
userBits
.
W
)
}
class
AXIMaster
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
aw
=
Decoupled
(
new
AXIAddress
(
params
))
val
w
=
Decoupled
(
new
AXIWriteData
(
params
))
val
b
=
Flipped
(
Decoupled
(
new
AXIWriteResponse
(
params
)))
val
ar
=
Decoupled
(
new
AXIAddress
(
params
))
val
r
=
Flipped
(
Decoupled
(
new
AXIReadData
(
params
)))
def
tieoff
()
{
aw
.
valid
:=
false
.
B
aw
.
bits
.
addr
:=
0.
U
aw
.
bits
.
id
:=
0.
U
aw
.
bits
.
user
:=
0.
U
aw
.
bits
.
len
:=
0.
U
aw
.
bits
.
size
:=
0.
U
aw
.
bits
.
burst
:=
0.
U
aw
.
bits
.
lock
:=
0.
U
aw
.
bits
.
cache
:=
0.
U
aw
.
bits
.
prot
:=
0.
U
aw
.
bits
.
qos
:=
0.
U
aw
.
bits
.
region
:=
0.
U
w
.
valid
:=
false
.
B
w
.
bits
.
data
:=
0.
U
w
.
bits
.
strb
:=
0.
U
w
.
bits
.
last
:=
false
.
B
w
.
bits
.
id
:=
0.
U
w
.
bits
.
user
:=
0.
U
b
.
ready
:=
false
.
B
ar
.
valid
:=
false
.
B
ar
.
bits
.
addr
:=
0.
U
ar
.
bits
.
id
:=
0.
U
ar
.
bits
.
user
:=
0.
U
ar
.
bits
.
len
:=
0.
U
ar
.
bits
.
size
:=
0.
U
ar
.
bits
.
burst
:=
0.
U
ar
.
bits
.
lock
:=
0.
U
ar
.
bits
.
cache
:=
0.
U
ar
.
bits
.
prot
:=
0.
U
ar
.
bits
.
qos
:=
0.
U
ar
.
bits
.
region
:=
0.
U
r
.
ready
:=
false
.
B
}
def
setConst
()
{
aw
.
bits
.
user
:=
params
.
userConst
.
U
aw
.
bits
.
burst
:=
params
.
burstConst
.
U
aw
.
bits
.
lock
:=
params
.
lockConst
.
U
aw
.
bits
.
cache
:=
params
.
cacheConst
.
U
aw
.
bits
.
prot
:=
params
.
protConst
.
U
aw
.
bits
.
qos
:=
params
.
qosConst
.
U
aw
.
bits
.
region
:=
params
.
regionConst
.
U
aw
.
bits
.
size
:=
params
.
sizeConst
.
U
aw
.
bits
.
id
:=
params
.
idConst
.
U
w
.
bits
.
id
:=
params
.
idConst
.
U
w
.
bits
.
user
:=
params
.
userConst
.
U
w
.
bits
.
strb
:=
Fill
(
params
.
strbBits
,
true
.
B
)
ar
.
bits
.
user
:=
params
.
userConst
.
U
ar
.
bits
.
burst
:=
params
.
burstConst
.
U
ar
.
bits
.
lock
:=
params
.
lockConst
.
U
ar
.
bits
.
cache
:=
params
.
cacheConst
.
U
ar
.
bits
.
prot
:=
params
.
protConst
.
U
ar
.
bits
.
qos
:=
params
.
qosConst
.
U
ar
.
bits
.
region
:=
params
.
regionConst
.
U
ar
.
bits
.
size
:=
params
.
sizeConst
.
U
ar
.
bits
.
id
:=
params
.
idConst
.
U
}
}
class
AXIClient
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
aw
=
Flipped
(
Decoupled
(
new
AXIAddress
(
params
)))
val
w
=
Flipped
(
Decoupled
(
new
AXIWriteData
(
params
)))
val
b
=
Decoupled
(
new
AXIWriteResponse
(
params
))
val
ar
=
Flipped
(
Decoupled
(
new
AXIAddress
(
params
)))
val
r
=
Decoupled
(
new
AXIReadData
(
params
))
def
tieoff
()
{
aw
.
ready
:=
false
.
B
w
.
ready
:=
false
.
B
b
.
valid
:=
false
.
B
b
.
bits
.
resp
:=
0.
U
b
.
bits
.
user
:=
0.
U
b
.
bits
.
id
:=
0.
U
ar
.
ready
:=
false
.
B
r
.
valid
:=
false
.
B
r
.
bits
.
resp
:=
0.
U
r
.
bits
.
data
:=
0.
U
r
.
bits
.
user
:=
0.
U
r
.
bits
.
last
:=
false
.
B
r
.
bits
.
id
:=
0.
U
}
}
// XilinxAXILiteClient and XilinxAXIMaster bundles are needed
// for wrapper purposes, because the package RTL tool in Xilinx Vivado
// only allows certain name formats
class
XilinxAXILiteClient
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
AWVALID
=
Input
(
Bool
())
val
AWREADY
=
Output
(
Bool
())
val
AWADDR
=
Input
(
UInt
(
params
.
addrBits
.
W
))
val
WVALID
=
Input
(
Bool
())
val
WREADY
=
Output
(
Bool
())
val
WDATA
=
Input
(
UInt
(
params
.
dataBits
.
W
))
val
WSTRB
=
Input
(
UInt
(
params
.
strbBits
.
W
))
val
BVALID
=
Output
(
Bool
())
val
BREADY
=
Input
(
Bool
())
val
BRESP
=
Output
(
UInt
(
params
.
respBits
.
W
))
val
ARVALID
=
Input
(
Bool
())
val
ARREADY
=
Output
(
Bool
())
val
ARADDR
=
Input
(
UInt
(
params
.
addrBits
.
W
))
val
RVALID
=
Output
(
Bool
())
val
RREADY
=
Input
(
Bool
())
val
RDATA
=
Output
(
UInt
(
params
.
dataBits
.
W
))
val
RRESP
=
Output
(
UInt
(
params
.
respBits
.
W
))
}
class
XilinxAXIMaster
(
params
:
AXIParams
)
extends
AXIBase
(
params
)
{
val
AWVALID
=
Output
(
Bool
())
val
AWREADY
=
Input
(
Bool
())
val
AWADDR
=
Output
(
UInt
(
params
.
addrBits
.
W
))
val
AWID
=
Output
(
UInt
(
params
.
idBits
.
W
))
val
AWUSER
=
Output
(
UInt
(
params
.
userBits
.
W
))
val
AWLEN
=
Output
(
UInt
(
params
.
lenBits
.
W
))
val
AWSIZE
=
Output
(
UInt
(
params
.
sizeBits
.
W
))
val
AWBURST
=
Output
(
UInt
(
params
.
burstBits
.
W
))
val
AWLOCK
=
Output
(
UInt
(
params
.
lockBits
.
W
))
val
AWCACHE
=
Output
(
UInt
(
params
.
cacheBits
.
W
))
val
AWPROT
=
Output
(
UInt
(
params
.
protBits
.
W
))
val
AWQOS
=
Output
(
UInt
(
params
.
qosBits
.
W
))
val
AWREGION
=
Output
(
UInt
(
params
.
regionBits
.
W
))
val
WVALID
=
Output
(
Bool
())
val
WREADY
=
Input
(
Bool
())
val
WDATA
=
Output
(
UInt
(
params
.
dataBits
.
W
))
val
WSTRB
=
Output
(
UInt
(
params
.
strbBits
.
W
))
val
WLAST
=
Output
(
Bool
())
val
WID
=
Output
(
UInt
(
params
.
idBits
.
W
))
val
WUSER
=
Output
(
UInt
(
params
.
userBits
.
W
))
val
BVALID
=
Input
(
Bool
())
val
BREADY
=
Output
(
Bool
())
val
BRESP
=
Input
(
UInt
(
params
.
respBits
.
W
))
val
BID
=
Input
(
UInt
(
params
.
idBits
.
W
))
val
BUSER
=
Input
(
UInt
(
params
.
userBits
.
W
))
val
ARVALID
=
Output
(
Bool
())
val
ARREADY
=
Input
(
Bool
())
val
ARADDR
=
Output
(
UInt
(
params
.
addrBits
.
W
))
val
ARID
=
Output
(
UInt
(
params
.
idBits
.
W
))
val
ARUSER
=
Output
(
UInt
(
params
.
userBits
.
W
))
val
ARLEN
=
Output
(
UInt
(
params
.
lenBits
.
W
))
val
ARSIZE
=
Output
(
UInt
(
params
.
sizeBits
.
W
))
val
ARBURST
=
Output
(
UInt
(
params
.
burstBits
.
W
))
val
ARLOCK
=
Output
(
UInt
(
params
.
lockBits
.
W
))
val
ARCACHE
=
Output
(
UInt
(
params
.
cacheBits
.
W
))
val
ARPROT
=
Output
(
UInt
(
params
.
protBits
.
W
))
val
ARQOS
=
Output
(
UInt
(
params
.
qosBits
.
W
))
val
ARREGION
=
Output
(
UInt
(
params
.
regionBits
.
W
))
val
RVALID
=
Input
(
Bool
())
val
RREADY
=
Output
(
Bool
())
val
RDATA
=
Input
(
UInt
(
params
.
dataBits
.
W
))
val
RRESP
=
Input
(
UInt
(
params
.
respBits
.
W
))
val
RLAST
=
Input
(
Bool
())
val
RID
=
Input
(
UInt
(
params
.
idBits
.
W
))
val
RUSER
=
Input
(
UInt
(
params
.
userBits
.
W
))
}
vta/hardware/chisel/src/main/scala/shell/Configs.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */
class
PynqConfig
extends
Config
((
site
,
here
,
up
)
=>
{
case
ShellKey
=>
ShellParams
(
hostParams
=
AXIParams
(
addrBits
=
16
,
dataBits
=
32
),
memParams
=
AXIParams
(
addrBits
=
32
,
dataBits
=
64
),
vcrParams
=
VCRParams
(),
vmeParams
=
VMEParams
())
})
/** F1Config. Shell configuration for F1 */
class
F1Config
extends
Config
((
site
,
here
,
up
)
=>
{
case
ShellKey
=>
ShellParams
(
hostParams
=
AXIParams
(
addrBits
=
16
,
dataBits
=
32
),
memParams
=
AXIParams
(
addrBits
=
64
,
dataBits
=
64
),
vcrParams
=
VCRParams
(),
vmeParams
=
VMEParams
())
})
vta/hardware/chisel/src/main/scala/shell/SimShell.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
vta.util.config._
import
vta.interface.axi._
import
vta.shell._
import
vta.dpi._
/** VTAHost.
*
* This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test host-to-VTA communication. This module should be updated
* for testing hosts using a different bus protocol, other than AXI.
*/
class
VTAHost
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
axi
=
new
AXILiteMaster
(
p
(
ShellKey
).
hostParams
)
})
val
host_dpi
=
Module
(
new
VTAHostDPI
)
val
host_axi
=
Module
(
new
VTAHostDPIToAXI
)
host_dpi
.
io
.
reset
:=
reset
host_dpi
.
io
.
clock
:=
clock
host_axi
.
io
.
dpi
<>
host_dpi
.
io
.
dpi
io
.
axi
<>
host_axi
.
io
.
axi
}
/** VTAMem.
*
* This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test VTA-to-memory communication. This module should be updated
* for testing memories using a different bus protocol, other than AXI.
*/
class
VTAMem
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
axi
=
new
AXIClient
(
p
(
ShellKey
).
memParams
)
})
val
mem_dpi
=
Module
(
new
VTAMemDPI
)
val
mem_axi
=
Module
(
new
VTAMemDPIToAXI
)
mem_dpi
.
io
.
reset
:=
reset
mem_dpi
.
io
.
clock
:=
clock
mem_dpi
.
io
.
dpi
<>
mem_axi
.
io
.
dpi
mem_axi
.
io
.
axi
<>
io
.
axi
}
/** SimShell.
*
* The simulation shell instantiate a host and memory simulation modules and it is
* intended to be connected to the VTAShell.
*/
class
SimShell
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
mem
=
new
AXIClient
(
p
(
ShellKey
).
memParams
)
val
host
=
new
AXILiteMaster
(
p
(
ShellKey
).
hostParams
)
})
val
host
=
Module
(
new
VTAHost
)
val
mem
=
Module
(
new
VTAMem
)
io
.
mem
<>
mem
.
io
.
axi
io
.
host
<>
host
.
io
.
axi
}
vta/hardware/chisel/src/main/scala/shell/VCR.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.util.genericbundle._
import
scala.collection.mutable.ListBuffer
import
scala.collection.mutable.LinkedHashMap
import
vta.interface.axi._
/** VCR parameters.
*
* These parameters are used on VCR interfaces and modules.
*/
case
class
VCRParams
()
{
val
nValsReg
:
Int
=
1
val
nPtrsReg
:
Int
=
6
val
regBits
:
Int
=
32
val
nCtrlReg
:
Int
=
4
val
ctrlBaseAddr
:
Int
=
0
require
(
nValsReg
>
0
)
require
(
nPtrsReg
>
0
)
}
/** VCRBase. Parametrize base class. */
abstract
class
VCRBase
(
implicit
p
:
Parameters
)
extends
GenericParameterizedBundle
(
p
)
/** VCRMaster.
*
* This is the master interface used by VCR in the VTAShell to control
* the Core unit.
*/
class
VCRMaster
(
implicit
p
:
Parameters
)
extends
VCRBase
{
val
vp
=
p
(
ShellKey
).
vcrParams
val
mp
=
p
(
ShellKey
).
memParams
val
launch
=
Output
(
Bool
())
val
finish
=
Input
(
Bool
())
val
irq
=
Output
(
Bool
())
val
ptrs
=
Output
(
Vec
(
vp
.
nPtrsReg
,
UInt
(
mp
.
addrBits
.
W
)))
val
vals
=
Output
(
Vec
(
vp
.
nValsReg
,
UInt
(
vp
.
regBits
.
W
)))
}
/** VCRClient.
*
* This is the client interface used by the Core module to communicate
* to the VCR in the VTAShell.
*/
class
VCRClient
(
implicit
p
:
Parameters
)
extends
VCRBase
{
val
vp
=
p
(
ShellKey
).
vcrParams
val
mp
=
p
(
ShellKey
).
memParams
val
launch
=
Input
(
Bool
())
val
finish
=
Output
(
Bool
())
val
irq
=
Input
(
Bool
())
val
ptrs
=
Input
(
Vec
(
vp
.
nPtrsReg
,
UInt
(
mp
.
addrBits
.
W
)))
val
vals
=
Input
(
Vec
(
vp
.
nValsReg
,
UInt
(
vp
.
regBits
.
W
)))
}
/** VTA Control Registers (VCR).
*
* This unit provides control registers (32 and 64 bits) to be used by a control'
* unit, typically a host processor. These registers are read-only by the core
* at the moment but this will likely change once we add support to general purpose
* registers that could be used as event counters by the Core unit.
*/
class
VCR
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
host
=
new
AXILiteClient
(
p
(
ShellKey
).
hostParams
)
val
vcr
=
new
VCRMaster
})
val
vp
=
p
(
ShellKey
).
vcrParams
val
mp
=
p
(
ShellKey
).
memParams
val
hp
=
p
(
ShellKey
).
hostParams
// Write control (AW, W, B)
val
waddr
=
RegInit
(
"h_ffff"
.
U
(
hp
.
addrBits
.
W
))
// init with invalid address
val
wdata
=
io
.
host
.
w
.
bits
.
data
val
wstrb
=
io
.
host
.
w
.
bits
.
strb
val
wmask
=
Cat
(
Fill
(
8
,
wstrb
(
3
)),
Fill
(
8
,
wstrb
(
2
)),
Fill
(
8
,
wstrb
(
1
)),
Fill
(
8
,
wstrb
(
0
)))
val
sWriteAddress
::
sWriteData
::
sWriteResponse
::
Nil
=
Enum
(
3
)
val
wstate
=
RegInit
(
sWriteAddress
)
switch
(
wstate
)
{
is
(
sWriteAddress
)
{
when
(
io
.
host
.
aw
.
valid
)
{
wstate
:=
sWriteData
}
}
is
(
sWriteData
)
{
when
(
io
.
host
.
w
.
valid
)
{
wstate
:=
sWriteResponse
}
}
is
(
sWriteResponse
)
{
when
(
io
.
host
.
b
.
ready
)
{
wstate
:=
sWriteAddress
}
}
}
when
(
io
.
host
.
aw
.
fire
())
{
waddr
:=
io
.
host
.
aw
.
bits
.
addr
}
io
.
host
.
aw
.
ready
:=
wstate
===
sWriteAddress
io
.
host
.
w
.
ready
:=
wstate
===
sWriteData
io
.
host
.
b
.
valid
:=
wstate
===
sWriteResponse
io
.
host
.
b
.
bits
.
resp
:=
"h_0"
.
U
// read control (AR, R)
val
sReadAddress
::
sReadData
::
Nil
=
Enum
(
2
)
val
rstate
=
RegInit
(
sReadAddress
)
switch
(
rstate
)
{
is
(
sReadAddress
)
{
when
(
io
.
host
.
ar
.
valid
)
{
rstate
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
host
.
r
.
ready
)
{
rstate
:=
sReadAddress
}
}
}
io
.
host
.
ar
.
ready
:=
rstate
===
sReadAddress
io
.
host
.
r
.
valid
:=
rstate
===
sReadData
val
nPtrsReg
=
vp
.
nPtrsReg
val
nValsReg
=
vp
.
nValsReg
val
regBits
=
vp
.
regBits
val
ptrsBits
=
mp
.
addrBits
val
nCtrlReg
=
vp
.
nCtrlReg
val
rStride
=
regBits
/
8
val
pStride
=
ptrsBits
/
8
val
ctrlBaseAddr
=
vp
.
ctrlBaseAddr
val
valsBaseAddr
=
ctrlBaseAddr
+
nCtrlReg
*
rStride
val
ptrsBaseAddr
=
valsBaseAddr
+
nValsReg
*
rStride
val
ctrlAddr
=
Seq
.
tabulate
(
nCtrlReg
)(
i
=>
i
*
rStride
+
ctrlBaseAddr
)
val
valsAddr
=
Seq
.
tabulate
(
nValsReg
)(
i
=>
i
*
rStride
+
valsBaseAddr
)
val
ptrsAddr
=
new
ListBuffer
[
Int
]()
for
(
i
<-
0
until
nPtrsReg
)
{
ptrsAddr
+=
i
*
pStride
+
ptrsBaseAddr
if
(
ptrsBits
==
64
)
{
ptrsAddr
+=
i
*
pStride
+
rStride
+
ptrsBaseAddr
}
}
// AP register
val
c0
=
RegInit
(
VecInit
(
Seq
.
fill
(
regBits
)(
false
.
B
)))
// ap start
when
(
io
.
host
.
w
.
fire
()
&&
waddr
===
ctrlAddr
(
0
).
asUInt
&&
wstrb
(
0
)
&&
wdata
(
0
))
{
c0
(
0
)
:=
true
.
B
}
.
elsewhen
(
io
.
vcr
.
finish
)
{
c0
(
0
)
:=
false
.
B
}
// ap done = finish
when
(
io
.
vcr
.
finish
)
{
c0
(
1
)
:=
true
.
B
}
.
elsewhen
(
io
.
host
.
ar
.
fire
()
&&
io
.
host
.
ar
.
bits
.
addr
===
ctrlAddr
(
0
).
asUInt
)
{
c0
(
1
)
:=
false
.
B
}
val
c1
=
0.
U
val
c2
=
0.
U
val
c3
=
0.
U
val
ctrlRegList
=
List
(
c0
,
c1
,
c2
,
c3
)
io
.
vcr
.
launch
:=
c0
(
0
)
// interrupts not supported atm
io
.
vcr
.
irq
:=
false
.
B
// Write pointer and value registers
val
pvAddr
=
valsAddr
++
ptrsAddr
val
pvNumReg
=
if
(
ptrsBits
==
64
)
nValsReg
+
nPtrsReg
*
2
else
nValsReg
+
nPtrsReg
val
pvReg
=
RegInit
(
VecInit
(
Seq
.
fill
(
pvNumReg
)(
0.
U
(
regBits
.
W
))))
val
pvRegList
=
new
ListBuffer
[
UInt
]()
for
(
i
<-
0
until
pvNumReg
)
{
when
(
io
.
host
.
w
.
fire
()
&&
(
waddr
===
pvAddr
(
i
).
U
))
{
pvReg
(
i
)
:=
(
wdata
&
wmask
)
|
(
pvReg
(
i
)
&
~
wmask
)
}
pvRegList
+=
pvReg
(
i
)
}
for
(
i
<-
0
until
nValsReg
)
{
io
.
vcr
.
vals
(
i
)
:=
pvReg
(
i
)
}
for
(
i
<-
0
until
nPtrsReg
)
{
if
(
ptrsBits
==
64
)
{
io
.
vcr
.
ptrs
(
i
)
:=
Cat
(
pvReg
(
nValsReg
+
i
*
2
+
1
),
pvReg
(
nValsReg
+
i
*
2
))
}
else
{
io
.
vcr
.
ptrs
(
i
)
:=
pvReg
(
nValsReg
+
i
)
}
}
// Read pointer and value registers
val
mapAddr
=
ctrlAddr
++
valsAddr
++
ptrsAddr
val
mapRegList
=
ctrlRegList
++
pvRegList
val
rdata
=
RegInit
(
0.
U
(
regBits
.
W
))
val
rmap
=
LinkedHashMap
[
Int
,
UInt
]()
val
totalReg
=
mapRegList
.
length
for
(
i
<-
0
until
totalReg
)
{
rmap
+=
mapAddr
(
i
)
->
mapRegList
(
i
).
asUInt
}
val
decodeAddr
=
rmap
map
{
case
(
k
,
_
)
=>
k
->
(
io
.
host
.
ar
.
bits
.
addr
===
k
.
asUInt
)
}
when
(
io
.
host
.
ar
.
fire
())
{
rdata
:=
Mux1H
(
for
((
k
,
v
)
<-
rmap
)
yield
decodeAddr
(
k
)
->
v
)
}
io
.
host
.
r
.
bits
.
resp
:=
0.
U
io
.
host
.
r
.
bits
.
data
:=
rdata
}
vta/hardware/chisel/src/main/scala/shell/VME.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
chisel3.util._
import
vta.util.config._
import
vta.util.genericbundle._
import
vta.interface.axi._
/** VME parameters.
*
* These parameters are used on VME interfaces and modules.
*/
case
class
VMEParams
()
{
val
nReadClients
:
Int
=
5
val
nWriteClients
:
Int
=
1
require
(
nReadClients
>
0
,
s
"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n"
)
require
(
nWriteClients
==
1
,
s
"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n"
)
}
/** VMEBase. Parametrize base class. */
abstract
class
VMEBase
(
implicit
p
:
Parameters
)
extends
GenericParameterizedBundle
(
p
)
/** VMECmd.
*
* This interface is used for creating write and read requests to memory.
*/
class
VMECmd
(
implicit
p
:
Parameters
)
extends
VMEBase
{
val
addrBits
=
p
(
ShellKey
).
memParams
.
addrBits
val
lenBits
=
p
(
ShellKey
).
memParams
.
lenBits
val
addr
=
UInt
(
addrBits
.
W
)
val
len
=
UInt
(
lenBits
.
W
)
}
/** VMEReadMaster.
*
* This interface is used by modules inside the core to generate read requests
* and receive responses from VME.
*/
class
VMEReadMaster
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
dataBits
=
p
(
ShellKey
).
memParams
.
dataBits
val
cmd
=
Decoupled
(
new
VMECmd
)
val
data
=
Flipped
(
Decoupled
(
UInt
(
dataBits
.
W
)))
override
def
cloneType
=
new
VMEReadMaster
().
asInstanceOf
[
this.
type
]
}
/** VMEReadClient.
*
* This interface is used by the VME to receive read requests and generate
* responses to modules inside the core.
*/
class
VMEReadClient
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
dataBits
=
p
(
ShellKey
).
memParams
.
dataBits
val
cmd
=
Flipped
(
Decoupled
(
new
VMECmd
))
val
data
=
Decoupled
(
UInt
(
dataBits
.
W
))
override
def
cloneType
=
new
VMEReadClient
().
asInstanceOf
[
this.
type
]
}
/** VMEWriteMaster.
*
* This interface is used by modules inside the core to generate write requests
* to the VME.
*/
class
VMEWriteMaster
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
dataBits
=
p
(
ShellKey
).
memParams
.
dataBits
val
cmd
=
Decoupled
(
new
VMECmd
)
val
data
=
Decoupled
(
UInt
(
dataBits
.
W
))
val
ack
=
Input
(
Bool
())
override
def
cloneType
=
new
VMEWriteMaster
().
asInstanceOf
[
this.
type
]
}
/** VMEWriteClient.
*
* This interface is used by the VME to handle write requests from modules inside
* the core.
*/
class
VMEWriteClient
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
dataBits
=
p
(
ShellKey
).
memParams
.
dataBits
val
cmd
=
Flipped
(
Decoupled
(
new
VMECmd
))
val
data
=
Flipped
(
Decoupled
(
UInt
(
dataBits
.
W
)))
val
ack
=
Output
(
Bool
())
override
def
cloneType
=
new
VMEWriteClient
().
asInstanceOf
[
this.
type
]
}
/** VMEMaster.
*
* Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
* interfaces.
*/
class
VMEMaster
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
nRd
=
p
(
ShellKey
).
vmeParams
.
nReadClients
val
nWr
=
p
(
ShellKey
).
vmeParams
.
nWriteClients
val
rd
=
Vec
(
nRd
,
new
VMEReadMaster
)
val
wr
=
Vec
(
nWr
,
new
VMEWriteMaster
)
}
/** VMEClient.
*
* Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
* interfaces.
*/
class
VMEClient
(
implicit
p
:
Parameters
)
extends
Bundle
{
val
nRd
=
p
(
ShellKey
).
vmeParams
.
nReadClients
val
nWr
=
p
(
ShellKey
).
vmeParams
.
nWriteClients
val
rd
=
Vec
(
nRd
,
new
VMEReadClient
)
val
wr
=
Vec
(
nWr
,
new
VMEWriteClient
)
}
/** VTA Memory Engine (VME).
*
* This unit multiplexes the memory controller interface for the Core. Currently,
* it supports single-writer and multiple-reader mode and it is also based on AXI.
*/
class
VME
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
mem
=
new
AXIMaster
(
p
(
ShellKey
).
memParams
)
val
vme
=
new
VMEClient
})
val
nReadClients
=
p
(
ShellKey
).
vmeParams
.
nReadClients
val
rd_arb
=
Module
(
new
Arbiter
(
new
VMECmd
,
nReadClients
))
val
rd_arb_chosen
=
RegEnable
(
rd_arb
.
io
.
chosen
,
rd_arb
.
io
.
out
.
fire
())
for
(
i
<-
0
until
nReadClients
)
{
rd_arb
.
io
.
in
(
i
)
<>
io
.
vme
.
rd
(
i
).
cmd
}
val
sReadIdle
::
sReadAddr
::
sReadData
::
Nil
=
Enum
(
3
)
val
rstate
=
RegInit
(
sReadIdle
)
switch
(
rstate
)
{
is
(
sReadIdle
)
{
when
(
rd_arb
.
io
.
out
.
valid
)
{
rstate
:=
sReadAddr
}
}
is
(
sReadAddr
)
{
when
(
io
.
mem
.
ar
.
ready
)
{
rstate
:=
sReadData
}
}
is
(
sReadData
)
{
when
(
io
.
mem
.
r
.
fire
()
&&
io
.
mem
.
r
.
bits
.
last
)
{
rstate
:=
sReadIdle
}
}
}
val
sWriteIdle
::
sWriteAddr
::
sWriteData
::
sWriteResp
::
Nil
=
Enum
(
4
)
val
wstate
=
RegInit
(
sWriteIdle
)
val
addrBits
=
p
(
ShellKey
).
memParams
.
addrBits
val
lenBits
=
p
(
ShellKey
).
memParams
.
lenBits
val
wr_cnt
=
RegInit
(
0.
U
(
lenBits
.
W
))
when
(
wstate
===
sWriteIdle
)
{
wr_cnt
:=
0.
U
}
.
elsewhen
(
io
.
mem
.
w
.
fire
())
{
wr_cnt
:=
wr_cnt
+
1.
U
}
switch
(
wstate
)
{
is
(
sWriteIdle
)
{
when
(
io
.
vme
.
wr
(
0
).
cmd
.
valid
)
{
wstate
:=
sWriteAddr
}
}
is
(
sWriteAddr
)
{
when
(
io
.
mem
.
aw
.
ready
)
{
wstate
:=
sWriteData
}
}
is
(
sWriteData
)
{
when
(
io
.
mem
.
w
.
ready
&&
wr_cnt
===
io
.
vme
.
wr
(
0
).
cmd
.
bits
.
len
)
{
wstate
:=
sWriteResp
}
}
is
(
sWriteResp
)
{
when
(
io
.
mem
.
b
.
valid
)
{
wstate
:=
sWriteIdle
}
}
}
// registers storing read/write cmds
val
rd_len
=
RegInit
(
0.
U
(
lenBits
.
W
))
val
wr_len
=
RegInit
(
0.
U
(
lenBits
.
W
))
val
rd_addr
=
RegInit
(
0.
U
(
addrBits
.
W
))
val
wr_addr
=
RegInit
(
0.
U
(
addrBits
.
W
))
when
(
rd_arb
.
io
.
out
.
fire
())
{
rd_len
:=
rd_arb
.
io
.
out
.
bits
.
len
rd_addr
:=
rd_arb
.
io
.
out
.
bits
.
addr
}
when
(
io
.
vme
.
wr
(
0
).
cmd
.
fire
())
{
wr_len
:=
io
.
vme
.
wr
(
0
).
cmd
.
bits
.
len
wr_addr
:=
io
.
vme
.
wr
(
0
).
cmd
.
bits
.
addr
}
// rd arb
rd_arb
.
io
.
out
.
ready
:=
rstate
===
sReadIdle
// vme
for
(
i
<-
0
until
nReadClients
)
{
io
.
vme
.
rd
(
i
).
data
.
valid
:=
rd_arb_chosen
===
i
.
asUInt
&
io
.
mem
.
r
.
valid
io
.
vme
.
rd
(
i
).
data
.
bits
:=
io
.
mem
.
r
.
bits
.
data
}
io
.
vme
.
wr
(
0
).
cmd
.
ready
:=
wstate
===
sWriteIdle
io
.
vme
.
wr
(
0
).
ack
:=
io
.
mem
.
b
.
fire
()
io
.
vme
.
wr
(
0
).
data
.
ready
:=
wstate
===
sWriteData
&
io
.
mem
.
w
.
ready
// mem
io
.
mem
.
aw
.
valid
:=
wstate
===
sWriteAddr
io
.
mem
.
aw
.
bits
.
addr
:=
wr_addr
io
.
mem
.
aw
.
bits
.
len
:=
wr_len
io
.
mem
.
w
.
valid
:=
wstate
===
sWriteData
&
io
.
vme
.
wr
(
0
).
data
.
valid
io
.
mem
.
w
.
bits
.
data
:=
io
.
vme
.
wr
(
0
).
data
.
bits
io
.
mem
.
w
.
bits
.
last
:=
wr_cnt
===
io
.
vme
.
wr
(
0
).
cmd
.
bits
.
len
io
.
mem
.
b
.
ready
:=
wstate
===
sWriteResp
io
.
mem
.
ar
.
valid
:=
rstate
===
sReadAddr
io
.
mem
.
ar
.
bits
.
addr
:=
rd_addr
io
.
mem
.
ar
.
bits
.
len
:=
rd_len
io
.
mem
.
r
.
ready
:=
rstate
===
sReadData
&
io
.
vme
.
rd
(
rd_arb_chosen
).
data
.
ready
// AXI constants - statically defined
io
.
mem
.
setConst
()
}
vta/hardware/chisel/src/main/scala/shell/VTAShell.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
vta.util.config._
import
vta.interface.axi._
import
vta.core._
/** Shell parameters. */
case
class
ShellParams
(
hostParams
:
AXIParams
,
memParams
:
AXIParams
,
vcrParams
:
VCRParams
,
vmeParams
:
VMEParams
)
case
object
ShellKey
extends
Field
[
ShellParams
]
/** VTAShell.
*
* The VTAShell is based on a VME, VCR and core. This creates a complete VTA
* system that can be used for simulation or real hardware.
*/
class
VTAShell
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{
val
host
=
new
AXILiteClient
(
p
(
ShellKey
).
hostParams
)
val
mem
=
new
AXIMaster
(
p
(
ShellKey
).
memParams
)
})
val
vcr
=
Module
(
new
VCR
)
val
vme
=
Module
(
new
VME
)
val
core
=
Module
(
new
Core
)
core
.
io
.
vcr
<>
vcr
.
io
.
vcr
vme
.
io
.
vme
<>
core
.
io
.
vme
vcr
.
io
.
host
<>
io
.
host
io
.
mem
<>
vme
.
io
.
mem
}
vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.shell
import
chisel3._
import
chisel3.experimental.
{
RawModule
,
withClockAndReset
}
import
vta.util.config._
import
vta.interface.axi._
/** XilinxShell.
*
* This is a wrapper shell mostly used to match Xilinx convention naming,
* therefore we can pack VTA as an IP for IPI based flows.
*/
class
XilinxShell
(
implicit
p
:
Parameters
)
extends
RawModule
{
val
hp
=
p
(
ShellKey
).
hostParams
val
mp
=
p
(
ShellKey
).
memParams
val
ap_clk
=
IO
(
Input
(
Clock
()))
val
ap_rst_n
=
IO
(
Input
(
Bool
()))
val
m_axi_gmem
=
IO
(
new
XilinxAXIMaster
(
mp
))
val
s_axi_control
=
IO
(
new
XilinxAXILiteClient
(
hp
))
val
shell
=
withClockAndReset
(
clock
=
ap_clk
,
reset
=
~
ap_rst_n
)
{
Module
(
new
VTAShell
)
}
// memory
m_axi_gmem
.
AWVALID
:=
shell
.
io
.
mem
.
aw
.
valid
shell
.
io
.
mem
.
aw
.
ready
:=
m_axi_gmem
.
AWREADY
m_axi_gmem
.
AWADDR
:=
shell
.
io
.
mem
.
aw
.
bits
.
addr
m_axi_gmem
.
AWID
:=
shell
.
io
.
mem
.
aw
.
bits
.
id
m_axi_gmem
.
AWUSER
:=
shell
.
io
.
mem
.
aw
.
bits
.
user
m_axi_gmem
.
AWLEN
:=
shell
.
io
.
mem
.
aw
.
bits
.
len
m_axi_gmem
.
AWSIZE
:=
shell
.
io
.
mem
.
aw
.
bits
.
size
m_axi_gmem
.
AWBURST
:=
shell
.
io
.
mem
.
aw
.
bits
.
burst
m_axi_gmem
.
AWLOCK
:=
shell
.
io
.
mem
.
aw
.
bits
.
lock
m_axi_gmem
.
AWCACHE
:=
shell
.
io
.
mem
.
aw
.
bits
.
cache
m_axi_gmem
.
AWPROT
:=
shell
.
io
.
mem
.
aw
.
bits
.
prot
m_axi_gmem
.
AWQOS
:=
shell
.
io
.
mem
.
aw
.
bits
.
qos
m_axi_gmem
.
AWREGION
:=
shell
.
io
.
mem
.
aw
.
bits
.
region
m_axi_gmem
.
WVALID
:=
shell
.
io
.
mem
.
w
.
valid
shell
.
io
.
mem
.
w
.
ready
:=
m_axi_gmem
.
WREADY
m_axi_gmem
.
WDATA
:=
shell
.
io
.
mem
.
w
.
bits
.
data
m_axi_gmem
.
WSTRB
:=
shell
.
io
.
mem
.
w
.
bits
.
strb
m_axi_gmem
.
WLAST
:=
shell
.
io
.
mem
.
w
.
bits
.
last
m_axi_gmem
.
WID
:=
shell
.
io
.
mem
.
w
.
bits
.
id
m_axi_gmem
.
WUSER
:=
shell
.
io
.
mem
.
w
.
bits
.
user
shell
.
io
.
mem
.
b
.
valid
:=
m_axi_gmem
.
BVALID
m_axi_gmem
.
BREADY
:=
shell
.
io
.
mem
.
b
.
valid
shell
.
io
.
mem
.
b
.
bits
.
resp
:=
m_axi_gmem
.
BRESP
shell
.
io
.
mem
.
b
.
bits
.
id
:=
m_axi_gmem
.
BID
shell
.
io
.
mem
.
b
.
bits
.
user
:=
m_axi_gmem
.
BUSER
m_axi_gmem
.
ARVALID
:=
shell
.
io
.
mem
.
ar
.
valid
shell
.
io
.
mem
.
ar
.
ready
:=
m_axi_gmem
.
ARREADY
m_axi_gmem
.
ARADDR
:=
shell
.
io
.
mem
.
ar
.
bits
.
addr
m_axi_gmem
.
ARID
:=
shell
.
io
.
mem
.
ar
.
bits
.
id
m_axi_gmem
.
ARUSER
:=
shell
.
io
.
mem
.
ar
.
bits
.
user
m_axi_gmem
.
ARLEN
:=
shell
.
io
.
mem
.
ar
.
bits
.
len
m_axi_gmem
.
ARSIZE
:=
shell
.
io
.
mem
.
ar
.
bits
.
size
m_axi_gmem
.
ARBURST
:=
shell
.
io
.
mem
.
ar
.
bits
.
burst
m_axi_gmem
.
ARLOCK
:=
shell
.
io
.
mem
.
ar
.
bits
.
lock
m_axi_gmem
.
ARCACHE
:=
shell
.
io
.
mem
.
ar
.
bits
.
cache
m_axi_gmem
.
ARPROT
:=
shell
.
io
.
mem
.
ar
.
bits
.
prot
m_axi_gmem
.
ARQOS
:=
shell
.
io
.
mem
.
ar
.
bits
.
qos
m_axi_gmem
.
ARREGION
:=
shell
.
io
.
mem
.
ar
.
bits
.
region
shell
.
io
.
mem
.
r
.
valid
:=
m_axi_gmem
.
RVALID
m_axi_gmem
.
RREADY
:=
shell
.
io
.
mem
.
r
.
ready
shell
.
io
.
mem
.
r
.
bits
.
data
:=
m_axi_gmem
.
RDATA
shell
.
io
.
mem
.
r
.
bits
.
resp
:=
m_axi_gmem
.
RRESP
shell
.
io
.
mem
.
r
.
bits
.
last
:=
m_axi_gmem
.
RLAST
shell
.
io
.
mem
.
r
.
bits
.
id
:=
m_axi_gmem
.
RID
shell
.
io
.
mem
.
r
.
bits
.
user
:=
m_axi_gmem
.
RUSER
// host
shell
.
io
.
host
.
aw
.
valid
:=
s_axi_control
.
AWVALID
s_axi_control
.
AWREADY
:=
shell
.
io
.
host
.
aw
.
ready
shell
.
io
.
host
.
aw
.
bits
.
addr
:=
s_axi_control
.
AWADDR
shell
.
io
.
host
.
w
.
valid
:=
s_axi_control
.
WVALID
s_axi_control
.
WREADY
:=
shell
.
io
.
host
.
w
.
ready
shell
.
io
.
host
.
w
.
bits
.
data
:=
s_axi_control
.
WDATA
shell
.
io
.
host
.
w
.
bits
.
strb
:=
s_axi_control
.
WSTRB
s_axi_control
.
BVALID
:=
shell
.
io
.
host
.
b
.
valid
shell
.
io
.
host
.
b
.
ready
:=
s_axi_control
.
BREADY
s_axi_control
.
BRESP
:=
shell
.
io
.
host
.
b
.
bits
.
resp
shell
.
io
.
host
.
ar
.
valid
:=
s_axi_control
.
ARVALID
s_axi_control
.
ARREADY
:=
shell
.
io
.
host
.
ar
.
ready
shell
.
io
.
host
.
ar
.
bits
.
addr
:=
s_axi_control
.
ARADDR
s_axi_control
.
RVALID
:=
shell
.
io
.
host
.
r
.
valid
shell
.
io
.
host
.
r
.
ready
:=
s_axi_control
.
RREADY
s_axi_control
.
RDATA
:=
shell
.
io
.
host
.
r
.
bits
.
data
s_axi_control
.
RRESP
:=
shell
.
io
.
host
.
r
.
bits
.
resp
}
vta/hardware/chisel/src/main/scala/test/Test.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.test
import
chisel3._
import
vta.util.config._
import
vta.shell._
/** Test. This generates a testbench file for simulation */
class
Test
(
implicit
p
:
Parameters
)
extends
Module
{
val
io
=
IO
(
new
Bundle
{})
val
sim_shell
=
Module
(
new
SimShell
)
val
vta_shell
=
Module
(
new
VTAShell
)
vta_shell
.
io
.
host
<>
sim_shell
.
io
.
host
sim_shell
.
io
.
mem
<>
vta_shell
.
io
.
mem
}
vta/hardware/chisel/src/main/scala/util/Config.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.util.config
// taken from https://github.com/vta.roject/rocket-chip
abstract
class
Field
[
T
]
private
(
val
default
:
Option
[
T
])
{
def
this
()
=
this
(
None
)
def
this
(
default
:
T
)
=
this
(
Some
(
default
))
}
abstract
class
View
{
final
def
apply
[
T
](
pname
:
Field
[
T
])
:
T
=
apply
(
pname
,
this
)
final
def
apply
[
T
](
pname
:
Field
[
T
],
site
:
View
)
:
T
=
{
val
out
=
find
(
pname
,
site
)
require
(
out
.
isDefined
,
s
"Key ${pname} is not defined in Parameters"
)
out
.
get
}
final
def
lift
[
T
](
pname
:
Field
[
T
])
:
Option
[
T
]
=
lift
(
pname
,
this
)
final
def
lift
[
T
](
pname
:
Field
[
T
],
site
:
View
)
:
Option
[
T
]
=
find
(
pname
,
site
).
map
(
_
.
asInstanceOf
[
T
])
protected
[
config
]
def
find
[
T
](
pname
:
Field
[
T
],
site
:
View
)
:
Option
[
T
]
}
abstract
class
Parameters
extends
View
{
final
def
++
(
x
:
Parameters
)
:
Parameters
=
new
ChainParameters
(
this
,
x
)
final
def
alter
(
f
:
(
View
,
View
,
View
)
=>
PartialFunction
[
Any
,
Any
])
:
Parameters
=
Parameters
(
f
)
++
this
final
def
alterPartial
(
f
:
PartialFunction
[
Any
,
Any
])
:
Parameters
=
Parameters
((
_
,
_
,
_
)
=>
f
)
++
this
final
def
alterMap
(
m
:
Map
[
Any
,
Any
])
:
Parameters
=
new
MapParameters
(
m
)
++
this
protected
[
config
]
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
:
Option
[
T
]
protected
[
config
]
def
find
[
T
](
pname
:
Field
[
T
],
site
:
View
)
=
chain
(
site
,
new
TerminalView
,
pname
)
}
object
Parameters
{
def
empty
:
Parameters
=
new
EmptyParameters
def
apply
(
f
:
(
View
,
View
,
View
)
=>
PartialFunction
[
Any
,
Any
])
:
Parameters
=
new
PartialParameters
(
f
)
}
class
Config
(
p
:
Parameters
)
extends
Parameters
{
def
this
(
f
:
(
View
,
View
,
View
)
=>
PartialFunction
[
Any
,
Any
])
=
this
(
Parameters
(
f
))
protected
[
config
]
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
=
p
.
chain
(
site
,
tail
,
pname
)
override
def
toString
=
this
.
getClass
.
getSimpleName
def
toInstance
=
this
}
// Internal implementation:
private
class
TerminalView
extends
View
{
def
find
[
T
](
pname
:
Field
[
T
],
site
:
View
)
:
Option
[
T
]
=
pname
.
default
}
private
class
ChainView
(
head
:
Parameters
,
tail
:
View
)
extends
View
{
def
find
[
T
](
pname
:
Field
[
T
],
site
:
View
)
=
head
.
chain
(
site
,
tail
,
pname
)
}
private
class
ChainParameters
(
x
:
Parameters
,
y
:
Parameters
)
extends
Parameters
{
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
=
x
.
chain
(
site
,
new
ChainView
(
y
,
tail
),
pname
)
}
private
class
EmptyParameters
extends
Parameters
{
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
=
tail
.
find
(
pname
,
site
)
}
private
class
PartialParameters
(
f
:
(
View
,
View
,
View
)
=>
PartialFunction
[
Any
,
Any
])
extends
Parameters
{
protected
[
config
]
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
=
{
val
g
=
f
(
site
,
this
,
tail
)
if
(
g
.
isDefinedAt
(
pname
))
Some
(
g
.
apply
(
pname
).
asInstanceOf
[
T
])
else
tail
.
find
(
pname
,
site
)
}
}
private
class
MapParameters
(
map
:
Map
[
Any
,
Any
])
extends
Parameters
{
protected
[
config
]
def
chain
[
T
](
site
:
View
,
tail
:
View
,
pname
:
Field
[
T
])
=
{
val
g
=
map
.
get
(
pname
)
if
(
g
.
isDefined
)
Some
(
g
.
get
.
asInstanceOf
[
T
])
else
tail
.
find
(
pname
,
site
)
}
}
vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta.util.genericbundle
// taken from https://github.com/vta.roject/rocket-chip
import
chisel3._
abstract
class
GenericParameterizedBundle
[
+T
<:
Object
](
val
params
:
T
)
extends
Bundle
{
override
def
cloneType
=
{
try
{
this
.
getClass
.
getConstructors
.
head
.
newInstance
(
params
).
asInstanceOf
[
this.
type
]
}
catch
{
case
e
:
java.lang.IllegalArgumentException
=>
throw
new
Exception
(
"Unable to use GenericParameterizedBundle.cloneType on "
+
this
.
getClass
+
", probably because "
+
this
.
getClass
+
"() takes more than one argument. Consider overriding "
+
"cloneType() on "
+
this
.
getClass
,
e
)
}
}
}
vta/hardware/chisel/src/main/scala/vta/Configs.scala
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package
vta
import
chisel3._
import
vta.util.config._
import
vta.shell._
import
vta.core._
import
vta.test._
/** VTA.
*
* This file contains all the configurations supported by VTA.
* These configurations are built in a mix/match form based on core
* and shell configurations.
*/
class
DefaultPynqConfig
extends
Config
(
new
CoreConfig
++
new
PynqConfig
)
class
DefaultF1Config
extends
Config
(
new
CoreConfig
++
new
F1Config
)
object
DefaultPynqConfig
extends
App
{
implicit
val
p
:
Parameters
=
new
DefaultPynqConfig
chisel3
.
Driver
.
execute
(
args
,
()
=>
new
XilinxShell
)
}
object
DefaultF1Config
extends
App
{
implicit
val
p
:
Parameters
=
new
DefaultF1Config
chisel3
.
Driver
.
execute
(
args
,
()
=>
new
XilinxShell
)
}
object
TestDefaultF1Config
extends
App
{
implicit
val
p
:
Parameters
=
new
DefaultF1Config
chisel3
.
Driver
.
execute
(
args
,
()
=>
new
Test
)
}
vta/hardware/dpi/tsim_device.cc
View file @
32f74f31
...
...
@@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle,
_mem_dpi
=
mem_dpi
;
}
// Override Verilator finish definition
// VL_USER_FINISH needs to be defined when compiling Verilator code
void
vl_finish
(
const
char
*
filename
,
int
linenum
,
const
char
*
hier
)
{
Verilated
::
gotFinish
(
true
);
VL_PRINTF
(
"[TSIM] exiting simulation
\n
"
);
}
int
VTADPISim
(
uint64_t
max_cycles
)
{
uint64_t
trace_count
=
0
;
Verilated
::
flushCall
();
Verilated
::
gotFinish
(
false
);
#if VM_TRACE
uint64_t
start
=
0
;
...
...
vta/include/vta/driver.h
View file @
32f74f31
...
...
@@ -53,7 +53,11 @@ extern "C" {
typedef
void
*
VTADeviceHandle
;
/*! \brief physical address */
#ifdef USE_TSIM
typedef
uint64_t
vta_phy_addr_t
;
#else
typedef
uint32_t
vta_phy_addr_t
;
#endif
/*!
* \brief Allocate a device resource handle
...
...
@@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle);
*
* \return 0 if running is successful, 1 if timeout.
*/
#ifdef USE_TSIM
int
VTADeviceRun
(
VTADeviceHandle
device
,
vta_phy_addr_t
insn_phy_addr
,
vta_phy_addr_t
uop_phy_addr
,
vta_phy_addr_t
inp_phy_addr
,
vta_phy_addr_t
wgt_phy_addr
,
vta_phy_addr_t
acc_phy_addr
,
vta_phy_addr_t
out_phy_addr
,
uint32_t
insn_count
,
uint32_t
wait_cycles
);
#else
int
VTADeviceRun
(
VTADeviceHandle
device
,
vta_phy_addr_t
insn_phy_addr
,
uint32_t
insn_count
,
uint32_t
wait_cycles
);
#endif
/*!
* \brief Allocates physically contiguous region in memory (limited by MAX_XFER).
...
...
vta/python/vta/environment.py
View file @
32f74f31
...
...
@@ -239,7 +239,7 @@ class Environment(object):
"""The target host"""
if
self
.
TARGET
==
"pynq"
:
return
"llvm -target=armv7-none-linux-gnueabihf"
if
self
.
TARGET
==
"sim"
:
if
self
.
TARGET
==
"sim"
or
self
.
TARGET
==
"tsim"
:
return
"llvm"
raise
ValueError
(
"Unknown target
%
s"
%
self
.
TARGET
)
...
...
vta/python/vta/testing/simulator.py
View file @
32f74f31
...
...
@@ -17,6 +17,8 @@
"""Utilities to start simulator."""
import
ctypes
import
json
import
sys
import
os
import
tvm
from
..libinfo
import
find_libvta
...
...
@@ -55,5 +57,22 @@ def stats():
x
=
tvm
.
get_global_func
(
"vta.simulator.profiler_status"
)()
return
json
.
loads
(
x
)
def
tsim_init
(
hw_lib
):
"""Init hardware shared library for TSIM
Parameters
------------
hw_lib : str
Name of hardware shared library
"""
cur_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
__file__
)))
vta_build_path
=
os
.
path
.
join
(
cur_path
,
".."
,
".."
,
".."
,
"build"
)
if
not
hw_lib
.
endswith
((
"dylib"
,
"so"
)):
hw_lib
+=
".dylib"
if
sys
.
platform
==
"darwin"
else
".so"
lib
=
os
.
path
.
join
(
vta_build_path
,
hw_lib
)
f
=
tvm
.
get_global_func
(
"tvm.vta.tsim.init"
)
m
=
tvm
.
module
.
load
(
lib
,
"vta-tsim"
)
f
(
m
)
LIBS
=
_load_lib
()
vta/python/vta/testing/util.py
View file @
32f74f31
...
...
@@ -31,7 +31,7 @@ def run(run_func):
"""
env
=
get_env
()
if
env
.
TARGET
==
"sim"
:
if
env
.
TARGET
in
[
"sim"
,
"tsim"
]
:
# Talk to local RPC if necessary to debug RPC server.
# Compile vta on your host with make at the root.
...
...
@@ -48,7 +48,8 @@ def run(run_func):
# Make sure simulation library exists
# If this fails, build vta on host (make)
# with TARGET="sim" in the json.config file.
assert
simulator
.
enabled
()
if
env
.
TARGET
==
"sim"
:
assert
simulator
.
enabled
()
run_func
(
env
,
rpc
.
LocalSession
())
elif
env
.
TARGET
==
"pynq"
:
...
...
vta/src/runtime.cc
View file @
32f74f31
...
...
@@ -56,7 +56,7 @@ struct DataBuffer {
return
data_
;
}
/*! \return Physical address of the data. */
uint32
_t
phy_addr
()
const
{
vta_phy_addr
_t
phy_addr
()
const
{
return
phy_addr_
;
}
/*!
...
...
@@ -113,7 +113,7 @@ struct DataBuffer {
/*! \brief The internal data. */
void
*
data_
;
/*! \brief The physical address of the buffer, excluding header. */
uint32
_t
phy_addr_
;
vta_phy_addr
_t
phy_addr_
;
};
/*!
...
...
@@ -302,7 +302,7 @@ class BaseQueue {
return
dram_buffer_
;
}
/*! \return Physical address of DRAM. */
uint32
_t
dram_phy_addr
()
const
{
vta_phy_addr
_t
dram_phy_addr
()
const
{
return
dram_phy_addr_
;
}
/*! \return Whether there is pending information. */
...
...
@@ -367,7 +367,7 @@ class BaseQueue {
// The buffer in DRAM
char
*
dram_buffer_
{
nullptr
};
// Physics address of the buffer
uint32
_t
dram_phy_addr_
;
vta_phy_addr
_t
dram_phy_addr_
;
};
/*!
...
...
@@ -424,7 +424,11 @@ class UopQueue : public BaseQueue {
CHECK
((
dram_end_
-
dram_begin_
)
==
(
sram_end_
-
sram_begin_
));
insn
->
memory_type
=
VTA_MEM_ID_UOP
;
insn
->
sram_base
=
sram_begin_
;
#ifdef USE_TSIM
insn
->
dram_base
=
(
uint32_t
)
dram_phy_addr_
+
dram_begin_
*
kElemBytes
;
#else
insn
->
dram_base
=
dram_phy_addr_
/
kElemBytes
+
dram_begin_
;
#endif
insn
->
y_size
=
1
;
insn
->
x_size
=
(
dram_end_
-
dram_begin_
);
insn
->
x_stride
=
(
dram_end_
-
dram_begin_
);
...
...
@@ -958,7 +962,11 @@ class CommandQueue {
insn
->
memory_type
=
dst_memory_type
;
insn
->
sram_base
=
dst_sram_index
;
DataBuffer
*
src
=
DataBuffer
::
FromHandle
(
src_dram_addr
);
#ifdef USE_TSIM
insn
->
dram_base
=
(
uint32_t
)
src
->
phy_addr
()
+
src_elem_offset
*
GetElemBytes
(
dst_memory_type
);
#else
insn
->
dram_base
=
src
->
phy_addr
()
/
GetElemBytes
(
dst_memory_type
)
+
src_elem_offset
;
#endif
insn
->
y_size
=
y_size
;
insn
->
x_size
=
x_size
;
insn
->
x_stride
=
x_stride
;
...
...
@@ -981,7 +989,11 @@ class CommandQueue {
insn
->
memory_type
=
src_memory_type
;
insn
->
sram_base
=
src_sram_index
;
DataBuffer
*
dst
=
DataBuffer
::
FromHandle
(
dst_dram_addr
);
#ifdef USE_TSIM
insn
->
dram_base
=
(
uint32_t
)
dst
->
phy_addr
()
+
dst_elem_offset
*
GetElemBytes
(
src_memory_type
);
#else
insn
->
dram_base
=
dst
->
phy_addr
()
/
GetElemBytes
(
src_memory_type
)
+
dst_elem_offset
;
#endif
insn
->
y_size
=
y_size
;
insn
->
x_size
=
x_size
;
insn
->
x_stride
=
x_stride
;
...
...
@@ -1046,11 +1058,24 @@ class CommandQueue {
// Make sure that we don't exceed contiguous physical memory limits
CHECK
(
insn_queue_
.
count
()
*
sizeof
(
VTAGenericInsn
)
<
VTA_MAX_XFER
);
#ifdef USE_TSIM
int
timeout
=
VTADeviceRun
(
device_
,
insn_queue_
.
dram_phy_addr
(),
uop_queue_
.
dram_phy_addr
(),
inp_phy_addr_
,
wgt_phy_addr_
,
acc_phy_addr_
,
out_phy_addr_
,
insn_queue_
.
count
(),
wait_cycles
);
#else
int
timeout
=
VTADeviceRun
(
device_
,
insn_queue_
.
dram_phy_addr
(),
insn_queue_
.
count
(),
wait_cycles
);
#endif
CHECK_EQ
(
timeout
,
0
);
// Reset buffers
uop_queue_
.
Reset
();
...
...
@@ -1125,6 +1150,18 @@ class CommandQueue {
ThreadLocal
().
reset
();
}
#ifdef USE_TSIM
void
SetBufPhyAddr
(
uint32_t
type
,
vta_phy_addr_t
addr
)
{
switch
(
type
)
{
case
VTA_MEM_ID_INP
:
inp_phy_addr_
=
addr
;
case
VTA_MEM_ID_WGT
:
wgt_phy_addr_
=
addr
;
case
VTA_MEM_ID_ACC
:
acc_phy_addr_
=
addr
;
case
VTA_MEM_ID_OUT
:
out_phy_addr_
=
addr
;
default
:
break
;
}
}
#endif
private
:
// Push GEMM uop to the command buffer
void
PushGEMMOp
(
UopKernel
*
kernel
)
{
...
...
@@ -1229,6 +1266,16 @@ class CommandQueue {
InsnQueue
<
VTA_MAX_XFER
,
true
,
true
>
insn_queue_
;
// Device handle
VTADeviceHandle
device_
{
nullptr
};
#ifdef USE_TSIM
// Input phy addr
vta_phy_addr_t
inp_phy_addr_
{
0
};
// Weight phy addr
vta_phy_addr_t
wgt_phy_addr_
{
0
};
// Accumulator phy addr
vta_phy_addr_t
acc_phy_addr_
{
0
};
// Output phy addr
vta_phy_addr_t
out_phy_addr_
{
0
};
#endif
};
}
// namespace vta
...
...
@@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd,
uint32_t
y_pad_after
,
uint32_t
dst_sram_index
,
uint32_t
dst_memory_type
)
{
#ifdef USE_TSIM
vta
::
DataBuffer
*
src
=
vta
::
DataBuffer
::
FromHandle
(
src_dram_addr
);
static_cast
<
vta
::
CommandQueue
*>
(
cmd
)
->
SetBufPhyAddr
(
dst_memory_type
,
src
->
phy_addr
());
#endif
static_cast
<
vta
::
CommandQueue
*>
(
cmd
)
->
LoadBuffer2D
(
src_dram_addr
,
src_elem_offset
,
x_size
,
y_size
,
x_stride
,
...
...
@@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd,
uint32_t
x_size
,
uint32_t
y_size
,
uint32_t
x_stride
)
{
#ifdef USE_TSIM
vta
::
DataBuffer
*
dst
=
vta
::
DataBuffer
::
FromHandle
(
dst_dram_addr
);
static_cast
<
vta
::
CommandQueue
*>
(
cmd
)
->
SetBufPhyAddr
(
src_memory_type
,
dst
->
phy_addr
());
#endif
static_cast
<
vta
::
CommandQueue
*>
(
cmd
)
->
StoreBuffer2D
(
src_sram_index
,
src_memory_type
,
dst_dram_addr
,
dst_elem_offset
,
...
...
vta/src/tsim/tsim_driver.cc
0 → 100644
View file @
32f74f31
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <vta/driver.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
namespace
vta
{
namespace
tsim
{
using
vta
::
dpi
::
DPIModuleNode
;
using
tvm
::
runtime
::
Module
;
class
DPILoader
{
public
:
void
Init
(
Module
module
)
{
mod_
=
module
;
}
DPIModuleNode
*
Get
()
{
return
static_cast
<
DPIModuleNode
*>
(
mod_
.
operator
->
());
}
static
DPILoader
*
Global
()
{
static
DPILoader
inst
;
return
&
inst
;
}
Module
mod_
;
};
class
Device
{
public
:
Device
()
{
dpi_
=
DPILoader
::
Global
();
}
int
Run
(
vta_phy_addr_t
insn_phy_addr
,
vta_phy_addr_t
uop_phy_addr
,
vta_phy_addr_t
inp_phy_addr
,
vta_phy_addr_t
wgt_phy_addr
,
vta_phy_addr_t
acc_phy_addr
,
vta_phy_addr_t
out_phy_addr
,
uint32_t
insn_count
,
uint32_t
wait_cycles
)
{
this
->
Init
();
this
->
Launch
(
insn_phy_addr
,
uop_phy_addr
,
inp_phy_addr
,
wgt_phy_addr
,
acc_phy_addr
,
out_phy_addr
,
insn_count
,
wait_cycles
);
this
->
WaitForCompletion
(
wait_cycles
);
dev_
->
Finish
();
return
0
;
}
private
:
void
Init
()
{
dev_
=
dpi_
->
Get
();
}
void
Launch
(
vta_phy_addr_t
insn_phy_addr
,
vta_phy_addr_t
uop_phy_addr
,
vta_phy_addr_t
inp_phy_addr
,
vta_phy_addr_t
wgt_phy_addr
,
vta_phy_addr_t
acc_phy_addr
,
vta_phy_addr_t
out_phy_addr
,
uint32_t
insn_count
,
uint32_t
wait_cycles
)
{
// launch simulation thread
dev_
->
Launch
(
wait_cycles
);
dev_
->
WriteReg
(
0x10
,
insn_count
);
dev_
->
WriteReg
(
0x14
,
insn_phy_addr
);
dev_
->
WriteReg
(
0x18
,
insn_phy_addr
>>
32
);
dev_
->
WriteReg
(
0x1c
,
0
);
dev_
->
WriteReg
(
0x20
,
uop_phy_addr
>>
32
);
dev_
->
WriteReg
(
0x24
,
0
);
dev_
->
WriteReg
(
0x28
,
inp_phy_addr
>>
32
);
dev_
->
WriteReg
(
0x2c
,
0
);
dev_
->
WriteReg
(
0x30
,
wgt_phy_addr
>>
32
);
dev_
->
WriteReg
(
0x34
,
0
);
dev_
->
WriteReg
(
0x38
,
acc_phy_addr
>>
32
);
dev_
->
WriteReg
(
0x3c
,
0
);
dev_
->
WriteReg
(
0x40
,
out_phy_addr
>>
32
);
// start
dev_
->
WriteReg
(
0x00
,
0x1
);
}
void
WaitForCompletion
(
uint32_t
wait_cycles
)
{
uint32_t
i
,
val
;
for
(
i
=
0
;
i
<
wait_cycles
;
i
++
)
{
val
=
dev_
->
ReadReg
(
0x00
);
val
&=
0x2
;
if
(
val
==
0x2
)
break
;
// finish
}
}
DPILoader
*
dpi_
;
DPIModuleNode
*
dev_
;
};
using
tvm
::
runtime
::
TVMRetValue
;
using
tvm
::
runtime
::
TVMArgs
;
TVM_REGISTER_GLOBAL
(
"tvm.vta.tsim.init"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Module
m
=
args
[
0
];
DPILoader
::
Global
()
->
Init
(
m
);
});
}
// namespace tsim
}
// namespace vta
void
*
VTAMemAlloc
(
size_t
size
,
int
cached
)
{
void
*
p
=
malloc
(
size
);
return
p
;
}
void
VTAMemFree
(
void
*
buf
)
{
free
(
buf
);
}
vta_phy_addr_t
VTAMemGetPhyAddr
(
void
*
buf
)
{
return
reinterpret_cast
<
uint64_t
>
(
reinterpret_cast
<
uint64_t
*>
(
buf
));
}
void
VTAFlushCache
(
vta_phy_addr_t
buf
,
int
size
)
{
}
void
VTAInvalidateCache
(
vta_phy_addr_t
buf
,
int
size
)
{
}
VTADeviceHandle
VTADeviceAlloc
()
{
return
new
vta
::
tsim
::
Device
();
}
void
VTADeviceFree
(
VTADeviceHandle
handle
)
{
delete
static_cast
<
vta
::
tsim
::
Device
*>
(
handle
);
}
int
VTADeviceRun
(
VTADeviceHandle
handle
,
vta_phy_addr_t
insn_phy_addr
,
vta_phy_addr_t
uop_phy_addr
,
vta_phy_addr_t
inp_phy_addr
,
vta_phy_addr_t
wgt_phy_addr
,
vta_phy_addr_t
acc_phy_addr
,
vta_phy_addr_t
out_phy_addr
,
uint32_t
insn_count
,
uint32_t
wait_cycles
)
{
return
static_cast
<
vta
::
tsim
::
Device
*>
(
handle
)
->
Run
(
insn_phy_addr
,
uop_phy_addr
,
inp_phy_addr
,
wgt_phy_addr
,
acc_phy_addr
,
out_phy_addr
,
insn_count
,
wait_cycles
);
}
vta/tests/python/unittest/test_vta_insn.py
View file @
32f74f31
...
...
@@ -68,6 +68,10 @@ def test_save_load_out():
y_np
=
x_np
.
astype
(
y
.
dtype
)
x_nd
=
tvm
.
nd
.
array
(
x_np
,
ctx
)
y_nd
=
tvm
.
nd
.
empty
(
y_np
.
shape
,
ctx
=
ctx
,
dtype
=
y_np
.
dtype
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
f
(
x_nd
,
y_nd
)
np
.
testing
.
assert_equal
(
y_np
,
y_nd
.
asnumpy
())
...
...
@@ -126,6 +130,10 @@ def test_padded_load():
:]
=
x_np
x_nd
=
tvm
.
nd
.
array
(
x_np
,
ctx
)
y_nd
=
tvm
.
nd
.
empty
(
y_np
.
shape
,
ctx
=
ctx
,
dtype
=
y_np
.
dtype
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
f
(
x_nd
,
y_nd
)
np
.
testing
.
assert_equal
(
y_np
,
y_nd
.
asnumpy
())
...
...
@@ -197,6 +205,9 @@ def test_gemm():
y_np
=
np
.
right_shift
(
y_np
,
8
)
y_np
=
np
.
clip
(
y_np
,
0
,
(
1
<<
(
env
.
INP_WIDTH
-
1
))
-
1
)
.
astype
(
y
.
dtype
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
if
env
.
TARGET
==
"sim"
:
simulator
.
clear_stats
()
f
(
x_nd
,
w_nd
,
y_nd
)
...
...
@@ -351,6 +362,10 @@ def test_alu():
a_nd
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
res_nd
=
tvm
.
nd
.
array
(
np
.
zeros
((
m
,
n
,
env
.
BATCH
,
env
.
BLOCK_OUT
))
.
astype
(
res
.
dtype
),
ctx
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
if
use_imm
:
f
(
a_nd
,
res_nd
)
else
:
...
...
@@ -420,6 +435,10 @@ def test_relu():
a_nd
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
res_nd
=
tvm
.
nd
.
array
(
np
.
zeros
((
m
,
n
,
env
.
BATCH
,
env
.
BLOCK_OUT
))
.
astype
(
res
.
dtype
),
ctx
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
f
(
a_nd
,
res_nd
)
np
.
testing
.
assert_equal
(
res_np
,
res_nd
.
asnumpy
())
...
...
@@ -479,6 +498,10 @@ def test_shift_and_scale():
a_nd
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
res_nd
=
tvm
.
nd
.
array
(
np
.
zeros
((
m
,
n
,
env
.
BATCH
,
env
.
BLOCK_OUT
))
.
astype
(
res
.
dtype
),
ctx
)
if
env
.
TARGET
==
"tsim"
:
simulator
.
tsim_init
(
"libvta_hw"
)
f
(
a_nd
,
res_nd
)
np
.
testing
.
assert_equal
(
res_np
,
res_nd
.
asnumpy
())
...
...
@@ -503,11 +526,12 @@ if __name__ == "__main__":
print
(
"Load/store test"
)
test_save_load_out
()
print
(
"Padded load test"
)
#
test_padded_load()
test_padded_load
()
print
(
"GEMM test"
)
test_gemm
()
test_alu
()
print
(
"ALU test"
)
test_alu
()
print
(
"Relu test"
)
test_relu
()
print
(
"Shift and scale"
)
test_shift_and_scale
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment