Commit 324a9607 by Siyuan Feng Committed by Leyuan Wang

TensorCore Support using Intrinsic (#4136)

* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic
parent 4ab73634
...@@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; ...@@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
constexpr const char* device_scope = "device_scope"; constexpr const char* device_scope = "device_scope";
/*! /*!
* \brief Mark that the shape of TensorCore fragment
*/
constexpr const char* fragment_shape = "fragment_shape";
/*!
* \brief Mark that the layout of TensorCore fragment
*/
constexpr const char* fragment_layout = "fragment_layout";
/*!
* \brief Check if attr_key is a pragma key extension * \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared * \param attr_key The attr key to be compared
* \return true if it is a pragma key * \return true if it is a pragma key
...@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; ...@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
* } * }
*/ */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*!
* \brief tvm intrinsic for tensor core load operators.
*
* void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment.
* // Determine fragment layout(column-major or row major) by layout.
* // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
* nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
* }
*/
constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
/*!
* \brief tvm intrinsic for tensor core mma_sync operators.
*
* void tvm_mma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators.
*
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr value) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::fill_fragment(fragment[index], value);
* }
*/
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
/*!
* \brief tvm intrinsic for tensor core store operators.
*
* void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
* }
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
} // namespace intrinsic } // namespace intrinsic
......
...@@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f); ...@@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f);
LoweredFunc PointerValueTypeRewrite(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*! /*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
/*!
* \brief Lower intrinsic function calls. * \brief Lower intrinsic function calls.
* \param f The device function to be lowered. * \param f The device function to be lowered.
* \param target The target device. * \param target The target device.
...@@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); ...@@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*! /*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type. * \brief Verify if memory accesses are legal for a specific target device type.
* *
* In the case that tgt is cuda, if not all workload is bound with * In the case that tgt is cuda, if not all workload is bound with
......
...@@ -413,7 +413,6 @@ def lower(sch, ...@@ -413,7 +413,6 @@ def lower(sch,
# Phase 3 # Phase 3
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting: if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt)
...@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host): ...@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host):
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp") func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
...@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host): ...@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host):
assert not fdevice assert not fdevice
target_host = _target.create(target_host) target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost]
......
...@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
}); });
}); });
TVM_REGISTER_API("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});
// make from two arguments // make from two arguments
#define REGISTER_PASS(PassName) \ #define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
...@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice); ...@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite); REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync); REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LowerDeviceStorageAccessInfo)
REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(InjectDoubleBuffer);
...@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope); ...@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch, ...@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch,
// Phase 2 // Phase 2
stmt = ir::Simplify(stmt); stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt); stmt = ir::RemoveNoOp(stmt);
if (!(config->disable_select_rewriting)) if (!(config->disable_select_rewriting))
...@@ -517,6 +516,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -517,6 +516,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::BindDeviceType(func, target->device_type); func = ir::BindDeviceType(func, target->device_type);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::LowerTVMBuiltin(func); func = ir::LowerTVMBuiltin(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
...@@ -524,6 +524,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -524,6 +524,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::LowerIntrin(func, target_host->target_name); func = ir::LowerIntrin(func, target_host->target_name);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::CombineContextCall(func); func = ir::CombineContextCall(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
......
...@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() { ...@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <math_constants.h>\n"; decl_stream << "#include <math_constants.h>\n";
} }
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) ...@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: os << "half"; case 16:
enable_fp16_ = true; enable_fp16_ = true;
if (lanes == 1) {
os << "half";
} else if (lanes <= 8) {
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "float" << lanes / 2;
} else {
fail = true;
}
break; break;
case 32: os << "float"; break; case 32: os << "float"; break;
case 64: os << "double"; break; case 64: os << "double"; break;
default: fail = true; break; default: fail = true; break;
} }
if (!fail && lanes == 1) return; if (!fail && (lanes == 1 || t.bits() == 16)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) { if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return; os << lanes; return;
} }
...@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
} }
} }
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImm *str = op->args[7].as<StringImm>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::fragment_shape) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* shape_str = op->value.as<StringImm>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* layout_str = op->value.as<StringImm>();
fragment_layouts[buffer] = layout_str->value;
}
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->type, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
} else {
CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->type, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->type, stream);
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->type);
this->PrintStmt(op->body);
}
void CodeGenCUDA::VisitStmt_(const Evaluate *op) { void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return; if (is_const(op->value)) return;
const Call* call = op->value.as<Call>(); const Call* call = op->value.as<Call>();
...@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* ...@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
PrintConst(op, os, this); PrintConst(op, os, this);
} }
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
const Variable* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.matrix_b") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.accumulator") {
need_mma_h_ = true;
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
<< shape_str << ", "<< type.str() << ">";
}
}
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
const Variable* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
pos = shape_str.find(", ", last_pos);
m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
pos = shape_str.find(", ", last_pos);
n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
return size / m / k;
} else if (scope == "wmma.matrix_b") {
return size / n / k;
} else if (scope == "wmma.accumulator") {
return size / m / n;
}
return 0;
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include <unordered_map>
#include "codegen_c.h" #include "codegen_c.h"
namespace tvm { namespace tvm {
...@@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC {
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
std::string Finish(); std::string Finish();
bool need_include_path() { bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_); return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
} }
// override behavior // override behavior
void VisitStmt_(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
...@@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC {
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; void VisitExpr_(const FloatImm *op, std::ostream& os) final;
void VisitExpr_(const Call *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
void VisitStmt_(const Allocate *op) final;
void VisitStmt_(const AttrStmt *op) final;
private: private:
// Whether global barrier is needed. // Whether global barrier is needed.
...@@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_int8_{false}; bool enable_int8_{false};
// whether need math_constants.h // whether need math_constants.h
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
std::unordered_map<const Variable*, std::string> fragment_shapes;
std::unordered_map<const Variable*, std::string> fragment_layouts;
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size);
}; };
} // namespace codegen } // namespace codegen
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \brief Infer TensorCore metadata from tensor intrinsic.
* \file tensorcore_fragment.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
// Get fragment information from tensor intrinsics
class FragmentGetter : public IRVisitor {
public:
// fragment metadata
struct FragmentInfo {
// fragment shape
int m, n, k;
// fragment layout (row-major or column-major)
std::string layout;
FragmentInfo() = default;
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
: m(_m), n(_n), k(_k), layout(_layout) {}
};
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
// Get shape and layout information from load and store intrinsic
CHECK_EQ(op->args.size(), 8U);
const Variable* buffer_var = op->args[0].as<Variable>();
CHECK(buffer_var);
// Get shape
const IntImm* m = op->args[1].as<IntImm>();
const IntImm* n = op->args[2].as<IntImm>();
const IntImm* k = op->args[3].as<IntImm>();
const StringImm* layout = op->args[7].as<StringImm>();
CHECK(m);
CHECK(n);
CHECK(k);
CHECK(layout);
std::string scope = scopes[buffer_var];
if (fragments.count(buffer_var)) {
// check if the fragment has met before
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK_EQ(layout->value, info.layout);
}
} else {
// store metadata
FragmentInfo info;
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
info = FragmentInfo(m->value, n->value, k->value, layout->value);
} else if (scope == "wmma.accumulator") {
info = FragmentInfo(m->value, n->value, k->value, "");
}
fragments[buffer_var] = info;
}
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
// Get shape information from fill intrinsic
CHECK_EQ(op->args.size(), 6U);
const Variable* buffer_var = op->args[0].as<Variable>();
CHECK(buffer_var);
// Get shape
const IntImm* m = op->args[1].as<IntImm>();
const IntImm* n = op->args[2].as<IntImm>();
const IntImm* k = op->args[3].as<IntImm>();
CHECK(m);
CHECK(n);
CHECK(k);
std::string scope = scopes[buffer_var];
// Only wmma.accumulator can use tvm_fill_fragment
CHECK_EQ(scope, "wmma.accumulator");
if (fragments.count(buffer_var)) {
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
} else {
FragmentInfo info(m->value, n->value, k->value, "");
fragments[buffer_var] = info;
}
}
}
// Get memory scope
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buffer = op->node.as<Variable>();
CHECK(buffer);
scopes[buffer] = op->value.as<StringImm>()->value;
}
IRVisitor::Visit_(op);
}
// Memory scope for allocations
std::unordered_map<const Variable*, std::string> scopes;
// Fragment metadata for all fragments
std::unordered_map<const Variable*, FragmentInfo> fragments;
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public IRVisitor {
public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
void Visit_(const Call* op) final {
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U);
const Variable* buffer_var_d = op->args[0].as<Variable>();
const Variable* buffer_var_a = op->args[2].as<Variable>();
const Variable* buffer_var_b = op->args[4].as<Variable>();
const Variable* buffer_var_c = op->args[6].as<Variable>();
CHECK(buffer_var_d);
CHECK(buffer_var_a);
CHECK(buffer_var_b);
CHECK(buffer_var_c);
// Check all fragment A, B, C and D have the same shape
CHECK(CheckShape(buffer_var_d, buffer_var_a));
CHECK(CheckShape(buffer_var_d, buffer_var_b));
CHECK(CheckShape(buffer_var_d, buffer_var_c));
}
}
private:
// A tool for checking shapes of two fragments
bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
CHECK(fragment_getter.fragments.count(buffer1));
CHECK(fragment_getter.fragments.count(buffer2));
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
}
// Fragment infomation
const FragmentGetter &fragment_getter;
};
// Store the metadata into attributes
class InferFragmenter : public IRMutator {
public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
const Variable* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
// Add shape attribute to all fragments
std::string shape = std::to_string(info.m) + ", " +
std::to_string(info.n) + ", " +
std::to_string(info.k);
Expr shape_expr = StringImm::make(shape);
Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
StringImm::make(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
}
}
return stmt;
}
private:
// Fragment infomation
const FragmentGetter &fragment_getter;
};
Stmt InferFragment(Stmt stmt) {
FragmentGetter getter;
getter.Visit(stmt);
FragmentChecker(getter).Visit(stmt);
stmt = InferFragmenter(getter).Mutate(stmt);
return stmt;
}
LoweredFunc InferFragment(LoweredFunc f) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = InferFragment(f->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
...@@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { ...@@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt); return StorageAccessInfoLower().Mutate(stmt);
} }
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator { ...@@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator {
} }
} }
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
const Variable* buffer_var = op->args[1].as<Variable>();
Var var(GetRef<Var>(buffer_var));
const IntImm* flag = op->args[4].as<IntImm>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].read_count;
}
if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].write_count;
}
return expr;
} else {
return IRMutator::Mutate_(op, e);
}
}
private: private:
// RW statistics about data // RW statistics about data
struct Entry { struct Entry {
......
...@@ -50,7 +50,13 @@ enum class StorageRank { ...@@ -50,7 +50,13 @@ enum class StorageRank {
*/ */
kWarp = 2, kWarp = 2,
/*! \brief thread local memory */ /*! \brief thread local memory */
kLocal = 3 kLocal = 3,
/*! \brief wmma scope memory of matrix_a */
kWMMAMatrixA = 4,
/*! \brief wmma scope memory of matrix_b */
kWMMAMatrixB = 5,
/*! \brief wmma scope memory of accumulator */
kWMMAAccumulator = 6,
}; };
/*! /*!
...@@ -89,6 +95,9 @@ struct StorageScope { ...@@ -89,6 +95,9 @@ struct StorageScope {
case StorageRank::kShared: return "shared" + tag; case StorageRank::kShared: return "shared" + tag;
case StorageRank::kWarp: return "warp" + tag; case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag; case StorageRank::kLocal: return "local" + tag;
case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag;
case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag;
case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag;
default: LOG(FATAL) << "unknown storage scope"; return ""; default: LOG(FATAL) << "unknown storage scope"; return "";
} }
} }
...@@ -111,6 +120,15 @@ struct StorageScope { ...@@ -111,6 +120,15 @@ struct StorageScope {
} else if (s.compare(0, 5, "local") == 0) { } else if (s.compare(0, 5, "local") == 0) {
r.rank = StorageRank::kLocal; r.rank = StorageRank::kLocal;
r.tag = s.substr(5, std::string::npos); r.tag = s.substr(5, std::string::npos);
} else if (s.compare(0, 13, "wmma.matrix_a") == 0) {
r.rank = StorageRank::kWMMAMatrixA;
r.tag = s.substr(13, std::string::npos);
} else if (s.compare(0, 13, "wmma.matrix_b") == 0) {
r.rank = StorageRank::kWMMAMatrixB;
r.tag = s.substr(13, std::string::npos);
} else if (s.compare(0, 16, "wmma.accumulator") == 0) {
r.rank = StorageRank::kWMMAAccumulator;
r.tag = s.substr(16, std::string::npos);
} else { } else {
LOG(FATAL) << "unknown storage scope " << s; LOG(FATAL) << "unknown storage scope " << s;
} }
......
...@@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
Returns Returns
------- -------
b_np : np.ndarray b_np : np.ndarray
4-D with shape [out_height, out_width, out_channel, batch] 4-D with shape [batch, out_height, out_width, out_channel]
""" """
batch, in_height, in_width, in_channel = a_np.shape batch, in_height, in_width, in_channel = a_np.shape
kernel_h, kernel_w, _, num_filter = w_np.shape kernel_h, kernel_w, _, num_filter = w_np.shape
......
...@@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs): ...@@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs):
if debug_flag: if debug_flag:
pass_list.append((1, add_debug)) pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, ir_pass.cpu_access_rewrite))
return tvm.build_config(add_lower_pass=pass_list, **kwargs) return tvm.build_config(add_lower_pass=pass_list, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment