Commit df6fcc50 by Tianqi Chen Committed by GitHub

[CODEGEN] Refactor common codegen, Verilog Codegen (#74)

* [CODEGEN] Refactor common codegen, Verilog Codegen

* fix make

* fix mk

* update enable signal

* change function name to at neg edge

* Move test to correct place
parent 9ebb57b3
Subproject commit 7efe0366e93c053d558415b72f9fe3f6545eb721 Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8
...@@ -148,6 +148,15 @@ struct IntSetNode : public Node { ...@@ -148,6 +148,15 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
}; };
/*!
* \brief Detect if e can be rewritten as e = base + var * coeff
* Where coeff and base are invariant of var.
*
* \return [base, coeff] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(Expr e, Var var);
/*! /*!
* \brief Find an symbolic integer set that contains all possible values of * \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables. * e given the domain of each iteration variables.
......
...@@ -19,6 +19,19 @@ using ::tvm::Node; ...@@ -19,6 +19,19 @@ using ::tvm::Node;
using ::tvm::NodeRef; using ::tvm::NodeRef;
using ::tvm::AttrVisitor; using ::tvm::AttrVisitor;
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public NodeRef { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<Node> n) : NodeRef(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
}; \
/*! /*!
* \brief save the node as well as all the node it depends on as json. * \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object * This can be used to serialize any TVM object
......
...@@ -35,7 +35,6 @@ struct ChannelNode : public Node { ...@@ -35,7 +35,6 @@ struct ChannelNode : public Node {
Var handle_var; Var handle_var;
/*! \brief default data type in read/write */ /*! \brief default data type in read/write */
Type dtype; Type dtype;
// visit all attributes // visit all attributes
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("handle_var", &handle_var); v->Visit("handle_var", &handle_var);
......
...@@ -103,10 +103,16 @@ constexpr const char* extern_op_scope = "extern_op_scope"; ...@@ -103,10 +103,16 @@ constexpr const char* extern_op_scope = "extern_op_scope";
// Pipeline related attributes // Pipeline related attributes
/*! \brief channel read scope */ /*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope"; constexpr const char* channel_read_scope = "channel_read_scope";
/*! \brief Advance step of channel after end of scope */
constexpr const char* channel_read_advance = "channel_read_advance";
/*! \brief channel write scope */ /*! \brief channel write scope */
constexpr const char* channel_write_scope = "channel_write_scope"; constexpr const char* channel_write_scope = "channel_write_scope";
/*! \brief pipeline module scope */ /*! \brief Advance step of channel after end of scope */
constexpr const char* channel_write_advance = "channel_write_advance";
/*! \brief pipeline stage scope, implies always execution */
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope"; constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
/*! \brief pipeline execution scope, implies the scope can be pipelined. */
constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
} // namespace attr } // namespace attr
/*! \brief namespace of TVM Intrinsic functions */ /*! \brief namespace of TVM Intrinsic functions */
......
...@@ -56,6 +56,14 @@ bool VerifySSA(const Stmt& ir); ...@@ -56,6 +56,14 @@ bool VerifySSA(const Stmt& ir);
bool HasSideEffect(const Expr& e); bool HasSideEffect(const Expr& e);
/*! /*!
* \brief Whether e expression used var.
* \param e The expression to be checked.
* \param v The variable.
* \return Whether e uses v.
*/
bool ExprUseVar(const Expr& e, const Var& v);
/*!
* \brief Convert a IR node to be SSA form. * \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted. * \param stmt The source statement to be converted.
* \return The converted form. * \return The converted form.
...@@ -115,9 +123,17 @@ Stmt RemoveNoOp(Stmt stmt); ...@@ -115,9 +123,17 @@ Stmt RemoveNoOp(Stmt stmt);
/*! /*!
* \brief Split statement into pipeine stages. * \brief Split statement into pipeine stages.
* \param stmt The stmt to be splitted * \param stmt The stmt to be splitted
* \param split_load Whether split load into its own stage.
* \return Transformed stmt.
*/
Stmt SplitPipeline(Stmt stmt, bool split_load);
/*!
* \brief Narrow channel access to smaller range.
* \param stmt The stmt to do access rewriting.
* \return Transformed stmt. * \return Transformed stmt.
*/ */
Stmt SplitPipeline(Stmt stmt); Stmt NarrowChannelAccess(Stmt stmt);
/*! /*!
* \brief unroll the constant loops * \brief unroll the constant loops
......
...@@ -5,6 +5,7 @@ from __future__ import absolute_import ...@@ -5,6 +5,7 @@ from __future__ import absolute_import
import ctypes import ctypes
import sys import sys
import traceback
from numbers import Number, Integral from numbers import Number, Integral
from .._base import _LIB, check_call from .._base import _LIB, check_call
...@@ -46,7 +47,14 @@ def convert_to_tvm_func(pyfunc): ...@@ -46,7 +47,14 @@ def convert_to_tvm_func(pyfunc):
""" ctypes function """ """ ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)] pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
rv = local_pyfunc(*pyargs) # pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None: if rv is not None:
if isinstance(rv, tuple): if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one reurn value") raise ValueError("PackedFunction can only support one reurn value")
......
...@@ -4,10 +4,12 @@ from __future__ import absolute_import ...@@ -4,10 +4,12 @@ from __future__ import absolute_import
import subprocess import subprocess
import sys import sys
import os import os
import ctypes
from .. import _api_internal from .. import _api_internal
from .._base import string_types from .._base import string_types
from .._ctypes._node import NodeBase, register_node from .._ctypes._node import NodeBase, register_node
from .._ctypes._function import register_func
from . import testing from . import testing
@register_node @register_node
...@@ -46,7 +48,7 @@ class VPISession(NodeBase): ...@@ -46,7 +48,7 @@ class VPISession(NodeBase):
def __getattr__(self, name): def __getattr__(self, name):
return _api_internal._vpi_SessGetHandleByName(self, name) return _api_internal._vpi_SessGetHandleByName(self, name)
def yield_until_posedge(self): def yield_until_next_cycle(self):
"""Yield until next posedge""" """Yield until next posedge"""
for f in self.yield_callbacks: for f in self.yield_callbacks:
f() f()
...@@ -120,7 +122,8 @@ def search_path(): ...@@ -120,7 +122,8 @@ def search_path():
"""Get the search directory.""" """Get the search directory."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
ver_path = [os.path.join(curr_path, '../../../verilog/')] ver_path = [os.path.join(curr_path, '../../../verilog/')]
ver_path += [os.path.join(curr_path, '../../../tests/verilog/')] ver_path += [os.path.join(curr_path, '../../../tests/verilog/unittest/')]
ver_path += [os.path.join(curr_path, '../../../tests/verilog/integration/')]
return ver_path return ver_path
...@@ -178,29 +181,41 @@ def compile_file(file_name, file_target, options=None): ...@@ -178,29 +181,41 @@ def compile_file(file_name, file_target, options=None):
raise ValueError("Compilation error:\n%s" % out) raise ValueError("Compilation error:\n%s" % out)
def session(file_name): def session(file_names, codes=None):
"""Create a new iverilog session by compile the file. """Create a new iverilog session by compile the file.
Parameters Parameters
---------- ----------
file_name : str or list of str file_names : str or list of str
The name of the file The name of the file
codes : str or list of str
The code in str.
Returns Returns
------- -------
sess : VPISession sess : VPISession
The created session. The created session.
""" """
if isinstance(file_name, string_types): if isinstance(file_names, string_types):
file_name = [file_name] file_names = [file_names]
for name in file_name: path = testing.tempdir()
if codes:
if isinstance(codes, (list, tuple)):
codes = '\n'.join(codes)
fcode = path.relpath("temp_code.v")
with open(fcode, "w") as out_file:
out_file.write(codes)
file_names.append(fcode)
for name in file_names:
if not os.path.exists(name): if not os.path.exists(name):
raise ValueError("Cannot find file %s" % name) raise ValueError("Cannot find file %s" % name)
path = testing.tempdir() target = path.relpath(os.path.basename(file_names[0].rsplit(".", 1)[0]))
target = path.relpath(os.path.basename(file_name[0].rsplit(".", 1)[0])) compile_file(file_names, target)
compile_file(file_name, target)
vpi_path = _find_vpi_path() vpi_path = _find_vpi_path()
cmd = ["vvp"] cmd = ["vvp"]
...@@ -243,3 +258,43 @@ def session(file_name): ...@@ -243,3 +258,43 @@ def session(file_name):
sess.proc = proc sess.proc = proc
sess.execpath = path sess.execpath = path
return sess return sess
@register_func
def tvm_callback_verilog_simulator(code, *args):
"""Callback by TVM runtime to invoke verilog simulator
Parameters
----------
code : str
The verilog code to be simulated
args : list
Additional arguments to be set.
"""
libs = [
find_file("tvm_vpi_mmap.v")
]
sess = session(libs, code)
for i, value in enumerate(args):
vpi_h = sess.main["tvm_arg%d" % i]
if isinstance(value, ctypes.c_void_p):
int_value = int(value.value)
elif isinstance(value, int):
int_value = value
else:
raise ValueError(
"Do not know how to handle value type %s" % type(value))
vpi_h.put_int(int_value)
rst = sess.main.rst
done = sess.main.done
# start driving
rst.put_int(1)
sess.yield_until_next_cycle()
rst.put_int(0)
sess.yield_until_next_cycle()
while not done.get_int():
sess.yield_until_next_cycle()
sess.yield_until_next_cycle()
sess.shutdown()
...@@ -26,6 +26,11 @@ TVM_REGISTER_API(_arith_EvalModular) ...@@ -26,6 +26,11 @@ TVM_REGISTER_API(_arith_EvalModular)
*ret = EvalModular(args[0], Map<Var, IntSet>()); *ret = EvalModular(args[0], Map<Var, IntSet>());
}); });
TVM_REGISTER_API(_arith_DetectLinearEquation)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]);
});
TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], *ret = DeduceBound(args[0], args[1],
......
...@@ -63,6 +63,7 @@ REGISTER_PASS1(CanonicalSimplify); ...@@ -63,6 +63,7 @@ REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar);
REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS4(MakeAPI);
...@@ -71,7 +72,7 @@ REGISTER_PASS1(LiftAllocate); ...@@ -71,7 +72,7 @@ REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition); REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp); REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS1(SplitPipeline); REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include "./compute_expr.h"
namespace tvm {
namespace arith {
using namespace ir;
// Linear equation, the components can be undefined.
struct LinearEqEntry {
Expr base;
Expr coeff;
};
class LinearEqDetector
: public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
public:
explicit LinearEqDetector(Var var)
: var_(var) {}
Array<Expr> Detect(const Expr& e) {
LinearEqEntry ret = VisitExpr(e, e);
if (fail_) return Array<Expr>();
if (!ret.base.defined()) {
ret.base = make_zero(var_.type());
}
if (!ret.coeff.defined()) {
ret.coeff = make_zero(var_.type());
}
return Array<Expr>{ret.base, ret.coeff};
}
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
LinearEqEntry ret;
ret.base = AddCombine(a.base, b.base);
ret.coeff = AddCombine(a.coeff, b.coeff);
return ret;
}
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
if (a.coeff.defined()) {
std::swap(a, b);
}
if (a.coeff.defined()) {
fail_ = true;
return LinearEqEntry();
}
LinearEqEntry ret;
ret.base = MulCombine(a.base, b.base);
ret.coeff = MulCombine(a.base, b.coeff);
return ret;
}
LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
ret.coeff = make_const(op->type, 1);
} else {
ret.base = e;
}
return ret;
}
LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
fail_ = true;
return LinearEqEntry();
} else {
LinearEqEntry ret;
ret.base = e;
return ret;
}
}
private:
Var var_;
bool fail_{false};
// Combine by add
Expr AddCombine(Expr a, Expr b) {
if (!a.defined()) return b;
if (!b.defined()) return a;
return ComputeExpr<Add>(a, b);
}
Expr MulCombine(Expr a, Expr b) {
if (!a.defined()) return a;
if (!b.defined()) return b;
return ComputeExpr<Mul>(a, b);
}
};
Array<Expr> DetectLinearEquation(Expr e, Var var) {
return LinearEqDetector(var).Detect(e);
}
} // namespace arith
} // namespace tvm
...@@ -163,6 +163,7 @@ inline bool MatchPoint(const IntSet& a, ...@@ -163,6 +163,7 @@ inline bool MatchPoint(const IntSet& a,
} }
IntSet Union(const Array<IntSet>& sets) { IntSet Union(const Array<IntSet>& sets) {
if (sets.size() == 0) return IntSet::nothing();
if (sets.size() == 1) return sets[0]; if (sets.size() == 1) return sets[0];
Interval x = sets[0].cover_interval().as<IntervalSet>()->i; Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) { for (size_t i = 1; i < sets.size(); ++i) {
......
...@@ -18,11 +18,8 @@ void CodeGenC::Init(bool output_ssa) { ...@@ -18,11 +18,8 @@ void CodeGenC::Init(bool output_ssa) {
void CodeGenC::InitFuncState(LoweredFunc f) { void CodeGenC::InitFuncState(LoweredFunc f) {
alloc_storage_scope_.clear(); alloc_storage_scope_.clear();
name_alloc_map_.clear();
ssa_assign_map_.clear();
var_idmap_.clear();
handle_data_type_.clear(); handle_data_type_.clear();
scope_mark_.clear(); CodeGenSourceBase::ClearFuncState();
} }
void CodeGenC::AddFunction(LoweredFunc f) { void CodeGenC::AddFunction(LoweredFunc f) {
// clear previous generated state. // clear previous generated state.
...@@ -67,30 +64,6 @@ std::string CodeGenC::Finish() { ...@@ -67,30 +64,6 @@ std::string CodeGenC::Finish() {
return stream.str(); return stream.str();
} }
std::string CodeGenC::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
if (scope_mark_.at(it->second.scope_id)) {
return it->second.vid;
}
}
this->PrintIndent();
SSAEntry e;
e.vid = GetUniqueName("_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
if (src.length() > 3 &&
src[0] == '(' && src[src.length() - 1] == ')') {
src = src.substr(1, src.length() - 2);
}
PrintType(t, stream);
stream << ' ' << e.vid << " = " << src << ";\n";
return e.vid;
}
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) { if (print_ssa_form_) {
std::ostringstream temp; std::ostringstream temp;
...@@ -101,88 +74,17 @@ void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) ...@@ -101,88 +74,17 @@ void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
} }
} }
std::string CodeGenC::GetUniqueName(std::string prefix) { void CodeGenC::PrintSSAAssign(
auto it = name_alloc_map_.find(prefix); const std::string& target, const std::string& src, Type t) {
if (it != name_alloc_map_.end()) { PrintType(t, stream);
while (true) { stream << ' ' << target << " = ";
std::ostringstream os; if (src.length() > 3 &&
os << prefix << (++it->second); src[0] == '(' && src[src.length() - 1] == ')') {
std::string name = os.str(); stream << src.substr(1, src.length() - 2);
if (name_alloc_map_.count(name) == 0) {
prefix = name;
break;
}
}
}
name_alloc_map_[prefix] = 0;
return prefix;
}
std::string CodeGenC::AllocVarID(const Variable* v) {
CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
for (size_t i = 0; i < key.size(); ++i) {
if (key[i] == '.') key[i] = '_';
}
std::string vid = GetUniqueName(key);
var_idmap_[v] = vid;
return vid;
}
std::string CodeGenC::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
return it->second;
}
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else { } else {
CHECK(it->second == t) stream << src;
<< "conflicting buf var type";
} }
} stream << ";\n";
void CodeGenC::PrintIndent() {
for (int i = 0; i < this->indent; ++i) {
this->stream << ' ';
}
}
void CodeGenC::MarkConst(std::string vid) {
if (print_ssa_form_) {
auto it = ssa_assign_map_.find(vid);
if (it == ssa_assign_map_.end()) {
SSAEntry e;
e.vid = vid;
e.scope_id = 0;
ssa_assign_map_[vid] = e;
} else {
CHECK_EQ(it->second.vid, vid);
}
}
}
int CodeGenC::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
} }
// Print a reference expression to a buffer. // Print a reference expression to a buffer.
...@@ -229,6 +131,23 @@ void CodeGenC::PrintBufferRef( ...@@ -229,6 +131,23 @@ void CodeGenC::PrintBufferRef(
} }
} }
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
CHECK(it->second == t)
<< "conflicting buf var type";
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec, void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i, Type t, int i,
std::ostream& os) { // NOLINT(*) std::ostream& os) { // NOLINT(*)
...@@ -564,29 +483,32 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { ...@@ -564,29 +483,32 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes(); int lanes = op->type.lanes();
std::string svalue = GetUniqueName("_");
// delcare type.
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue;
if (op->type.lanes() == 1) { if (op->type.lanes() == 1) {
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os); stream << " = ";
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream);
stream << ";\n";
} else { } else {
Expr base; Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
this->PrintVecLoad(op->buffer_var.get(), op->type, base, os); stream << " = ";
this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream);
stream << ";\n";
} else { } else {
// Load elements seperately // Load elements seperately
stream << ";\n";
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type()); std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
std::string svalue = GetUniqueName("_");
{
// delcare type.
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue << ";\n";
}
std::string vid = GetVarID(op->buffer_var.get()); std::string vid = GetVarID(op->buffer_var.get());
Type elem_type = op->type.element_of(); Type elem_type = op->type.element_of();
for (int i = 0; i < lanes; ++i) { for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp; std::ostringstream value_temp;
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
value_temp << "(("; value_temp << "((";
PrintType(elem_type, os); PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')'; value_temp << "*)" << vid << ')';
} else { } else {
value_temp << vid; value_temp << vid;
...@@ -596,9 +518,9 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -596,9 +518,9 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
value_temp << ']'; value_temp << ']';
PrintVecElemStore(svalue, op->type, i, value_temp.str()); PrintVecElemStore(svalue, op->type, i, value_temp.str());
} }
os << svalue;
} }
} }
os << svalue;
} }
void CodeGenC::VisitStmt_(const Store* op) { void CodeGenC::VisitStmt_(const Store* op) {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "./codegen_source_base.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -25,7 +26,8 @@ using namespace ir; ...@@ -25,7 +26,8 @@ using namespace ir;
*/ */
class CodeGenC : class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>, public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> { public StmtFunctor<void(const Stmt&)>,
public CodeGenSourceBase {
public: public:
/*! /*!
* \brief Initialize the code generator. * \brief Initialize the code generator.
...@@ -64,26 +66,6 @@ class CodeGenC : ...@@ -64,26 +66,6 @@ class CodeGenC :
PrintExpr(n, os); PrintExpr(n, os);
return os.str(); return os.str();
} }
/*! \brief print the current indented value */
void PrintIndent();
/*!
* \brief Register constant value appeared in expresion tree
* This avoid generated a ssa id for each appearance of the value
* \param value The constant value.
*/
void MarkConst(std::string value);
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the variable name.
*/
std::string AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
std::string GetVarID(const Variable* v) const;
// The following parts are overloadable print operations. // The following parts are overloadable print operations.
/*! /*!
* \brief Initialize codegen state for generating f. * \brief Initialize codegen state for generating f.
...@@ -164,44 +146,13 @@ class CodeGenC : ...@@ -164,44 +146,13 @@ class CodeGenC :
virtual void PrintVecElemStore( virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value); const std::string& vec, Type t, int i, const std::string& value);
protected: protected:
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer, void PrintBufferRef(const Variable* buffer,
Type t, Expr index, Type t, Expr index,
std::ostream& os); // NOLINT(*) std::ostream& os); // NOLINT(*)
/*! /*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
* \param src The source expression
* \param t The type of the expression.
*/
std::string SSAGetID(std::string src, Type t);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
int BeginScope();
/*!
* \brief mark the end of an old scope.
* \param scope_id The scope id to be ended.
*/
void EndScope(int scope_id);
/*!
* \brief If buffer is allocated as type t. * \brief If buffer is allocated as type t.
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
* \param t The type to be checked. * \param t The type to be checked.
...@@ -213,30 +164,17 @@ class CodeGenC : ...@@ -213,30 +164,17 @@ class CodeGenC :
* \param t The type to be checked. * \param t The type to be checked.
*/ */
void RegisterHandleType(const Variable* buf_var, Type t); void RegisterHandleType(const Variable* buf_var, Type t);
/*! // override
* \brief Get the storage scope of buf_var. void PrintSSAAssign(
* \param buf_var The buf_var to be queryed. const std::string& target, const std::string& src, Type t) final;
* \return The storage scope.
*/
std::string GetStorageScope(const Variable* buf_var) const;
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
private: private:
/*! \brief whether to print in SSA form */ /*! \brief whether to print in SSA form */
bool print_ssa_form_{true}; bool print_ssa_form_{true};
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */ /*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_; std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent{0};
}; };
} // namespace codegen } // namespace codegen
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_source_base.cc
*/
#include "./codegen_source_base.h"
namespace tvm {
namespace codegen {
void CodeGenSourceBase::ClearFuncState() {
name_alloc_map_.clear();
ssa_assign_map_.clear();
var_idmap_.clear();
scope_mark_.clear();
}
std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
for (size_t i = 0; i < prefix.size(); ++i) {
if (prefix[i] == '.') prefix[i] = '_';
}
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (name_alloc_map_.count(name) == 0) {
prefix = name;
break;
}
}
}
name_alloc_map_[prefix] = 0;
return prefix;
}
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
if (scope_mark_.at(it->second.scope_id)) {
return it->second.vid;
}
}
SSAEntry e;
e.vid = GetUniqueName("_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
this->PrintIndent();
PrintSSAAssign(e.vid, src, t);
return e.vid;
}
std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
std::string vid = GetUniqueName(key);
var_idmap_[v] = vid;
return vid;
}
std::string CodeGenSourceBase::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
return it->second;
}
void CodeGenSourceBase::PrintIndent() {
for (int i = 0; i < indent_; ++i) {
this->stream << ' ';
}
}
void CodeGenSourceBase::MarkConst(std::string vid) {
auto it = ssa_assign_map_.find(vid);
if (it == ssa_assign_map_.end()) {
SSAEntry e;
e.vid = vid;
e.scope_id = 0;
ssa_assign_map_[vid] = e;
} else {
CHECK_EQ(it->second.vid, vid);
}
}
int CodeGenSourceBase::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent_ += 2;
return sid;
}
void CodeGenSourceBase::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent_ -= 2;
}
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_source_base.h
* \brief Common utilities to source code in text form.
*/
#ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#include <tvm/ir.h>
#include <tvm/codegen.h>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
namespace codegen {
/*!
* \brief A base class to generate source code.
* Contains helper utilities to generate nest and ssa form.
*/
class CodeGenSourceBase {
public:
/*!
* \brief Register constant value appeared in expresion tree
* This avoid generated a ssa id for each appearance of the value
* \param value The constant value.
*/
void MarkConst(std::string value);
protected:
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id, used to check if this entry is invalid. */
int scope_id;
};
/*! \brief Clear the states that might relates to function generation */
void ClearFuncState();
/*! \brief print the current indented value */
void PrintIndent();
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the variable name.
*/
std::string AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
std::string GetVarID(const Variable* v) const;
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
* \param src The source expression
* \param t The type of the expression.
*/
std::string SSAGetID(std::string src, Type t);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
int BeginScope();
/*!
* \brief mark the end of an old scope.
* \param scope_id The scope id to be ended.
*/
void EndScope(int scope_id);
/*!
* \brief Print assignment of src to the id in ssa entry.
* \param target id of target variable.
* \param src The source expression.
* \param t The type of target.
*/
virtual void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) = 0;
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
private:
/*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent_{0};
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_verilog.cc
*/
#include <tvm/ir_pass.h>
#include <cctype>
#include <sstream>
#include <iostream>
#include "./codegen_verilog.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
namespace verilog {
using namespace ir;
void CodeGenVerilog::Init() {
stream << "`include \"tvm_marcos.v\"\n\n";
}
void CodeGenVerilog::InitFuncState(LoweredFunc f) {
CodeGenSourceBase::ClearFuncState();
cmap_.clear();
tvm_vpi_modules_.clear();
done_sigs_.clear();
}
void CodeGenVerilog::AddFunction(LoweredFunc f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
GetUniqueName("rst");
GetUniqueName("clk");
GetUniqueName("done");
GetUniqueName("enable");
GetUniqueName("all_input_valid");
// print out function body.
int func_scope = this->BeginScope();
// Stich things up.
stream << "module " << f->name << "(\n";
PrintDecl("clk", kInput, Bool(1), "");
stream << ",\n";
PrintDecl("rst", kInput, Bool(1), "");
VerilogFuncEntry entry;
for (size_t i = 0; i < f->args.size(); ++i) {
stream << ",\n";
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
entry.arg_ids.push_back(vid);
entry.arg_types.push_back(v.type());
PrintDecl(vid, kInput, v.type(), "");
}
stream << ",\n";
PrintDecl("done", kOutput, Bool(1), "");
stream << "\n);\n";
this->CodeGen(MakePipeline(f));
PrintAssignAnd("done", done_sigs_);
this->EndScope(func_scope);
this->PrintIndent();
stream << "endmodule\n";
entry.vpi_modules = std::move(tvm_vpi_modules_);
functions_[f->name] = entry;
}
std::string VerilogCodeGenModule::AppendSimMain(
const std::string& func_name) const {
// Add main function for simulator hook
const VerilogFuncEntry& entry = fmap.at(func_name);
std::ostringstream stream;
stream << code;
stream << "\n"
<< "module main();\n"
<< " `TVM_DEFINE_TEST_SIGNAL(clk, rst)\n";
// print out function body.
std::vector<std::string> sargs;
for (size_t i = 0; i < entry.arg_types.size(); ++i) {
Type t = entry.arg_types[i];
std::ostringstream sarg;
sarg << "tvm_arg" << i;
std::string vid = sarg.str();
stream << " reg";
if (t.bits() > 1) {
stream << "[" << t.bits() - 1 << ":0]";
}
stream << " " << vid << ";\n";
sargs.push_back(vid);
}
stream << " wire done;\n";
stream << "\n " << func_name << " dut(\n"
<< " .clk(clk),\n"
<< " .rst(rst),\n";
for (size_t i = 0; i < entry.arg_ids.size(); ++i) {
stream << " ." << entry.arg_ids[i] << '('
<< sargs[i] << "),\n";
}
stream << " .done(done)\n"
<< " );\n";
stream << " initial begin\n"
<< " $tvm_session(clk";
for (const std::string& mvpi : entry.vpi_modules) {
stream << ", dut." << mvpi;
}
stream << ");\n"
<< " end\n";
stream << "endmodule\n";
return stream.str();
}
VerilogCodeGenModule CodeGenVerilog::Finish() {
VerilogCodeGenModule m;
m.code = stream.str();
m.fmap = std::move(functions_);
return m;
}
void CodeGenVerilog::PrintDecl(
const std::string& vid, VerilogVarType vtype, Type dtype,
const char* suffix, bool indent) {
if (indent) PrintIndent();
switch (vtype) {
case kReg: stream << "reg "; break;
case kWire: stream << "wire "; break;
case kInput: stream << "input "; break;
case kOutput: stream << "output "; break;
default: LOG(FATAL) << "unsupported vtype=" << vtype;
}
int bits = dtype.bits();
// bits for handle type.
if (dtype.is_handle()) {
bits = 64;
}
if (bits > 1) {
stream << "[" << bits - 1 << ":0] ";
}
stream << vid << suffix;
}
void CodeGenVerilog::PrintSSAAssign(
const std::string& target, const std::string& src, Type t) {
// add target to list of declaration.
PrintDecl(target, kWire, t, ";\n", false);
PrintAssign(target, src);
}
void CodeGenVerilog::PrintAssign(
const std::string& target, const std::string& src) {
PrintIndent();
stream << "assign " << target << " = ";
if (src.length() > 3 &&
src[0] == '(' && src[src.length() - 1] == ')') {
stream << src.substr(1, src.length() - 2);
} else {
stream << src;
}
stream << ";\n";
}
void CodeGenVerilog::PrintAssignAnd(
const std::string& target, const std::vector<std::string>& conds) {
if (conds.size() != 0) {
std::ostringstream os_valid;
for (size_t i = 0; i < conds.size(); ++i) {
if (i != 0) os_valid << " && ";
os_valid << conds[i];
}
PrintAssign(target, os_valid.str());
} else {
PrintAssign(target, "1");
}
}
void CodeGenVerilog::PrintLine(const std::string& line) {
PrintIndent();
stream << line << '\n';
}
VerilogValue CodeGenVerilog::MakeBinary(Type t,
VerilogValue a,
VerilogValue b,
const char *opstr) {
CHECK_EQ(t.lanes(), 1)
<< "Do not yet support vectorized op";
CHECK(t.is_int() || t.is_uint())
<< "Only support integer operations";
std::ostringstream os;
os << a.vid << ' ' << opstr << ' '<< b.vid;
return GetSSAValue(os.str(), t);
}
template<typename T>
inline VerilogValue IntConst(const T* op, CodeGenVerilog* p) {
if (op->type.bits() <= 32 && op->type.lanes() == 1) {
std::ostringstream temp;
temp << op->value;
p->MarkConst(temp.str());
return VerilogValue(temp.str(), kConst, op->type);
} else {
LOG(FATAL) << "Do not support integer constant type " << op->type;
return VerilogValue();
}
}
VerilogValue CodeGenVerilog::VisitExpr_(const IntImm *op) {
return IntConst(op, this);
}
VerilogValue CodeGenVerilog::VisitExpr_(const UIntImm *op) {
return IntConst(op, this);
}
VerilogValue CodeGenVerilog::VisitExpr_(const FloatImm *op) {
LOG(FATAL) << "Donot support float constant in Verilog";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const StringImm *op) {
LOG(FATAL) << "Donot support string constant in Verilog";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Cast *op) {
LOG(FATAL) << "Type cast not supported";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Variable *op) {
return VerilogValue(GetVarID(op), kReg, op->type);
}
VerilogValue CodeGenVerilog::VisitExpr_(const Add *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "+");
}
VerilogValue CodeGenVerilog::VisitExpr_(const Sub *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "-");
}
VerilogValue CodeGenVerilog::VisitExpr_(const Mul *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "*");
}
VerilogValue CodeGenVerilog::VisitExpr_(const Div *op) {
int shift;
if (is_const_power_of_two_integer(op->b, &shift) &&
(op->type.is_int() || op->type.is_uint())) {
return MakeValue(op->a >> make_const(op->b.type(), shift));
} else {
LOG(FATAL) << "do not support synthesis division";
}
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Mod *op) {
LOG(FATAL) << "do not support synthesis Mod";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Min *op) {
LOG(FATAL) << "not supported";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Max *op) {
LOG(FATAL) << "not supported";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const EQ *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "==");
}
VerilogValue CodeGenVerilog::VisitExpr_(const NE *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "!=");
}
VerilogValue CodeGenVerilog::VisitExpr_(const LT *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "<");
}
VerilogValue CodeGenVerilog::VisitExpr_(const LE *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "<=");
}
VerilogValue CodeGenVerilog::VisitExpr_(const GT *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), ">");
}
VerilogValue CodeGenVerilog::VisitExpr_(const GE *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), ">=");
}
VerilogValue CodeGenVerilog::VisitExpr_(const And *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "&&");
}
VerilogValue CodeGenVerilog::VisitExpr_(const Or *op) {
return MakeBinary(op->type, MakeValue(op->a), MakeValue(op->b), "||");
}
VerilogValue CodeGenVerilog::VisitExpr_(const Not *op) {
VerilogValue value = MakeValue(op->a);
std::ostringstream os;
os << "(!" << value.vid << ")";
return GetSSAValue(os.str(), op->type);
}
VerilogValue CodeGenVerilog::VisitExpr_(const Call *op) {
if (op->is_intrinsic(Call::bitwise_and)) {
return MakeBinary(
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "&");
} else if (op->is_intrinsic(Call::bitwise_xor)) {
return MakeBinary(
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "^");
} else if (op->is_intrinsic(Call::bitwise_or)) {
return MakeBinary(
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "|");
} else if (op->is_intrinsic(Call::bitwise_not)) {
VerilogValue value = MakeValue(op->args[0]);
std::ostringstream os;
os << "(~" << value.vid << ")";
return GetSSAValue(os.str(), op->type);
} else if (op->is_intrinsic(Call::shift_left)) {
return MakeBinary(
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), "<<");
} else if (op->is_intrinsic(Call::shift_right)) {
return MakeBinary(
op->type, MakeValue(op->args[0]), MakeValue(op->args[1]), ">>");
} else {
LOG(FATAL) << "Cannot generate call type " << op->name;
return VerilogValue();
}
}
VerilogValue CodeGenVerilog::VisitExpr_(const Let* op) {
VerilogValue value = MakeValue(op->value);
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value.vid;
return value;
}
VerilogValue CodeGenVerilog::VisitExpr_(const Ramp* op) {
LOG(FATAL) << "Ramp: not supported ";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Broadcast* op) {
LOG(FATAL) << "Broadcast: not supported ";
return VerilogValue();
}
VerilogValue CodeGenVerilog::VisitExpr_(const Select* op) {
LOG(FATAL) << "Select: not supported ";
return VerilogValue();
}
void CodeGenVerilog::CodeGen(const Pipeline& pipeline) {
// setup channel map.
for (auto kv : pipeline->channels) {
ChannelEntry e; e.block = kv.second;
cmap_[kv.first.get()] = e;
}
for (ComputeBlock stage : pipeline->stages) {
const Store* store = stage->body.as<Store>();
CHECK(store);
const Load* load = store->value.as<Load>();
if (load) {
MakeLoadToFIFO(stage, store, load);
} else {
MakeStore(stage, store);
}
}
for (const auto& kv : cmap_) {
MakeChannelUnit(kv.second);
}
}
CodeGenVerilog::SignalEntry
CodeGenVerilog::MakeLoop(const Array<Stmt>& loop) {
SignalEntry sig;
// do not use init signal for now.
std::string init = "0";
std::string lp_ready = GetUniqueName("lp_tmp_sig");
sig.ready = GetUniqueName("loop_ready");
sig.valid = GetUniqueName("loop_valid");
PrintLine("// loop logic");
PrintDecl(lp_ready, kWire, Bool(1));
PrintDecl(sig.ready, kWire, Bool(1));
std::string end_loop = lp_ready;
for (size_t i = loop.size(); i != 0; --i) {
const For* for_op = loop[i - 1].as<For>();
int bits = for_op->loop_var.type().bits();
VerilogValue min = MakeValue(for_op->min);
VerilogValue extent = MakeValue(for_op->extent);
CHECK(min.vtype == kConst && extent.vtype == kConst)
<< "Only support constant loop domain";
std::string vid = AllocVarID(for_op->loop_var.get());
std::string finish = GetUniqueName(vid + "_finish");
this->PrintIndent();
stream <<"`NONSTOP_LOOP(" << vid << ", " << bits << ", " << init
<< ", " << end_loop << ", " << finish
<< ", " << min.vid << ", " << extent.vid << ")\n";
end_loop = finish;
}
if (loop.size() != 0) {
std::string local_ready = GetUniqueName("lp_tmp_sig");
this->PrintIndent();
stream <<"`WRAP_LOOP_ONCE(" << init << ", " << sig.valid
<< ", " << sig.ready << ", " << end_loop << ", " << local_ready << ")\n";
PrintAssign(lp_ready, local_ready);
}
return sig;
}
void CodeGenVerilog::MakeStageInputs(
const ComputeBlock& block,
const std::string& enable,
std::string* out_all_input_valid) {
std::vector<SignalEntry> sigs;
sigs.push_back(MakeLoop(block->loop));
// Input data path.
PrintLine("// stage inputs");
for (auto kv : block->inputs) {
const Var& var = kv.first;
const StageInput& arg = kv.second;
std::string vid = AllocVarID(var.get());
this->PrintDecl(vid, kWire, var.type());
if (arg->input_type == kGlobalConst ||
arg->input_type == kLoopVar) {
PrintAssign(vid, GetVarID(arg->var.get()));
} else if (arg->input_type == kChannel) {
std::string vid_valid = GetUniqueName(vid + "_valid");
std::string vid_ready = GetUniqueName(vid + "_ready");
this->PrintDecl(vid_valid, kWire, Bool(1));
this->PrintDecl(vid_ready, kWire, Bool(1));
ChannelEntry* e = GetChannelInfo(arg->var.get());
// TODO(tqchen, thierry) add one cache here.
e->AssignPort("read_data", vid, var.type());
e->AssignPort("read_valid", vid_valid, Bool(1));
e->AssignPort("read_ready", vid_ready, Bool(1));
e->AssignPort("read_addr", "0", Int(1));
sigs.push_back(SignalEntry{vid_valid, vid_ready});
} else {
LOG(FATAL) << "Unknown input type";
}
}
PrintLine("// stage input stall");
std::string all_input_valid = GetUniqueName("all_input_valid");
this->PrintDecl(all_input_valid, kWire, Bool(1));
// forward all valid
std::vector<std::string> valid_conds;
for (const SignalEntry& e : sigs) {
if (e.valid.length() != 0) {
valid_conds.push_back(e.valid);
}
}
PrintAssignAnd(all_input_valid, valid_conds);
// input ready signal
for (size_t i = 0; i < sigs.size(); ++i) {
if (sigs[i].ready.length() == 0) continue;
std::vector<std::string> conds = {enable};
for (size_t j = 0; j < sigs.size(); ++j) {
if (j != i && sigs[j].valid.length() != 0) {
conds.push_back(sigs[j].valid);
}
}
PrintAssignAnd(sigs[i].ready, conds);
}
*out_all_input_valid = all_input_valid;
}
void CodeGenVerilog::MakeDelay(const std::string& dst,
const std::string& src,
Type dtype,
int delay,
const std::string& enable) {
PrintIndent();
stream << "`DELAY(" << dst << ", " << src << ", "
<< dtype.bits() << ", " << delay << ", " << enable << ")\n";
}
void CodeGenVerilog::MakeStore(const ComputeBlock& block,
const Store* store) {
std::string all_input_valid;
std::string enable = GetUniqueName("enable");
this->PrintDecl(enable, kWire, Bool(1));
MakeStageInputs(block, enable, &all_input_valid);
// Data path
PrintLine("// data path");
VerilogValue value = MakeValue(store->value);
VerilogValue index = MakeValue(store->index);
PrintLine("// control and retiming");
ChannelEntry* write_entry = GetChannelInfo(store->buffer_var.get());
// TODO(tqchen, thierry) add delay model from expression.a
int delay = 2;
std::string ch_name = write_entry->block->channel->handle_var->name_hint;
std::string write_addr = GetUniqueName(ch_name + ".write_addr");
std::string write_ready = GetUniqueName(ch_name + ".write_ready");
std::string write_valid = GetUniqueName(ch_name + ".write_valid");
std::string write_data = GetUniqueName(ch_name + ".write_data");
PrintDecl(write_addr, kWire, store->index.type());
PrintDecl(write_ready, kWire, Bool(1));
PrintDecl(write_valid, kWire, Bool(1));
PrintDecl(write_data, kWire, store->value.type());
MakeDelay(write_addr, index.vid, store->index.type(), delay, enable);
MakeDelay(write_data, value.vid, store->value.type(), delay, enable);
MakeDelay(write_valid, all_input_valid, Bool(1), delay, enable);
PrintAssign(enable, "!" + write_valid + " || " + write_ready);
write_entry->AssignPort("write_addr", write_addr, store->index.type());
write_entry->AssignPort("write_ready", write_ready, Bool(1));
write_entry->AssignPort("write_valid", write_valid, Bool(1));
write_entry->AssignPort("write_data", write_data, store->value.type());
// The triggers
for (size_t i = 0; i < block->triggers.size(); ++i) {
SignalTrigger trigger = block->triggers[i];
CHECK(trigger->predicate.type() == Bool(1));
ChannelEntry* trigger_ch = GetChannelInfo(trigger->channel_var.get());
std::string port = trigger_ch->SignalPortName(trigger->signal_index);
VerilogValue v = MakeValue(trigger->predicate);
// Assign constant trigger.
if (v.vtype == kConst) {
trigger_ch->AssignPort(port, v.vid, Bool(1));
} else {
// non-constant trigger
CHECK_EQ(trigger_ch, write_entry)
<< "Can only triggger conditional event at write channel";
std::string v_trigger = GetUniqueName(ch_name + "." + port);
MakeDelay(v_trigger, v.vid, Bool(1), delay, enable);
write_entry->AssignPort(port, v_trigger, Bool(1));
}
}
stream << "\n";
}
void CodeGenVerilog::MakeLoadToFIFO(const ComputeBlock& block,
const Store* store,
const Load* load) {
ChannelEntry* write_entry = GetChannelInfo(store->buffer_var.get());
ChannelEntry* load_entry = GetChannelInfo(load->buffer_var.get());
std::string all_input_valid;
std::string enable = GetUniqueName("enable");
this->PrintDecl(enable, kWire, Bool(1));
MakeStageInputs(block, enable, &all_input_valid);
// data path
PrintLine("// data path");
VerilogValue index = MakeValue(load->index);
// control and retiming
PrintLine("// control and retiming");
// TODO(tqchen, thierry) add delay model from expression
int delay = 1;
std::string read_ch_name = load_entry->block->channel->handle_var->name_hint;
std::string write_ch_name = write_entry->block->channel->handle_var->name_hint;
std::string read_addr = GetUniqueName(read_ch_name + ".read_addr");
std::string read_data = GetUniqueName(read_ch_name + ".read_data");
std::string read_valid = GetUniqueName(read_ch_name + ".read_valid");
std::string index_valid = GetUniqueName(read_ch_name + ".index_valid");
std::string write_ready = GetUniqueName(write_ch_name + ".write_ready");
std::string data_valid = GetUniqueName(read_ch_name + ".data_valid");
std::string valid_delay = GetUniqueName(read_ch_name + ".valid_delay");
PrintDecl(read_addr, kWire, load->index.type());
PrintDecl(read_data, kWire, load->type);
PrintDecl(read_valid, kWire, Bool(1));
PrintDecl(index_valid, kWire, Bool(1));
PrintDecl(data_valid, kWire, Bool(1));
MakeDelay(read_addr, index.vid, load->index.type(), delay, enable);
MakeDelay(index_valid, all_input_valid, Bool(1), delay, enable);
PrintAssignAnd(data_valid, {read_valid, index_valid});
// The read ports.
load_entry->AssignPort("read_addr", read_addr, load->index.type());
load_entry->AssignPort("read_data", read_data, load->type);
load_entry->AssignPort("read_valid", read_valid, Bool(1));
// The write ports.
write_entry->AssignPort("write_ready", write_ready, Bool(1));
write_entry->AssignPort("write_data", read_data, load->type);
write_entry->AssignPort("write_valid", valid_delay, Bool(1));
write_entry->AssignPort("write_addr", "0", Int(1));
// The not stall condition.
PrintAssignAnd(enable, {write_ready, read_valid});
// The ready signal
PrintIndent();
stream << "`BUFFER_READ_VALID_DELAY(" << valid_delay << ", " << data_valid
<< ", " << write_ready << ")\n";
// The triggers
for (size_t i = 0; i < block->triggers.size(); ++i) {
SignalTrigger trigger = block->triggers[i];
CHECK(trigger->predicate.type() == Bool(1));
ChannelEntry* trigger_ch = GetChannelInfo(trigger->channel_var.get());
std::string port = trigger_ch->SignalPortName(trigger->signal_index);
VerilogValue v = MakeValue(trigger->predicate);
// Assign constant trigger.
if (v.vtype == kConst) {
trigger_ch->AssignPort(port, v.vid, Bool(1));
} else {
// non-constant trigger
CHECK_EQ(trigger_ch, load_entry)
<< "Can only triggger conditional event at load channel";
std::string v_trigger = GetUniqueName(read_ch_name + "." + port);
MakeDelay(v_trigger, v.vid, Bool(1), delay, enable);
load_entry->AssignPort(port, v_trigger, Bool(1));
}
}
stream << "\n";
}
void CodeGenVerilog::MakeChannelUnit(const ChannelEntry& ch) {
if (ch.block->read_window == 0) {
// This is a memory map
MakeChannelMemMap(ch);
} else if (ch.block->read_window == 1 &&
ch.block->write_window == 1) {
MakeChannelFIFO(ch);
} else {
// general Buffer
MakeChannelBuffer(ch);
}
}
void CodeGenVerilog::MakeChannelMemMap(const ChannelEntry& ch) {
Var ch_var = ch.block->channel->handle_var;
std::string dut = GetUniqueName(ch_var->name_hint + ".mmap");
std::string mmap_addr = GetVarID(ch_var.get());
tvm_vpi_modules_.push_back(dut);
if (ch.ports.count("read_addr")) {
CHECK(!ch.ports.count("write_addr"))
<< "Cannot read/write to same RAM";
const PortEntry& read_addr = ch.GetPort("read_addr");
const PortEntry& read_data = ch.GetPort("read_data");
const PortEntry& read_valid = ch.GetPort("read_valid");
stream << " // channel setup for " << ch_var << "\n"
<< " tvm_vpi_read_mmap # (\n"
<< " .DATA_WIDTH(" << read_data.dtype.bits() << "),\n"
<< " .ADDR_WIDTH(" << read_addr.dtype.bits() << "),\n"
<< " .BASE_ADDR_WIDTH(" << ch_var.type().bits() << ")\n"
<< " ) " << dut << " (\n"
<< " .clk(clk),\n"
<< " .rst(rst),\n"
<< " .addr(" << read_addr.value << "),\n"
<< " .data_out(" << read_data.value << "),\n"
<< " .mmap_addr(" << mmap_addr << ")\n"
<< " );\n";
PrintAssign(read_valid.value, "1");
} else if (ch.ports.count("write_addr")) {
const PortEntry& write_addr = ch.GetPort("write_addr");
const PortEntry& write_data = ch.GetPort("write_data");
const PortEntry& write_valid = ch.GetPort("write_valid");
const PortEntry& write_ready = ch.GetPort("write_ready");
stream << " // channel setup for " << ch_var << "\n"
<< " tvm_vpi_write_mmap # (\n"
<< " .DATA_WIDTH(" << write_data.dtype.bits() << "),\n"
<< " .ADDR_WIDTH(" << write_addr.dtype.bits() << "),\n"
<< " .BASE_ADDR_WIDTH(" << ch_var.type().bits() << ")\n"
<< " ) " << dut << " (\n"
<< " .clk(clk),\n"
<< " .rst(rst),\n"
<< " .addr(" << write_addr.value << "),\n"
<< " .data_in(" << write_data.value << "),\n"
<< " .en(" << write_valid.value << "),\n"
<< " .mmap_addr(" << mmap_addr << ")\n"
<< " );\n";
PrintAssign(write_ready.value, "1");
// additional control signals
for (size_t i = 0; i < ch.block->ctrl_signals.size(); ++i) {
ControlSignal sig = ch.block->ctrl_signals[i];
CHECK_EQ(sig->ctrl_type, kComputeFinish);
std::string port = ch.SignalPortName(i);
done_sigs_.push_back(ch.GetPort(port).value);
}
}
}
void CodeGenVerilog::MakeChannelFIFO(const ChannelEntry& ch) {
Var ch_var = ch.block->channel->handle_var;
std::string dut = GetUniqueName(ch_var->name_hint + ".fifo_reg");
const PortEntry& write_data = ch.GetPort("write_data");
const PortEntry& write_valid = ch.GetPort("write_valid");
const PortEntry& write_ready = ch.GetPort("write_ready");
const PortEntry& read_data = ch.GetPort("read_data");
const PortEntry& read_valid = ch.GetPort("read_valid");
const PortEntry& read_ready = ch.GetPort("read_ready");
CHECK_EQ(write_data.dtype, read_data.dtype);
stream << " // channel setup for " << ch_var << "\n"
<< " `CACHE_REG(" << write_data.dtype.bits()
<< ", " << write_data.value
<< ", " << write_valid.value
<< ", " << write_ready.value
<< ", " << read_data.value
<< ", " << read_valid.value
<< ", " << read_ready.value
<< ")\n";
}
void CodeGenVerilog::MakeChannelBuffer(const ChannelEntry& ch) {
LOG(FATAL) << "not implemeneted";
}
CodeGenVerilog::ChannelEntry*
CodeGenVerilog::GetChannelInfo(const Variable* var) {
auto it = cmap_.find(var);
CHECK(it != cmap_.end())
<< "cannot find channel for var " << var->name_hint;
return &(it->second);
}
void CodeGenVerilog::ChannelEntry::AssignPort(
std::string port, std::string value, Type dtype) {
CHECK(!ports.count(port))
<< "port " << port
<< " of channel " << block->channel << " has already been connected";
ports[port] = PortEntry{value, dtype};
}
const CodeGenVerilog::PortEntry&
CodeGenVerilog::ChannelEntry::GetPort(const std::string& port) const {
auto it = ports.find(port);
CHECK(it != ports.end())
<< "port " << port
<< " of channel " << block->channel << " has not been connected";
return it->second;
}
std::string CodeGenVerilog::ChannelEntry::SignalPortName(int index) const {
CHECK_LT(static_cast<size_t>(index), block->ctrl_signals.size());
std::ostringstream os;
os << "ctrl_port" << index;
return os.str();
}
} // namespace verilog
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_verilog.h
* \brief Generate verilog code.
*/
#ifndef TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
#define TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
#include <tvm/base.h>
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <string>
#include <vector>
#include <unordered_map>
#include "./verilog_ir.h"
#include "../codegen_source_base.h"
namespace tvm {
namespace codegen {
namespace verilog {
using namespace ir;
/* \brief The variable type in register.*/
enum VerilogVarType {
kWire,
kInput,
kOutput,
kReg,
kConst
};
/*! \brief The verilog value */
struct VerilogValue {
/*! \brief The variable id */
std::string vid;
/*! \brief The variable type */
VerilogVarType vtype{kReg};
/*! \brief The data type it encodes */
Type dtype;
VerilogValue() {}
VerilogValue(std::string vid, VerilogVarType vtype, Type dtype)
: vid(vid), vtype(vtype), dtype(dtype) {}
};
/*! \brief Information of each procedure function generated */
struct VerilogFuncEntry {
/*! \brief The original functions */
std::vector<Type> arg_types;
/*! \brief The real argument ids of the function */
std::vector<std::string> arg_ids;
/*! \brief The VPI Modules in the function */
std::vector<std::string> vpi_modules;
};
/*!
* \brief The code module of generated verilog code.
*/
class VerilogCodeGenModule {
public:
/*! \brief the code of each modoules */
std::string code;
/*! \brief map of functions */
std::unordered_map<std::string, VerilogFuncEntry> fmap;
/*!
* \brief Generate a code that append simulator function to call func_name.
* \param func_name The function to be called.
* \return The generated code.
*/
std::string AppendSimMain(const std::string& func_name) const;
};
/*!
* \brief Verilog generator
*/
class CodeGenVerilog :
public ExprFunctor<VerilogValue(const Expr&)>,
public CodeGenSourceBase {
public:
/*!
* \brief Initialize the code generator.
* \param output_ssa Whether output SSA.
*/
void Init();
/*!
* \brief Add the function to the generated module.
* \param f The function to be compiled.
*/
void AddFunction(LoweredFunc f);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
VerilogCodeGenModule Finish();
/*!
* \brief Transform expression to verilog value.
* \param n The expression to be printed.
*/
VerilogValue MakeValue(const Expr& n) {
return VisitExpr(n);
}
// The following parts are overloadable print operations.
// expression
VerilogValue VisitExpr_(const Variable* op) final;
VerilogValue VisitExpr_(const Let* op) final;
VerilogValue VisitExpr_(const Call* op) final;
VerilogValue VisitExpr_(const Add* op) final;
VerilogValue VisitExpr_(const Sub* op) final;
VerilogValue VisitExpr_(const Mul* op) final;
VerilogValue VisitExpr_(const Div* op) final;
VerilogValue VisitExpr_(const Mod* op) final;
VerilogValue VisitExpr_(const Min* op) final;
VerilogValue VisitExpr_(const Max* op) final;
VerilogValue VisitExpr_(const EQ* op) final;
VerilogValue VisitExpr_(const NE* op) final;
VerilogValue VisitExpr_(const LT* op) final;
VerilogValue VisitExpr_(const LE* op) final;
VerilogValue VisitExpr_(const GT* op) final;
VerilogValue VisitExpr_(const GE* op) final;
VerilogValue VisitExpr_(const And* op) final;
VerilogValue VisitExpr_(const Or* op) final;
VerilogValue VisitExpr_(const Cast* op) final;
VerilogValue VisitExpr_(const Not* op) final;
VerilogValue VisitExpr_(const Select* op) final;
VerilogValue VisitExpr_(const Ramp* op) final;
VerilogValue VisitExpr_(const Broadcast* op) final;
VerilogValue VisitExpr_(const IntImm* op) final;
VerilogValue VisitExpr_(const UIntImm* op) final;
VerilogValue VisitExpr_(const FloatImm* op) final;
VerilogValue VisitExpr_(const StringImm* op) final;
protected:
void InitFuncState(LoweredFunc f);
void PrintDecl(const std::string& vid, VerilogVarType vtype, Type dtype,
const char* suffix = ";\n", bool indent = true);
void PrintAssign(
const std::string& target, const std::string& src);
void PrintAssignAnd(
const std::string& target, const std::vector<std::string>& conds);
void PrintLine(const std::string& line);
void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) final;
// make binary op
VerilogValue MakeBinary(Type t, VerilogValue a, VerilogValue b, const char* opstr);
private:
// Hand shake signal name.
// These name can be empty.
// Indicate that the signal is always true
// or do not need to take these signals.
struct SignalEntry {
std::string valid;
std::string ready;
};
// Information about port
struct PortEntry {
// The port value
std::string value;
// The data type
Type dtype;
};
// Channel setup
struct ChannelEntry {
// The channel block
ChannelBlock block;
// The port map, on how port is assigned.
std::unordered_map<std::string, PortEntry> ports;
// Assign port to be valueo
void AssignPort(std::string port, std::string value, Type dtype);
// Assign port to be valueo
const PortEntry& GetPort(const std::string& port) const;
// Signal port name
std::string SignalPortName(int index) const;
};
// Get wire ssa value from s
VerilogValue GetSSAValue(std::string s, Type dtype) {
VerilogValue ret;
ret.vid = SSAGetID(s, dtype);
ret.vtype = kWire;
ret.dtype = dtype;
return ret;
}
void CodeGen(const Pipeline& pipeine);
// codegen the delays
void MakeDelay(const std::string& dst,
const std::string& src,
Type dtype,
int delay,
const std::string& not_stall);
// codegen the loop macros
SignalEntry MakeLoop(const Array<Stmt>& loop);
// codegen the loop macros
void MakeStageInputs(const ComputeBlock& block,
const std::string& not_stall,
std::string* out_all_input_valid);
// codegen compute block
void MakeStore(const ComputeBlock& block, const Store* store);
// Codegen of load statement into FIFO
void MakeLoadToFIFO(const ComputeBlock& block,
const Store* store,
const Load* load);
// Make channel unit.
void MakeChannelUnit(const ChannelEntry& ch);
void MakeChannelFIFO(const ChannelEntry& ch);
void MakeChannelBuffer(const ChannelEntry& ch);
void MakeChannelMemMap(const ChannelEntry& ch);
// Get channel information
ChannelEntry* GetChannelInfo(const Variable* var);
// channel setup map.
std::unordered_map<const Variable*, ChannelEntry> cmap_;
// list of vpi modules to be hooked.
std::vector<std::string> tvm_vpi_modules_;
// The signals for done.
std::vector<std::string> done_sigs_;
// The verilog function.
std::unordered_map<std::string, VerilogFuncEntry> functions_;
};
} // namespace verilog
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_VERILOG_CODEGEN_VERILOG_H_
/*!
* Copyright (c) 2017 by Contributors
* \file verilog_ir.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include "./verilog_ir.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
namespace verilog {
using namespace ir;
ControlSignal ControlSignalNode::make(
ControlSignalType type, int advance_size) {
auto n = std::make_shared<ControlSignalNode>();
n->ctrl_type = type;
n->advance_size = advance_size;
return ControlSignal(n);
}
StageInput StageInputNode::make(Var var, StageInputType input_type) {
std::shared_ptr<StageInputNode> n = std::make_shared<StageInputNode>();
n->var = var;
n->input_type = input_type;
return StageInput(n);
}
// Replace stage inputs by placeholder, update the input map.
class StageInputReplacer : public IRMutator {
public:
explicit StageInputReplacer(
const std::unordered_map<const Variable*, StageInput>& var_info)
: var_info_(var_info) {}
Expr Mutate_(const Variable* op, const Expr& e) final {
if (replace_.count(op)) {
return replace_.at(op);
}
auto it = var_info_.find(op);
if (it == var_info_.end()) return e;
Var new_var(it->second->var->name_hint + ".sync", op->type);
inputs_.Set(new_var, it->second);
replace_[op] = new_var;
return new_var;
}
Expr Mutate_(const Load* op, const Expr& e) final {
CHECK(is_zero(op->index))
<< "Load should be in its own stage.";
if (replace_.count(op->buffer_var.get())) {
return replace_.at(op->buffer_var.get());
}
auto it = var_info_.find(op->buffer_var.get());
CHECK(it != var_info_.end())
<< "Load from unknown channel";
Var data(it->second->var->name_hint + ".load.sync", op->type);
inputs_.Set(data, it->second);
replace_[op->buffer_var.get()] = data;
return data;
}
// inputs that get replaced.
Map<Var, StageInput> inputs_;
// replacement map
std::unordered_map<const Variable*, Var> replace_;
// Variable replacement plan.
const std::unordered_map<const Variable*, StageInput>& var_info_;
};
/*! \brief Extract module block */
class PipelineExtractor: public IRVisitor {
public:
Pipeline Extract(LoweredFunc f) {
// Initialize the memory map channels
// TODO(tqchen) move the logic to explicit specification.
for (auto arg : f->args) {
if (arg.type().is_handle()) {
arg_handle_[arg.get()] = arg;
}
}
pipeline_ = std::make_shared<PipelineNode>();
this->Visit(f->body);
// setup channels
for (const auto &kv : cmap_) {
pipeline_->channels.Set(
kv.second.node->channel->handle_var,
ChannelBlock(kv.second.node));
}
pipeline_->args = f->args;
return Pipeline(pipeline_);
}
void Visit_(const AttrStmt* op) final {
if (op->type_key == attr::pipeline_stage_scope) {
CHECK(!in_pipeline_stage_);
in_pipeline_stage_ = true;
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
in_pipeline_stage_ = false;
} else if (op->type_key == attr::channel_read_advance ||
op->type_key == attr::channel_write_advance) {
trigger_.emplace_back(std::make_pair(loop_.size(), op));
IRVisitor::Visit_(op);
trigger_.pop_back();
} else if (op->type_key == attr::channel_read_scope ||
op->type_key == attr::channel_write_scope) {
Channel ch(op->node.node_);
ChannelEntry& cb = cmap_[ch->handle_var.get()];
if (cb.node != nullptr) {
CHECK(cb.node->channel.same_as(ch));
} else {
cb.node = std::make_shared<ChannelBlockNode>();
cb.node->channel = ch;
}
if (op->type_key == attr::channel_read_scope) {
CHECK_EQ(cb.read_ref_count, 0)
<< "One channel can only be read from one consumer";
++cb.read_ref_count;
CHECK(arith::GetConstInt(op->value, &(cb.node->read_window)))
<< "Only supprt constant read window";
} else {
CHECK_EQ(cb.write_ref_count, 0)
<< "One channel can only be write by one producer";
++cb.write_ref_count;
CHECK(arith::GetConstInt(op->value, &(cb.node->write_window)))
<< "Only supprt constant write window";
}
var_info_[ch->handle_var.get()] =
StageInputNode::make(ch->handle_var, kChannel);
IRVisitor::Visit_(op);
var_info_.erase(ch->handle_var.get());
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Block* op) final {
CHECK(!in_pipeline_stage_)
<< "Do not support serial execution inside pipeline";
IRVisitor::Visit_(op);
}
void Visit_(const IfThenElse* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const For* op) final {
if (in_pipeline_stage_) {
loop_.push_back(
For::make(op->loop_var, op->min, op->extent,
op->for_type, op->device_api, Evaluate::make(0)));
var_info_[op->loop_var.get()] =
StageInputNode::make(Var(op->loop_var.node_), kLoopVar);
IRVisitor::Visit_(op);
var_info_.erase(op->loop_var.get());
loop_.pop_back();
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Store* op) final {
// Check the access pattern
Channel arg_write =
CheckArgHandleAccess(op->buffer_var.get(), op->value.type(), false);
this->Visit(op->value);
// The replace logic
StageInputReplacer repl(var_info_);
// Setup the compute block.
std::shared_ptr<ComputeBlockNode> compute =
std::make_shared<ComputeBlockNode>();
compute->loop = Array<Stmt>(loop_);
// setup the advance triggers
for (const auto& e : trigger_) {
const AttrStmt* attr = e.second;
Channel ch;
if (attr->type_key == attr::pipeline_stage_scope) {
ch = arg_write;
if (!ch.defined()) continue;
} else {
ch = Channel(attr->node.node_);
}
std::shared_ptr<SignalTriggerNode> trigger
= std::make_shared<SignalTriggerNode>();
trigger->channel_var = ch->handle_var;
// predicate for the trigger
Expr predicate = const_true();
for (size_t i = e.first; i < loop_.size(); ++i) {
const For* loop = loop_[i].as<For>();
predicate = predicate &&
(loop->loop_var == (loop->extent - 1));
}
trigger->predicate = ir::Simplify(predicate);
// Add the signal back to the channels.
ChannelEntry& cb = cmap_.at(ch->handle_var.get());
trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size());
// Grab the advance constant size.
int trigger_size;
if (attr->type_key == attr::pipeline_stage_scope) {
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kComputeFinish, 0));
} else if (attr->type_key == attr::channel_read_advance) {
CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size";
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kReadAdvance, trigger_size));
} else {
CHECK(arith::GetConstInt(attr->value, &trigger_size))
<< "Only support constant advance size";
cb.node->ctrl_signals.push_back(
ControlSignalNode::make(kWriteAdvance, trigger_size));
}
compute->triggers.push_back(SignalTrigger(trigger));
}
// Check if we are writing to FIFO.
const Load* load = op->value.as<Load>();
if (is_zero(op->index) && load) {
compute->body = Store::make(
op->buffer_var,
Load::make(load->type, load->buffer_var, repl.Mutate(load->index)),
op->index);
} else {
compute->body = Store::make(
op->buffer_var, repl.Mutate(op->value), repl.Mutate(op->index));
}
compute->inputs = repl.inputs_;
pipeline_->stages.push_back(ComputeBlock(compute));
}
void Visit_(const LetStmt* op) final {
LOG(FATAL) << "cannot pass through let";
}
void Visit_(const Evaluate* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const Allocate* op) final {
CHECK(!in_pipeline_stage_);
}
void Visit_(const AssertStmt* op) final {
LOG(FATAL) << "Not implemeneted";
}
void Visit_(const Load* op) final {
CheckArgHandleAccess(op->buffer_var.get(), op->type, true);
}
Channel CheckArgHandleAccess(const Variable* var, Type dtype, bool read_access) {
if (!arg_handle_.count(var)) return Channel();
CHECK(!cmap_.count(var))
<< "Multiple access to the same handle";
ChannelEntry& cb = cmap_[var];
cb.node = std::make_shared<ChannelBlockNode>();
cb.node->channel = ChannelNode::make(arg_handle_.at(var), dtype);
return cb.node->channel;
}
private:
// The channel information.
struct ChannelEntry {
std::shared_ptr<ChannelBlockNode> node;
int read_ref_count{0};
int write_ref_count{0};
};
// Whether we are inside the pipeline stage.
bool in_pipeline_stage_{false};
// The current loop nest
std::vector<Stmt> loop_;
// Advance signal trigger
std::vector<std::pair<size_t, const AttrStmt*> > trigger_;
// Read write scope
std::vector<const AttrStmt*> channel_scope_;
// The loop index.
std::unordered_map<const Variable*, StageInput> var_info_;
// The channel entry;
std::unordered_map<const Variable*, ChannelEntry> cmap_;
// The argument handle map
std::unordered_map<const Variable*, Var> arg_handle_;
// The result block.
std::shared_ptr<PipelineNode> pipeline_;
};
Pipeline MakePipeline(LoweredFunc f) {
return PipelineExtractor().Extract(f);
}
} // namespace verilog
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file verilog_ir.h
* \brief A lowered IR that resembles verilog blocks,
* This is data structure before final codegen.
*/
#ifndef TVM_CODEGEN_VERILOG_VERILOG_IR_H_
#define TVM_CODEGEN_VERILOG_VERILOG_IR_H_
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/channel.h>
#include <tvm/lowered_func.h>
#include <vector>
#include <memory>
#include <unordered_map>
namespace tvm {
namespace codegen {
namespace verilog {
/*! \brief The data argument type */
enum StageInputType : int {
/*! \brief Data channel input. */
kChannel,
/*! \brief Loop variable generated by compute block. */
kLoopVar,
/*! \brief Global constant. */
kGlobalConst
};
/*! \brief The data argument type */
enum ControlSignalType : int {
// Read advance signal
kReadAdvance,
// Write advance signal
kWriteAdvance,
// Pipeline stage finish signal
kComputeFinish
};
class ControlSignal;
class StageInput;
class SignalTrigger;
/*! \brief The control signal of a channel */
struct ControlSignalNode : public Node {
/*! \brief The control signal type */
ControlSignalType ctrl_type;
/*! \brief Advance size of the signal */
int advance_size{0};
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("ctrl_type", &ctrl_type);
v->Visit("advance_size", &advance_size);
}
static ControlSignal make(ControlSignalType ctrl_type, int advance_size);
static constexpr const char* _type_key = "VerilogControlSignal";
TVM_DECLARE_NODE_TYPE_INFO(ControlSignalNode, Node);
};
TVM_DEFINE_NODE_REF(ControlSignal, ControlSignalNode);
/*! \brief Information about channel. */
struct ChannelBlockNode : public Node {
/*! \brief The channel we are refer to */
Channel channel;
/*! \brief Read window */
int read_window{0};
/*! \brief Write window */
int write_window{0};
/*! \brief Control signals in the channel */
Array<ControlSignal> ctrl_signals;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("channel", &channel);
v->Visit("read_window", &read_window);
v->Visit("write_window", &write_window);
v->Visit("ctrl_signals", &ctrl_signals);
}
static constexpr const char* _type_key = "VerilogChannelBlock";
TVM_DECLARE_NODE_TYPE_INFO(ChannelBlockNode, Node);
};
TVM_DEFINE_NODE_REF(ChannelBlock, ChannelBlockNode);
/*!
* \brief Input to the compute block.
* These represents the data values that need to be shared;
*/
struct StageInputNode : public Node {
/*!
* \brief The corresponding var of the input
* For loop and global const it is the var.
* For channel this corresponds to the channel handle.
*/
Var var;
/*! \brief The type of the input. */
StageInputType input_type;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("input_type", &input_type);
}
// constructor
static StageInput make(Var var, StageInputType input_type);
static constexpr const char* _type_key = "VerilogStageInput";
TVM_DECLARE_NODE_TYPE_INFO(StageInputNode, Node);
};
TVM_DEFINE_NODE_REF(StageInput, StageInputNode);
/*! \brief The trigger signal for certain channel */
struct SignalTriggerNode : public Node {
/*! \brief The channel handle variable */
Var channel_var;
/*! \brief Boolean predicate to trigger the signal */
Expr predicate;
/*! \brief siginal index of the channel */
int signal_index;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("channel_var", &channel_var);
v->Visit("predicate", &predicate);
v->Visit("signal_index", &signal_index);
}
// constructor
static constexpr const char* _type_key = "VerilogSignalTrigger";
TVM_DECLARE_NODE_TYPE_INFO(SignalTriggerNode, Node);
};
TVM_DEFINE_NODE_REF(SignalTrigger, SignalTriggerNode);
/*! \brief compute block for verilog */
struct ComputeBlockNode : public Node {
/*! \brief The body of the block. */
Stmt body;
/*! \brief The loop nest around the body, each is a For with no_op as body */
Array<Stmt> loop;
/*! \brief The channel advance trigger */
Array<SignalTrigger> triggers;
/*! \brief The input variables that need to be synced. */
Map<Var, StageInput> inputs;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("body", &body);
v->Visit("loop", &loop);
v->Visit("triggers", &triggers);
v->Visit("inputs", &inputs);
}
static constexpr const char* _type_key = "VerilogComputeBlock";
TVM_DECLARE_NODE_TYPE_INFO(ComputeBlockNode, Node);
};
TVM_DEFINE_NODE_REF(ComputeBlock, ComputeBlockNode);
/*! \brief Codeblock for verilog module. */
struct PipelineNode : public Node {
/*! \brief arguments to the module */
Array<Var> args;
/*! \brief Computation stages */
Array<ComputeBlock> stages;
/*! \brief The data channels */
Map<Var, ChannelBlock> channels;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("stages", &stages);
v->Visit("channels", &channels);
}
static constexpr const char* _type_key = "VerilogPipeline";
TVM_DECLARE_NODE_TYPE_INFO(PipelineNode, Node);
};
TVM_DEFINE_NODE_REF(Pipeline, PipelineNode);
/*!
* \brief Build a lowered verilog pipeline given function.
* \param f The function to be transformed.
* \param The created verilog pipeline.
*/
Pipeline MakePipeline(LoweredFunc f);
} // namespace verilog
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_VERILOG_VERILOG_IR_H_
/*!
* Copyright (c) 2017 by Contributors
* \file verilog_module.cc
* \brief Build verilog source code.
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include <mutex>
#include "./codegen_verilog.h"
#include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h"
namespace tvm {
namespace codegen {
namespace verilog {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
// Simulator function
class VerilogModuleNode : public runtime::ModuleNode {
public:
VerilogModuleNode() : fmt_("v") {}
const char* type_key() const {
return "verilog";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
CHECK(sptr_to_self.get() == this);
if (name == runtime::symbol::tvm_entry_setdevice) {
return PackedFunc([](const TVMArgs& args, TVMRetValue* rv){});
}
CHECK(m_.fmap.count(name)) << "Cannot find function " << name << " in the module";
auto f = [sptr_to_self, name, this](const runtime::TVMArgs& args, TVMRetValue* rv) {
auto* fsim = runtime::Registry::Get("tvm_callback_verilog_simulator");
CHECK(fsim != nullptr)
<< "tvm_callback_verilog_simulator is not registered,"
<<" did you import tvm.addon.verilog?";
std::string code = m_.AppendSimMain(name);
if (const auto* f = runtime::Registry::Get("tvm_callback_verilog_postproc")) {
code = (*f)(code).operator std::string();
}
std::vector<TVMValue> values;
std::vector<int> codes;
TVMValue v;
v.v_str = code.c_str();
values.push_back(v);
codes.push_back(kStr);
for (int i = 0; i < args.num_args; ++i) {
values.push_back(args.values[i]);
codes.push_back(args.type_codes[i]);
}
fsim->CallPacked(TVMArgs(&values[0], &codes[0], args.num_args + 1), rv);
};
return PackedFunc(f);
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "not implemented";
}
std::string GetSource(const std::string& format) final {
return m_.code;
}
void Init(const Array<LoweredFunc>& funcs) {
CodeGenVerilog cg;
cg.Init();
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
m_ = cg.Finish();
}
private:
// the verilog code. data
VerilogCodeGenModule m_;
// format;
std::string fmt_;
};
TVM_REGISTER_API(_codegen_build_verilog)
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<VerilogModuleNode> n =
std::make_shared<VerilogModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace verilog
} // namespace codegen
} // namespace tvm
...@@ -186,8 +186,8 @@ class VPIMemoryInterface { ...@@ -186,8 +186,8 @@ class VPIMemoryInterface {
read_unit_bytes_ = read_bits / 8U; read_unit_bytes_ = read_bits / 8U;
write_unit_bytes_ = write_bits / 8U; write_unit_bytes_ = write_bits / 8U;
} }
// Callback at post-edge. // Callback at neg-edge.
void AtPosEedge() { void AtNegEdge() {
// reset // reset
if (in_rst_.get_int()) { if (in_rst_.get_int()) {
CHECK_EQ(pending_read_.size, 0U); CHECK_EQ(pending_read_.size, 0U);
...@@ -358,7 +358,7 @@ class VPIReadMemMap : public VPIMemMapBase { ...@@ -358,7 +358,7 @@ class VPIReadMemMap : public VPIMemMapBase {
void Init(VPIHandle module) { void Init(VPIHandle module) {
VPIMemMapBase::Init(module, "reg_data"); VPIMemMapBase::Init(module, "reg_data");
} }
void AtPosEedge() { void AtNegEdge() {
void* ptr = RealAddr(); void* ptr = RealAddr();
if (ptr == nullptr) return; if (ptr == nullptr) return;
size_t nwords = (unit_bytes_ + 3) / 4; size_t nwords = (unit_bytes_ + 3) / 4;
...@@ -373,7 +373,7 @@ class VPIWriteMemMap : public VPIMemMapBase { ...@@ -373,7 +373,7 @@ class VPIWriteMemMap : public VPIMemMapBase {
VPIMemMapBase::Init(module, "data_in"); VPIMemMapBase::Init(module, "data_in");
enable_ = module["en"]; enable_ = module["en"];
} }
void AtPosEedge() { void AtNegEdge() {
if (!enable_.get_int() || rst_.get_int()) return; if (!enable_.get_int() || rst_.get_int()) return;
void* ptr = RealAddr(); void* ptr = RealAddr();
CHECK(ptr != nullptr) CHECK(ptr != nullptr)
...@@ -398,7 +398,7 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) { ...@@ -398,7 +398,7 @@ void TVMVPIHook(runtime::TVMArgs args, runtime::TVMRetValue* rv) {
p->Init(m); p->Init(m);
LOG(INFO) << "Hook " << m.name() << " to tvm vpi simulation..."; LOG(INFO) << "Hook " << m.name() << " to tvm vpi simulation...";
PackedFunc pf([p](const runtime::TVMArgs&, runtime::TVMRetValue*) { PackedFunc pf([p](const runtime::TVMArgs&, runtime::TVMRetValue*) {
p->AtPosEedge(); p->AtNegEdge();
}); });
*rv = pf; *rv = pf;
} }
......
...@@ -139,13 +139,25 @@ MakeLoopNest(const Stage& stage, ...@@ -139,13 +139,25 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op)); AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var; value_map[iv] = var;
} else if (iv->thread_tag == "pipeline") {
// pipeline marker.
CHECK(is_zero(dom->min));
CHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
} else { } else {
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op)); AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
value_map[iv] = var; if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = var;
}
} }
// annotate the extent of the IterVar // annotate the extent of the IterVar
if (!new_loop_var) { if (!new_loop_var) {
......
/*!
* Copyright (c) 2017 by Contributors
* \file narrow_channel_access.cc
* \brief Narrow channel access to a smaller range
* when possible by bringing it to the internal loop.
*/
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <tvm/channel.h>
#include "./ir_util.h"
namespace tvm {
namespace ir {
using namespace arith;
// Bound deducer for channel access.
class ChannelAccessBound : public IRVisitor {
public:
ChannelAccessBound(const Variable* buf_var, bool read_access)
: buf_var_(buf_var), read_access_(read_access) {}
void Visit_(const Store* op) final {
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
ret_.emplace_back(EvalSet(op->index, dom_map_));
}
IRVisitor::Visit_(op);
}
void Visit_(const For* op) final {
CHECK(is_zero(op->min));
// We know that the extent of the loop won't depend on relaxed scope.
// TODO(tqchen) have a verification pass.
dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->extent - 1);
IRVisitor::Visit_(op);
}
void Visit_(const Load* op) final {
if (read_access_ && buf_var_ == op->buffer_var.get()) {
ret_.emplace_back(EvalSet(op->index, dom_map_));
}
IRVisitor::Visit_(op);
}
void Visit_(const Let* op) final {
LOG(FATAL) << "cannot pass through let";
}
void Visit_(const LetStmt* op) final {
LOG(FATAL) << "cannot pass through let";
}
IntSet Eval(const Stmt& stmt) {
Visit(stmt);
return Union(ret_);
}
private:
// The buffer variable.
const Variable* buf_var_;
// read or write
bool read_access_{true};
// Box
std::vector<IntSet> ret_;
// Domain map.
std::unordered_map<const Variable*, IntSet> dom_map_;
};
class ChannelAccessIndexRewriter : public IRMutator {
public:
ChannelAccessIndexRewriter(const Variable* buf_var,
Expr min,
bool read_access)
: buf_var_(buf_var), min_(min), read_access_(read_access) {}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
if (read_access_ && buf_var_ == op->buffer_var.get()) {
return Load::make(
op->type, op->buffer_var, ir::Simplify(op->index - min_));
} else {
return expr;
}
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
return Store::make(
op->buffer_var, op->value, ir::Simplify(op->index - min_));
} else {
return stmt;
}
}
private:
// The buffer variable.
const Variable* buf_var_;
// The min bound.
Expr min_;
// read or write
bool read_access_{true};
};
// Rewrite channel access pattern.
class ChannelAccessRewriter : public IRMutator {
public:
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt ret;
const AttrStmt* adv = op->body.as<AttrStmt>();
if ((op->type_key == ir::attr::channel_read_scope &&
adv && adv->type_key == ir::attr::channel_read_advance) ||
(op->type_key == ir::attr::channel_write_scope &&
adv && adv->type_key == ir::attr::channel_write_advance)) {
RewriteEntry e;
e.window = op;
e.advance = adv;
e.read_access = op->type_key == ir::attr::channel_read_scope;
tasks_.push_back(e);
ret = IRMutator::Mutate_(op, s);
if (tasks_.back().rewrite_success) {
ret = ret.as<AttrStmt>()->body.as<AttrStmt>()->body;
}
tasks_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
std::vector<RewriteEntry> tasks;
std::swap(tasks_, tasks);
Stmt body = op->body;
std::vector<Stmt> nest;
for (RewriteEntry& e : tasks) {
body = RewriteAccess(op, body, &e, &nest);
}
if (!body.same_as(op->body)) {
body = Mutate(body);
body = For::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, body);
body = MergeNest(nest, body);
} else {
CHECK_EQ(nest.size(), 0U);
body = IRMutator::Mutate_(op, s);
}
std::swap(tasks_, tasks);
return body;
}
private:
struct RewriteEntry {
bool read_access;
const AttrStmt* window;
const AttrStmt* advance;
bool rewrite_success{false};
};
Stmt RewriteAccess(const For* for_op,
Stmt body,
RewriteEntry* e,
std::vector<Stmt>* outer_nest) {
const AttrStmt* adv_op = e->advance;
const Expr& window = e->window->value;
bool read_access = e->read_access;
Var var(for_op->loop_var);
Channel ch(adv_op->node.node_);
ChannelAccessBound acc(ch->handle_var.get(), read_access);
IntSet iset = acc.Eval(for_op->body);
Range r = iset.cover_range(Range::make_with_min_extent(0, window));
r = Range::make_with_min_extent(
ir::Simplify(r->min), ir::Simplify(r->extent));
if (ExprUseVar(r->extent, var)) return body;
Array<Expr> linear_eq = DetectLinearEquation(r->min, var);
if (linear_eq.size() == 0) return body;
Expr base = linear_eq[0];
Expr coeff = linear_eq[1];
if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body;
// rewrite access index.
ChannelAccessIndexRewriter rw(
ch->handle_var.get(), var * coeff, read_access);
body = rw.Mutate(body);
if (read_access) {
body = AttrStmt::make(
ch, ir::attr::channel_read_scope, r->extent,
AttrStmt::make(ch, ir::attr::channel_read_advance, coeff,
body));
} else {
body = AttrStmt::make(
ch, ir::attr::channel_write_scope, r->extent,
AttrStmt::make(ch, ir::attr::channel_write_advance, coeff,
body));
}
if (!is_zero(left)) {
Stmt no_op = Evaluate::make(0);
if (read_access) {
outer_nest->emplace_back(
AttrStmt::make(ch, ir::attr::channel_read_advance, left, no_op));
} else {
outer_nest->emplace_back(
AttrStmt::make(ch, ir::attr::channel_write_advance, left, no_op));
}
}
e->rewrite_success = true;
return body;
}
std::vector<RewriteEntry> tasks_;
};
Stmt NarrowChannelAccess(Stmt stmt) {
return ChannelAccessRewriter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -15,6 +15,7 @@ class IRSideEffect : public IRVisitor { ...@@ -15,6 +15,7 @@ class IRSideEffect : public IRVisitor {
public: public:
void Visit(const NodeRef& e) final { void Visit(const NodeRef& e) final {
if (has_side_effect_) return; if (has_side_effect_) return;
IRVisitor::Visit(e);
} }
void Visit_(const Call* op) final { void Visit_(const Call* op) final {
...@@ -55,5 +56,39 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) { ...@@ -55,5 +56,39 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
} }
return m.Mutate(stmt); return m.Mutate(stmt);
} }
class ExprUseVarVisitor : public IRVisitor {
public:
explicit ExprUseVarVisitor(const Variable* var)
: var_(var) {}
void Visit(const NodeRef& e) final {
if (use_var_) return;
IRVisitor::Visit(e);
}
void Visit_(const Variable* op) final {
if (op == var_) {
use_var_ = true;
}
}
void Visit_(const Load* op) final {
if (op->buffer_var.get() == var_) {
use_var_ = true;
}
IRVisitor::Visit_(op);
}
const Variable* var_;
bool use_var_{false};
};
bool ExprUseVar(const Expr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get());
visitor.Visit(e);
return visitor.use_var_;
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -147,8 +147,8 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") { if (op->type_key == attr::thread_extent ||
IterVar iv(op->node.node_); op->type_key == attr::pipeline_exec_scope) {
return SplitDeviceFunc(s); return SplitDeviceFunc(s);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
...@@ -195,7 +195,6 @@ class HostDeviceSplitter : public IRMutator { ...@@ -195,7 +195,6 @@ class HostDeviceSplitter : public IRMutator {
n->name = os.str(); n->name = os.str();
n->args = m.undefined_; n->args = m.undefined_;
n->thread_axis = m.thread_axis_; n->thread_axis = m.thread_axis_;
CHECK_NE(m.thread_extent_.size(), 0U);
// improve the handle data type // improve the handle data type
for (Var arg : n->args) { for (Var arg : n->args) {
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/channel.h> #include <tvm/channel.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "./ir_util.h" #include "./ir_util.h"
namespace tvm { namespace tvm {
...@@ -18,14 +19,38 @@ namespace ir { ...@@ -18,14 +19,38 @@ namespace ir {
class MarkChannelAccess : public IRMutator { class MarkChannelAccess : public IRMutator {
public: public:
MarkChannelAccess( MarkChannelAccess(
const std::unordered_map<const Variable*, Channel>& cmap) const std::unordered_map<const Variable*, Channel>& cmap,
: cmap_(cmap) {} const std::unordered_map<const Variable*, Channel>& fifo_map)
: cmap_(cmap), fifo_map_(fifo_map) {}
using IRMutator::Mutate;
Stmt Mutate(Stmt stmt) final {
Stmt ret = IRMutator::Mutate(stmt);
if (read_fifos_.size() != 0) {
for (const Variable* v : read_fifos_) {
Channel ch = fifo_map_.at(v);
ret = ReadChannel(ch, 1, ret);
}
read_fifos_.clear();
}
if (write_fifos_.size() != 0) {
for (const Variable* v : write_fifos_) {
Channel ch = fifo_map_.at(v);
ret = WriteChannel(ch, 1, ret);
}
write_fifos_.clear();
}
return ret;
}
Expr Mutate_(const Load *op, const Expr& e) final { Expr Mutate_(const Load *op, const Expr& e) final {
auto it = rmap_.find(op->buffer_var.get()); auto it = rmap_.find(op->buffer_var.get());
if (it != rmap_.end()) { if (it != rmap_.end()) {
++it->second.read_count; ++it->second.read_count;
} }
if (fifo_map_.count(op->buffer_var.get())) {
read_fifos_.insert(op->buffer_var.get());
CHECK(!write_fifos_.count(op->buffer_var.get()));
}
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
Stmt Mutate_(const Store *op, const Stmt& s) final { Stmt Mutate_(const Store *op, const Stmt& s) final {
...@@ -33,6 +58,10 @@ class MarkChannelAccess : public IRMutator { ...@@ -33,6 +58,10 @@ class MarkChannelAccess : public IRMutator {
if (it != rmap_.end()) { if (it != rmap_.end()) {
++it->second.write_count; ++it->second.write_count;
} }
if (fifo_map_.count(op->buffer_var.get())) {
write_fifos_.insert(op->buffer_var.get());
CHECK(!read_fifos_.count(op->buffer_var.get()));
}
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt Mutate_(const Allocate* op, const Stmt& s) final {
...@@ -79,51 +108,90 @@ class MarkChannelAccess : public IRMutator { ...@@ -79,51 +108,90 @@ class MarkChannelAccess : public IRMutator {
} }
if (rw.write_count) { if (rw.write_count) {
return AttrStmt::make( return WriteChannel(ch, alloc_size, body);
ch, ir::attr::channel_write_scope, alloc_size, body);
} else { } else {
CHECK(rw.read_count); CHECK(rw.read_count);
return AttrStmt::make( return ReadChannel(ch, alloc_size, body);
ch, ir::attr::channel_read_scope, alloc_size, body);
} }
} }
Stmt ReadChannel(Channel ch, Expr size, Stmt body) {
return AttrStmt::make(
ch, ir::attr::channel_read_scope, size,
AttrStmt::make(ch, ir::attr::channel_read_advance, size,
body));
}
Stmt WriteChannel(Channel ch, Expr size, Stmt body) {
return AttrStmt::make(
ch, ir::attr::channel_write_scope, size,
AttrStmt::make(ch, ir::attr::channel_write_advance, size,
body));
}
struct Entry { struct Entry {
int read_count{0}; int read_count{0};
int write_count{0}; int write_count{0};
}; };
// The channels of each allocation. // The channels of each allocation.
const std::unordered_map<const Variable*, Channel>& cmap_; const std::unordered_map<const Variable*, Channel>& cmap_;
// FIFO map.
const std::unordered_map<const Variable*, Channel>& fifo_map_;
// the result. // the result.
std::unordered_map<const Variable*, Entry> rmap_; std::unordered_map<const Variable*, Entry> rmap_;
// Accessed FIFOs
std::unordered_set<const Variable*> read_fifos_, write_fifos_;
}; };
// Mark the statment of each stage. // Mark the statment of each stage.
class StageSplitter : public IRMutator { class StageSplitter : public IRMutator {
public: public:
using IRMutator::Mutate;
explicit StageSplitter(bool split_load)
: split_load_(split_load) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
nest_.push_back(stmt); nest_.push_back(stmt);
Stmt ret = IRMutator::Mutate(stmt); Stmt ret = IRMutator::Mutate(stmt);
nest_.pop_back(); nest_.pop_back();
return ret; return ret;
} }
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) { Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (!op->is_producer) return IRMutator::Mutate_(op, s); if (!op->is_producer) {
return Mutate(op->body);
}
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
stages_.emplace_back(BuildStage(body, op->func)); stages_.emplace_back(BuildStage(body, op->func));
return Evaluate::make(0); return Evaluate::make(0);
} }
Expr Mutate_(const Load* op, const Expr& e) final {
if (!split_load_) return IRMutator::Mutate_(op, e);
std::ostringstream cname;
cname << "fifo." << temp_fifo_count_++;
// Create FIFO channel for load.
Channel ch = ChannelNode::make(Var(cname.str(), Handle()), op->type);
Expr index = Mutate(op->index);
Stmt provide = Store::make(
ch->handle_var,
Load::make(op->type, op->buffer_var, index), 0);
Stmt temp = nest_.back(); nest_.pop_back();
stages_.emplace_back(BuildStage(provide, ch));
nest_.push_back(temp);
fifo_map_[ch->handle_var.get()] = ch;
return Load::make(op->type, ch->handle_var, 0);
}
Stmt Split(Stmt stmt) { Stmt Split(Stmt stmt, const ProducerConsumer* env) {
stmt = Mutate(stmt); stmt = Mutate(stmt);
stmt = RemoveNoOp(stmt); if (env) {
CHECK(is_no_op(stmt)); stages_.emplace_back(BuildStage(stmt, env->func));
} else {
stmt = RemoveNoOp(stmt);
CHECK(is_no_op(stmt));
}
CHECK_NE(stages_.size(), 0); CHECK_NE(stages_.size(), 0);
stmt = stages_.back(); stmt = stages_.back();
for (size_t i = stages_.size() - 1; i != 0; --i) { for (size_t i = stages_.size() - 1; i != 0; --i) {
stmt = Block::make(stages_[i - 1], stmt); stmt = Block::make(stages_[i - 1], stmt);
} }
stmt = MarkChannelAccess(cmap_).Mutate(stmt); stmt = MarkChannelAccess(cmap_, fifo_map_).Mutate(stmt);
return RemoveNoOp(stmt); return RemoveNoOp(stmt);
} }
...@@ -184,10 +252,52 @@ class StageSplitter : public IRMutator { ...@@ -184,10 +252,52 @@ class StageSplitter : public IRMutator {
std::vector<Stmt> stages_; std::vector<Stmt> stages_;
// channel map // channel map
std::unordered_map<const Variable*, Channel> cmap_; std::unordered_map<const Variable*, Channel> cmap_;
// Whether split load into a temp fifo.
bool split_load_{true};
// Counter for temp FIFOs.
size_t temp_fifo_count_{0};
// fifo map
std::unordered_map<const Variable*, Channel> fifo_map_;
};
class PipelineSplitter : public IRMutator {
public:
explicit PipelineSplitter(bool split_load)
: split_load_(split_load) {}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::pipeline_exec_scope) {
CHECK_LE(env_.size(), 1U);
const ProducerConsumer* env = nullptr;
if (env_.size() == 1) {
std::swap(env_[0], env);
}
Stmt body = StageSplitter(split_load_).Split(
op->body, env);
if (body.same_as(op->body)) return s;
return AttrStmt::make(
op->node, op->type_key, op->value, body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
env_.push_back(op);
Stmt ret = IRMutator::Mutate_(op, s);
if (env_.back() == nullptr) {
ret = ret.as<ProducerConsumer>()->body;
}
env_.pop_back();
return ret;
}
private:
bool split_load_;
std::vector<const ProducerConsumer *> env_;
}; };
Stmt SplitPipeline(Stmt stmt) { Stmt SplitPipeline(Stmt stmt, bool split_load) {
return StageSplitter().Split(stmt); return PipelineSplitter(split_load).Mutate(stmt);
} }
} // namespace ir } // namespace ir
......
...@@ -283,8 +283,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -283,8 +283,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
if (fin == nullptr) { if (fin == nullptr) {
*out = new PackedFunc( *out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) { [func, resource_handle](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle); args.num_args, rv, resource_handle);
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
}); });
} else { } else {
// wrap it in a shared_ptr, with fin as deleter. // wrap it in a shared_ptr, with fin as deleter.
...@@ -292,8 +293,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -292,8 +293,9 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
std::shared_ptr<void> rpack(resource_handle, fin); std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc( *out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) { [func, rpack](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get()); args.num_args, rv, rpack.get());
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
}); });
} }
API_END(); API_END();
......
...@@ -275,8 +275,7 @@ void InferRootBound(const Stage& stage, ...@@ -275,8 +275,7 @@ void InferRootBound(const Stage& stage,
// special optimization to remove trivial loop // special optimization to remove trivial loop
if (is_one(vrange->extent)) { if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min); up_state[iv] = IntSet::single_point(vrange->min);
} } else if (fix_value && !ScopeRelax(iv, stage->scope)) {
if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::single_point(iv->var); up_state[iv] = IntSet::single_point(iv->var);
} else { } else {
up_state[iv] = IntSet::range(vrange); up_state[iv] = IntSet::range(vrange);
......
...@@ -26,7 +26,7 @@ def mybuild(fapi, target="llvm"): ...@@ -26,7 +26,7 @@ def mybuild(fapi, target="llvm"):
def test_dot(): def test_dot():
nn = 12 nn = 12
n = tvm.Var('n') n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
k = tvm.reduce_axis((0, n), 'k') k = tvm.reduce_axis((0, n), 'k')
......
import tvm
def test_basic():
a = tvm.Var("a")
b = tvm.Var("b")
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a)
assert m[1].value == 4
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a)
assert len(m) == 0
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a)
assert m[1].value == 5
assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
if __name__ == "__main__":
test_basic()
...@@ -29,3 +29,13 @@ def test_convert_ssa(): ...@@ -29,3 +29,13 @@ def test_convert_ssa():
assert(not tvm.ir_pass.VerifySSA(z)) assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z) z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa)) assert(tvm.ir_pass.VerifySSA(z_ssa))
def test_expr_use_var():
x = tvm.Var('x')
assert(tvm.ir_pass.ExprUseVar(x+1, x))
assert(not tvm.ir_pass.ExprUseVar(1+10, x))
if __name__ == "__main__":
test_expr_use_var()
import tvm import tvm
def lower(s, args):
binds = {}
arg_list = []
for x in args:
assert isinstance(x, tvm.tensor.Tensor)
buf = tvm.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
return stmt
def test_basic_pipeline(): def test_basic_pipeline():
n = tvm.convert(128) n = tvm.convert(128)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
...@@ -12,20 +29,37 @@ def test_basic_pipeline(): ...@@ -12,20 +29,37 @@ def test_basic_pipeline():
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k) B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=4) px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
xo, xi = s[B].split(xi, factor=4)
for S in stages: for S in stages:
s[S].compute_at(s[B], xo) s[S].compute_at(s[B], xo)
# Lowering stmt = lower(s, [A, B])
bounds = tvm.schedule.InferBound(s) stmt = tvm.ir_pass.SplitPipeline(stmt, False)
stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') stmt = tvm.ir_pass.NarrowChannelAccess(stmt)
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb})
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.SplitPipeline(stmt)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
def test_conv1d():
n = tvm.Var('n')
A = tvm.compute((n+2), lambda i: 1, name='A')
def computeB(ii):
i = ii + 1
return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline")
xo, xi = s[B].split(B.op.axis[0], outer=px)
s[A].compute_at(s[B], px)
stmt = lower(s, [B])
stmt = tvm.ir_pass.SplitPipeline(stmt, False)
print(stmt)
stmt = tvm.ir_pass.NarrowChannelAccess(stmt)
print(stmt)
if __name__ == "__main__": if __name__ == "__main__":
test_basic_pipeline() test_basic_pipeline()
test_conv1d()
...@@ -36,7 +36,8 @@ if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then ...@@ -36,7 +36,8 @@ if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then
make -f tests/travis/packages.mk iverilog make -f tests/travis/packages.mk iverilog
make verilog || exit -1 make verilog || exit -1
make all || exit -1 make all || exit -1
nosetests -v tests/verilog || exit -1 nosetests -v tests/verilog/unittest || exit -1
nosetests -v tests/verilog/integration || exit -1
fi fi
fi fi
......
import tvm
from tvm.addon import testing, verilog
import numpy as np
def lower(s, args, name):
binds = {}
arg_list = []
for x in args:
assert isinstance(x, tvm.tensor.Tensor)
buf = tvm.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.SplitPipeline(stmt, True)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0)
return fapi
@tvm.register_func
def tvm_callback_verilog_postproc(code):
"""Hook to inspect the verilog code before actually run it"""
print(code)
return code
def test_add_pipeline():
nn = 128
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A', dtype='int32')
B = tvm.placeholder((n,), name='B', dtype='int32')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.Schedule(C.op)
grid_x = tvm.thread_axis((0, 1), "pipeline")
_, x = s[C].split(C.op.axis[0], outer=grid_x)
fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
print(fsplits[1].body)
print("------")
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
ctx = tvm.vpi(0)
mhost = tvm.codegen.build(fsplits[0], host)
mdev = tvm.codegen.build(fsplits[1:], device)
mhost.import_module(mdev)
code = mdev.get_source()
f = mhost.entry_func
# launch the kernel.
n = nn
a = tvm.nd.array((np.random.uniform(size=n) * 128).astype(A.dtype), ctx)
b = tvm.nd.array((np.random.uniform(size=n) * 128).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, b, c)
print("Check correctness...")
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_target("verilog")
if __name__ == "__main__":
test_add_pipeline()
`include "tvm_marcos.v"
module main();
parameter PER = 10;
reg clk;
reg rst;
wire init;
wire done;
wire enable;
`NORMAL_LOOP_LEAF(iter0, 4, init0, enable, iter0_done, 0, 4, 1)
`NORMAL_LOOP_NEST(iter1, 4, init, iter0_done, iter1_done, 0, 3, 1, init0)
assign done = iter0_done;
always begin
#(PER/2) clk =~ clk;
end
initial begin
// This will allow tvm session to be called every cycle.
$tvm_session(clk);
end
endmodule
import tvm
from tvm.addon import verilog
from testing_util import FIFODelayedWriter, FIFODelayedReader
def run_with_lag(n, read_lag, write_lag):
data = list(range(n))
# head ptr of a
sess = verilog.session([
verilog.find_file("test_cache_reg.v")
])
rst = sess.main.rst
in_data = sess.main.in_data
in_valid = sess.main.in_valid
in_ready = sess.main.in_ready
out_data = sess.main.out_data
out_valid = sess.main.out_valid
out_ready = sess.main.out_ready
# hook up reader
reader = FIFODelayedReader(out_data, out_valid, out_ready, read_lag)
writer = FIFODelayedWriter(in_data, in_valid, in_ready, data, write_lag)
rst.put_int(1)
sess.yield_until_next_cycle()
rst.put_int(0)
sess.yield_until_next_cycle()
sess.yield_callbacks.append(reader)
sess.yield_callbacks.append(writer)
timeout = sum(read_lag) + sum(write_lag) + n + 10
for t in range(timeout):
sess.yield_until_next_cycle()
if len(reader.data) == n:
break
assert tuple(reader.data) == tuple(range(n))
assert len(writer.data) == 0
sess.shutdown()
def test_fifo():
n = 20
# slow reader
run_with_lag(n, read_lag=[3,4,8], write_lag=[])
# slow writer
run_with_lag(n, read_lag=[0], write_lag=[0, 2, 10])
# mix
run_with_lag(n, read_lag=[3, 4, 8], write_lag=[0, 2, 10])
if __name__ == "__main__":
test_fifo()
`include "tvm_marcos.v"
module main();
`TVM_DEFINE_TEST_SIGNAL(clk, rst)
reg[31:0] in_data;
wire[31:0] out_data;
wire in_ready;
reg in_valid;
reg out_ready;
wire out_valid;
`CACHE_REG(32, in_data, in_valid, in_ready,
out_data, out_valid, out_ready)
initial begin
// This will allow tvm session to be called every cycle.
$tvm_session(clk);
end
endmodule
...@@ -16,14 +16,38 @@ def test_counter(): ...@@ -16,14 +16,38 @@ def test_counter():
assert(counter.size == 4) assert(counter.size == 4)
rst.put_int(1) rst.put_int(1)
# This will advance the cycle to next pos-edge of clk. # This will advance the cycle to next pos-edge of clk.
sess.yield_until_posedge() sess.yield_until_next_cycle()
rst.put_int(0) rst.put_int(0)
sess.yield_until_next_cycle()
for i in range(10): for i in range(10):
# get value of counter. # get value of counter.
assert(counter.get_int() == i) assert(counter.get_int() == i)
sess.yield_until_posedge() sess.yield_until_next_cycle()
def test_scratch():
sess = verilog.session([
verilog.find_file("test_counter.v"),
verilog.find_file("example_counter.v")
])
# Get the handles by their names
rst = sess.main.rst
counter = sess.main.counter
rst.put_int(1)
# This will advance the cycle to next pos-edge of clk.
sess.yield_until_next_cycle()
rst.put_int(0)
temp = 0
for i in range(10):
if rst.get_int():
rst.put_int(0)
temp = counter.get_int()
elif counter.get_int() == 3:
rst.put_int(1)
print("counter=%d, temp=%d" % (counter.get_int(), temp))
sess.yield_until_next_cycle()
if __name__ == "__main__": if __name__ == "__main__":
test_scratch()
test_counter() test_counter()
`include "tvm_marcos.v"
module main(); module main();
parameter PER = 10; `TVM_DEFINE_TEST_SIGNAL(clk, rst)
reg clk;
reg rst;
wire [3:0] counter;
wire[3:0] counter;
counter counter_unit1(.clk(clk), .rst(rst), .out(counter)); counter counter_unit1(.clk(clk), .rst(rst), .out(counter));
always begin
#(PER/2) clk =~ clk;
end
initial begin initial begin
// This will allow tvm session to be called every cycle. // This will allow tvm session to be called every cycle.
$tvm_session(clk); $tvm_session(clk);
......
...@@ -7,28 +7,23 @@ def test_loop(): ...@@ -7,28 +7,23 @@ def test_loop():
]) ])
# Get the handles by their names # Get the handles by their names
rst = sess.main.rst rst = sess.main.rst
init = sess.main.init
iter0 = sess.main.iter0 iter0 = sess.main.iter0
iter1 = sess.main.iter1 iter1 = sess.main.iter1
enable = sess.main.enable ready = sess.main.ready
invalid = sess.main.done
rst.put_int(1) rst.put_int(1)
ready.put_int(1)
# This will advance the cycle to next pos-edge of clk. # This will advance the cycle to next pos-edge of clk.
sess.yield_until_posedge() sess.yield_until_next_cycle()
rst.put_int(0) rst.put_int(0)
init.put_int(1) sess.yield_until_next_cycle()
sess.yield_until_posedge()
enable.put_int(1)
init.put_int(0)
for i in range(0, 3): for k in range(0, 1):
for j in range(0, 4): for i in range(0, 3):
while invalid.get_int(): for j in range(0, 4):
sess.yield_until_posedge() assert(iter1.get_int() == i)
assert(iter1.get_int() == i) assert(iter0.get_int() == j)
assert(iter0.get_int() == j) sess.yield_until_next_cycle()
sess.yield_until_posedge()
if __name__ == "__main__": if __name__ == "__main__":
......
`include "tvm_marcos.v"
module main();
`TVM_DEFINE_TEST_SIGNAL(clk, rst)
reg ready;
wire lp_ready;
`NONSTOP_LOOP(iter0, 4, 0, lp_ready, iter0_finish, 0, 4)
`NONSTOP_LOOP(iter1, 4, 0, iter0_finish, iter1_finish, 0, 3)
`WRAP_LOOP_ONCE(0, valid, ready, iter1_finish, loop_ready)
assign lp_ready = loop_ready;
initial begin
// This will allow tvm session to be called every cycle.
$tvm_session(clk);
end
endmodule
...@@ -51,7 +51,7 @@ def test_ram_read(): ...@@ -51,7 +51,7 @@ def test_ram_read():
host_read_addr = sess.main.read_addr host_read_addr = sess.main.read_addr
host_read_size = sess.main.read_size host_read_size = sess.main.read_size
rst.put_int(1) rst.put_int(1)
sess.yield_until_posedge() sess.yield_until_next_cycle()
rst.put_int(0) rst.put_int(0)
# hook up reader # hook up reader
reader = FIFOReader(read_data, read_valid) reader = FIFOReader(read_data, read_valid)
...@@ -61,18 +61,18 @@ def test_ram_read(): ...@@ -61,18 +61,18 @@ def test_ram_read():
host_read_addr.put_int(a_ptr) host_read_addr.put_int(a_ptr)
host_read_size.put_int(a.shape[0]) host_read_size.put_int(a.shape[0])
sess.yield_until_posedge() sess.yield_until_next_cycle()
# second read request # second read request
host_read_addr.put_int(a_ptr + 2) host_read_addr.put_int(a_ptr + 2)
host_read_size.put_int(a.shape[0] - 2) host_read_size.put_int(a.shape[0] - 2)
sess.yield_until_posedge() sess.yield_until_next_cycle()
host_read_req.put_int(0) host_read_req.put_int(0)
read_en.put_int(1) read_en.put_int(1)
# yield until read is done # yield until read is done
for i in range(a.shape[0] * 3): for i in range(a.shape[0] * 3):
sess.yield_until_posedge() sess.yield_until_next_cycle()
# check if result matches # check if result matches
r = np.concatenate((a_np, a_np[2:])) r = np.concatenate((a_np, a_np[2:]))
np.testing.assert_equal(np.array(reader.data), r) np.testing.assert_equal(np.array(reader.data), r)
...@@ -105,7 +105,7 @@ def test_ram_write(): ...@@ -105,7 +105,7 @@ def test_ram_write():
host_write_size = sess.main.write_size host_write_size = sess.main.write_size
rst.put_int(1) rst.put_int(1)
sess.yield_until_posedge() sess.yield_until_next_cycle()
rst.put_int(0) rst.put_int(0)
# hook up writeer # hook up writeer
writer = FIFOWriter(write_data, write_en, write_ready, w_data) writer = FIFOWriter(write_data, write_en, write_ready, w_data)
...@@ -116,12 +116,12 @@ def test_ram_write(): ...@@ -116,12 +116,12 @@ def test_ram_write():
host_write_addr.put_int(a_ptr + offset) host_write_addr.put_int(a_ptr + offset)
host_write_size.put_int(a.shape[0] - offset) host_write_size.put_int(a.shape[0] - offset)
sess.yield_until_posedge() sess.yield_until_next_cycle()
host_write_req.put_int(0) host_write_req.put_int(0)
# yield until write is done # yield until write is done
for i in range(a.shape[0]+2): for i in range(a.shape[0]+2):
sess.yield_until_posedge() sess.yield_until_next_cycle()
# check if result matches # check if result matches
np.testing.assert_equal(a.asnumpy()[2:], r_data) np.testing.assert_equal(a.asnumpy()[2:], r_data)
......
...@@ -25,18 +25,18 @@ def test_mmap(): ...@@ -25,18 +25,18 @@ def test_mmap():
# setup memory map. # setup memory map.
rst.put_int(1) rst.put_int(1)
sess.yield_until_posedge() sess.yield_until_next_cycle()
rst.put_int(0) rst.put_int(0)
write_en.put_int(0) write_en.put_int(0)
mmap_addr.put_int(a_ptr) mmap_addr.put_int(a_ptr)
sess.yield_until_posedge() sess.yield_until_next_cycle()
# read test # read test
for i in range(n): for i in range(n):
read_addr.put_int(i) read_addr.put_int(i)
sess.yield_until_posedge() sess.yield_until_next_cycle()
# read addr get set this cycle # read addr get set this cycle
sess.yield_until_posedge() sess.yield_until_next_cycle()
# get the data out # get the data out
assert(read_data.get_int() == i) assert(read_data.get_int() == i)
...@@ -45,9 +45,9 @@ def test_mmap(): ...@@ -45,9 +45,9 @@ def test_mmap():
write_addr.put_int(i) write_addr.put_int(i)
write_en.put_int(1) write_en.put_int(1)
write_data.put_int(i + 1) write_data.put_int(i + 1)
sess.yield_until_posedge() sess.yield_until_next_cycle()
write_en.put_int(0) write_en.put_int(0)
sess.yield_until_posedge() sess.yield_until_next_cycle()
np.testing.assert_equal(a.asnumpy(), a_np + 1) np.testing.assert_equal(a.asnumpy(), a_np + 1)
......
"""Common utilities for test"""
class FIFODelayedReader(object):
"""Reader that have specified ready lag."""
def __init__(self, read_data, read_valid, read_ready, lag):
self.read_data = read_data
self.read_valid = read_valid
self.read_ready = read_ready
self.read_ready.put_int(1)
self.lag = list(reversed(lag))
self.data = []
self.wait_counter = 0
self.wait_state = False
def __call__(self):
"""Logic as if always at pos-edge"""
if not self.wait_state:
if (self.read_ready.get_int() and
self.read_valid.get_int()):
self.data.append(self.read_data.get_int())
self.wait_counter = self.lag.pop() if self.lag else 0
self.wait_state = True
if self.wait_state:
if self.wait_counter == 0:
self.read_ready.put_int(1)
self.wait_state = False
else:
self.wait_counter -= 1
self.read_ready.put_int(0)
class FIFODelayedWriter(object):
"""Auxiliary class to write to FIFO """
def __init__(self, write_data, write_valid, write_ready, data, lag):
self.write_data = write_data
self.write_valid = write_valid
self.write_ready = write_ready
self.write_valid.put_int(0)
self.lag = list(reversed(lag))
self.data = list(reversed(data))
self.wait_counter = 0
self.wait_state = True
def __call__(self):
"""Logic as if always at pos-edge"""
if not self.wait_state:
if self.write_ready.get_int():
self.wait_counter = self.lag.pop() if self.lag else 0
self.wait_state = True
if self.wait_state:
if self.wait_counter == 0:
if self.data:
self.write_valid.put_int(1)
self.write_data.put_int(self.data.pop())
self.wait_state = False
else:
self.write_valid.put_int(0)
else:
self.write_valid.put_int(0)
self.wait_counter -= 1
// Leaf of a normal loop nest // Nonstop version of loop
// Starts at done = 1 // Always keeps looping when increase == true
// Need init to reset to done = 0 // At end is a signal to indicate the next cycle is end
// increases when enabled = 1 // Use that to signal parent loop to advance.
`define NORMAL_LOOP_LEAF(iter, width, init, enable, done, min, max, incr)\ `define NONSTOP_LOOP(iter, width, init, ready, finish, min, extent)\
reg [width-1:0] iter;\ reg [width-1:0] iter;\
reg valid;\ wire finish;\
reg done;\
always@(posedge clk) begin\ always@(posedge clk) begin\
if(rst) begin\ if (rst || init) begin\
iter <= 0;\
done <= 1;\
end else if(init) begin\
iter <= (min);\ iter <= (min);\
done <= 0;\ end else if(ready) begin\
end else if(done) begin\ if (iter != ((extent)-1)) begin\
iter <= 0;\ iter <= iter + 1;\
done <= 1;\
end else if(enable) begin\
if (iter < ((max)-(incr))) begin\
iter <= iter + (incr);\
done <= 0;\
end else begin\ end else begin\
iter <= 0;\ iter <= (min);\
done <= 1;\
end\ end\
end else begin\ end else begin\
iter <= iter;\ iter <= iter;\
done <= done;\
end\ end\
end end\
assign finish = (ready && (iter == (extent) - 1));
// Normal loop nest that can connect to a child which is a normal loop
`define NORMAL_LOOP_NEST(iter, width, init, body_done, done, min, max, incr, body_init)\ // Wrap a nonstop loop to normal loop that loop only once.
reg [width-1:0] iter;\ // Use done signal to control the non-stop body to stop.
reg done;\ // The init and done behaves like normal loop
reg body_init;\ `define WRAP_LOOP_ONCE(init, valid, ready, body_finish, body_ready)\
reg valid;\
wire body_ready;\
always@(posedge clk) begin\
if (rst || init) begin\
valid <= 1;\
end else if(body_finish) begin\
valid <= 0;\
end else begin\
valid <= valid;\
end\
end\
assign body_ready = (valid && ready);
// Assign dst as src delayed by specific cycles.
`define DELAY(dst, src, width, delay, not_stall)\
reg [(width)*(delay)-1:0] src``_dly_chain;\
always@(posedge clk) begin\ always@(posedge clk) begin\
if(rst) begin\ if(rst) begin\
iter <= 0;\ src``_dly_chain <= 0;\
done <= 1;\ end else if (not_stall) begin\
body_init <= 0;\ src``_dly_chain[(width)-1:0] <= src;\
end else if(init) begin\ if((delay) != 1) begin\
iter <= (min);\ src``_dly_chain[(delay)*(width)-1:(width)] <= src``_dly_chain[((delay)-1)*(width)-1:0];\
done <= 0;\ end\
body_init <= 1;\ end else begin\
end else if(done) begin\ src``_dly_chain <= src``_dly_chain;\
iter <= 0;\ end\
done <= 1;\ end\
body_init <= 0;\ assign dst = src``_dly_chain[(delay)*(width)-1:((delay)-1)*(width)];
end else if (body_init) begin\
iter <= iter;\ // TVM generate clock signal
done <= done;\ `define TVM_DEFINE_TEST_SIGNAL(clk, rst)\
body_init <= 0;\ parameter PER = 10;\
end else if (body_done) begin\ reg clk;\
if (iter < ((max)-(incr))) begin\ reg rst;\
iter <= iter + (incr);\ always begin\
done <= 0;\ #(PER/2) clk =~ clk;\
body_init <= 1;\ end
// Control logic on buffer/RAM read valid.
// This delays the valid signal by one cycle and retain it when write_ready == 0
`define BUFFER_READ_VALID_DELAY(dst, data_valid, write_ready)\
reg dst;\
always@(posedge clk) begin\
if(rst) begin\
dst <= 0;\
end else if (write_ready) begin\
dst <= (data_valid);\
end else begin\
dst <= dst;\
end\
end\
// A cache register that add one cycle lag to the ready signal
// This allows the signal to flow more smoothly
`define CACHE_REG(width, in_data, in_valid, in_ready, out_data, out_valid, out_ready)\
reg [width-1:0] out_data``_state_;\
reg [width-1:0] out_data``_overflow_;\
reg out_valid``_state_;\
reg out_valid``_overflow_;\
always@(posedge clk) begin\
if(rst) begin\
out_valid``_overflow_ <= 0;\
out_valid``_state_ <= 0;\
end else if (out_valid``_overflow_) begin\
if (out_ready) begin\
out_valid``_state_ <= 1;\
out_data``_state_ <= out_data``_overflow_;\
out_valid``_overflow_ <= 0;\
out_data``_overflow_ <= 0;\
end else begin\ end else begin\
iter <= 0;\ out_valid``_state_ <= 1;\
done <= 1;\ out_data``_state_ <= out_data``_state_;\
body_init <= 0;\ out_valid``_overflow_ <= out_valid``_overflow_;\
out_data``_overflow_ <= out_data``_overflow_;\
end\ end\
end else begin\ end else begin\
iter <= iter;\ if (!out_ready && out_valid``_state_) begin\
done <= done;\ out_valid``_state_ <= 1;\
body_init <= 0;\ out_data``_state_ <= out_data``_state_;\
out_valid``_overflow_ <= in_valid;\
out_data``_overflow_ <= in_data;\
end else begin\
out_valid``_state_ <= in_valid;\
out_data``_state_ <= in_data;\
out_valid``_overflow_ <= out_valid``_overflow_;\
out_data``_overflow_ <= out_data``_overflow_;\
end\
end\ end\
end end\ // always@ (posedge clk)
assign in_ready = !out_valid``_overflow_;\
assign out_data = out_data``_state_;\
assign out_valid = out_valid``_state_;
...@@ -43,9 +43,9 @@ class IPCClient { ...@@ -43,9 +43,9 @@ class IPCClient {
PutInt(clock_, 0); PutInt(clock_, 0);
} }
int Callback() { int Callback() {
if (GetInt(clock_)) { if (!GetInt(clock_)) {
try { try {
return AtPosEedge(); return AtNegEdge();
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {
reader_.Close(); reader_.Close();
writer_.Close(); writer_.Close();
...@@ -57,8 +57,11 @@ class IPCClient { ...@@ -57,8 +57,11 @@ class IPCClient {
return 0; return 0;
} }
} }
// called at positive edge. // called at neg edge.
int AtPosEedge() { int AtNegEdge() {
// This is actually called at neg-edge
// The put values won't take effect until next neg-edge.
// This allow us to see the registers before snc
writer_.Write(kPosEdgeTrigger); writer_.Write(kPosEdgeTrigger);
VPICallCode rcode; VPICallCode rcode;
VPIRawHandle handle; VPIRawHandle handle;
...@@ -149,10 +152,10 @@ class IPCClient { ...@@ -149,10 +152,10 @@ class IPCClient {
s_vpi_time time_s; s_vpi_time time_s;
time_s.type = vpiSimTime; time_s.type = vpiSimTime;
time_s.high = 0; time_s.high = 0;
time_s.low = 0; time_s.low = 10;
value_s.format = vpiVectorVal; value_s.format = vpiVectorVal;
value_s.value.vector = &svec_buf_[0]; value_s.value.vector = &svec_buf_[0];
vpi_put_value(h, &value_s, &time_s, vpiInertialDelay); vpi_put_value(h, &value_s, &time_s, vpiTransportDelay);
writer_.Write(kSuccess); writer_.Write(kSuccess);
break; break;
} }
...@@ -202,10 +205,10 @@ class IPCClient { ...@@ -202,10 +205,10 @@ class IPCClient {
s_vpi_time time_s; s_vpi_time time_s;
time_s.type = vpiSimTime; time_s.type = vpiSimTime;
time_s.high = 0; time_s.high = 0;
time_s.low = 0; time_s.low = 10;
value_s.format = vpiIntVal; value_s.format = vpiIntVal;
value_s.value.integer = value; value_s.value.integer = value;
vpi_put_value(h, &value_s, &time_s, vpiInertialDelay); vpi_put_value(h, &value_s, &time_s, vpiTransportDelay);
} }
// Handles // Handles
vpiHandle clock_; vpiHandle clock_;
......
VPI_CFLAGS=`iverilog-vpi --cflags` VPI_CFLAGS=`iverilog-vpi --cflags`
VPI_LDLAGS=`iverilog-vpi --ldlags` VPI_LDFLAGS=`iverilog-vpi --ldflags`
VER_SRCS = $(wildcard verilog/*.v) VER_SRCS = $(wildcard verilog/*.v)
...@@ -7,4 +7,4 @@ VER_LIBS=lib/tvm_vpi.vpi ...@@ -7,4 +7,4 @@ VER_LIBS=lib/tvm_vpi.vpi
lib/tvm_vpi.vpi: verilog/tvm_vpi.cc verilog/tvm_vpi.h lib/tvm_vpi.vpi: verilog/tvm_vpi.cc verilog/tvm_vpi.h
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(VPI_CFLAGS) $(CFLAGS) -shared -o $@ $(filter %.cc, $^) $(LDFLAGS) $(VPI_LDFLAGS) $(CXX) $(VPI_CFLAGS) $(CFLAGS) -o $@ $(filter %.cc, $^) $(LDFLAGS) $(VPI_LDFLAGS)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment