Commit 41ed49b7 by Tianqi Chen Committed by GitHub

[IR] support general type annotation. (#1480)

parent b55a2d54
...@@ -514,6 +514,16 @@ using HalideIR::Internal::Shuffle; ...@@ -514,6 +514,16 @@ using HalideIR::Internal::Shuffle;
// ir functions // ir functions
using HalideIR::Internal::is_const_power_of_two_integer; using HalideIR::Internal::is_const_power_of_two_integer;
/*!
* \brief Create a type annotation expression
* \param dtype The data type
* \return Expr a expression with dtype.
*/
inline Expr TypeAnnotation(Type dtype) {
return ir::Call::make(dtype,
"type_annotation", {},
ir::Call::PureIntrinsic);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -350,12 +350,12 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr ...@@ -350,12 +350,12 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
} }
Expr elem_offset = self->elem_offset + offset; Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) { if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes); extent = extent / make_const(self->elem_offset.type(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.type(), elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
content_lanes); content_lanes);
} else { } else {
e_dtype = make_zero(self->dtype); e_dtype = ir::TypeAnnotation(self->dtype);
} }
Array<Expr> acc_args{ Array<Expr> acc_args{
e_dtype, self->data, elem_offset, e_dtype, self->data, elem_offset,
......
...@@ -18,9 +18,12 @@ ...@@ -18,9 +18,12 @@
#include <memory> #include <memory>
#include <atomic> #include <atomic>
namespace vta { namespace vta {
// Avoid bad configurations.
static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8,
"VTA_UOP_WIDTH do not match VTAUop size");
/*! \brief Enable coherent access between VTA and CPU. */ /*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true; static const bool kBufferCoherent = true;
......
...@@ -245,12 +245,12 @@ class SRAM { ...@@ -245,12 +245,12 @@ class SRAM {
CHECK_LE(sram_end, kMaxNumElem); CHECK_LE(sram_end, kMaxNumElem);
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0); memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
sram_ptr += xtotal * op->y_pad_0; sram_ptr += xtotal * op->y_pad_0;
for (uint32_t y = 0; y < op->y_size; ++y) { for (uint32_t y = 0; y < op->y_size; ++y) {
memset(sram_ptr, 0, kElemBytes * op->x_pad_0); memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
sram_ptr += op->x_pad_0; sram_ptr += op->x_pad_0;
memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size); memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
sram_ptr += op->x_size; sram_ptr += op->x_size;
BitPacker<kBits> src(sram_ptr);
memset(sram_ptr, 0, kElemBytes * op->x_pad_1); memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
sram_ptr += op->x_pad_1; sram_ptr += op->x_pad_1;
dram_ptr += kElemBytes * op->x_stride; dram_ptr += kElemBytes * op->x_stride;
...@@ -415,12 +415,14 @@ class Device { ...@@ -415,12 +415,14 @@ class Device {
uint32_t acc_idx = uop_ptr->dst_idx; uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx; uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx; uint32_t wgt_idx = uop_ptr->wgt_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in; acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in; inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in; wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx)); BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx)); BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx)); BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));
// gemm loop // gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) { for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) { for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment