Commit 324a9607 by Siyuan Feng Committed by Leyuan Wang

TensorCore Support using Intrinsic (#4136)

* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic
parent 4ab73634
...@@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; ...@@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
constexpr const char* device_scope = "device_scope"; constexpr const char* device_scope = "device_scope";
/*! /*!
* \brief Mark that the shape of TensorCore fragment
*/
constexpr const char* fragment_shape = "fragment_shape";
/*!
* \brief Mark that the layout of TensorCore fragment
*/
constexpr const char* fragment_layout = "fragment_layout";
/*!
* \brief Check if attr_key is a pragma key extension * \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared * \param attr_key The attr key to be compared
* \return true if it is a pragma key * \return true if it is a pragma key
...@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; ...@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
* } * }
*/ */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*!
* \brief tvm intrinsic for tensor core load operators.
*
* void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment.
* // Determine fragment layout(column-major or row major) by layout.
* // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
* nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
* }
*/
constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
/*!
* \brief tvm intrinsic for tensor core mma_sync operators.
*
* void tvm_mma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators.
*
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr value) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::fill_fragment(fragment[index], value);
* }
*/
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
/*!
* \brief tvm intrinsic for tensor core store operators.
*
* void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
* }
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
} // namespace intrinsic } // namespace intrinsic
......
...@@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f); ...@@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f);
LoweredFunc PointerValueTypeRewrite(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*! /*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
/*!
* \brief Lower intrinsic function calls. * \brief Lower intrinsic function calls.
* \param f The device function to be lowered. * \param f The device function to be lowered.
* \param target The target device. * \param target The target device.
...@@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); ...@@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*! /*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type. * \brief Verify if memory accesses are legal for a specific target device type.
* *
* In the case that tgt is cuda, if not all workload is bound with * In the case that tgt is cuda, if not all workload is bound with
......
...@@ -413,7 +413,6 @@ def lower(sch, ...@@ -413,7 +413,6 @@ def lower(sch,
# Phase 3 # Phase 3
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting: if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt)
...@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host): ...@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host):
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp") func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
...@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host): ...@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host):
assert not fdevice assert not fdevice
target_host = _target.create(target_host) target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost]
......
...@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
}); });
}); });
TVM_REGISTER_API("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});
// make from two arguments // make from two arguments
#define REGISTER_PASS(PassName) \ #define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
...@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice); ...@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite); REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync); REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LowerDeviceStorageAccessInfo)
REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(InjectDoubleBuffer);
...@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope); ...@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch, ...@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch,
// Phase 2 // Phase 2
stmt = ir::Simplify(stmt); stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt); stmt = ir::RemoveNoOp(stmt);
if (!(config->disable_select_rewriting)) if (!(config->disable_select_rewriting))
...@@ -517,6 +516,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -517,6 +516,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::BindDeviceType(func, target->device_type); func = ir::BindDeviceType(func, target->device_type);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::LowerTVMBuiltin(func); func = ir::LowerTVMBuiltin(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
...@@ -524,6 +524,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -524,6 +524,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::LowerIntrin(func, target_host->target_name); func = ir::LowerIntrin(func, target_host->target_name);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::CombineContextCall(func); func = ir::CombineContextCall(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
......
...@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() { ...@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <math_constants.h>\n"; decl_stream << "#include <math_constants.h>\n";
} }
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) ...@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: os << "half"; case 16:
enable_fp16_ = true; enable_fp16_ = true;
if (lanes == 1) {
os << "half";
} else if (lanes <= 8) {
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "float" << lanes / 2;
} else {
fail = true;
}
break; break;
case 32: os << "float"; break; case 32: os << "float"; break;
case 64: os << "double"; break; case 64: os << "double"; break;
default: fail = true; break; default: fail = true; break;
} }
if (!fail && lanes == 1) return; if (!fail && (lanes == 1 || t.bits() == 16)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) { if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return; os << lanes; return;
} }
...@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
} }
} }
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImm *str = op->args[7].as<StringImm>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::fragment_shape) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* shape_str = op->value.as<StringImm>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* layout_str = op->value.as<StringImm>();
fragment_layouts[buffer] = layout_str->value;
}
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->type, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
} else {
CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->type, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->type, stream);
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->type);
this->PrintStmt(op->body);
}
void CodeGenCUDA::VisitStmt_(const Evaluate *op) { void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return; if (is_const(op->value)) return;
const Call* call = op->value.as<Call>(); const Call* call = op->value.as<Call>();
...@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* ...@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
PrintConst(op, os, this); PrintConst(op, os, this);
} }
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
const Variable* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.matrix_b") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.accumulator") {
need_mma_h_ = true;
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
<< shape_str << ", "<< type.str() << ">";
}
}
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
const Variable* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
pos = shape_str.find(", ", last_pos);
m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
pos = shape_str.find(", ", last_pos);
n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
return size / m / k;
} else if (scope == "wmma.matrix_b") {
return size / n / k;
} else if (scope == "wmma.accumulator") {
return size / m / n;
}
return 0;
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include <unordered_map>
#include "codegen_c.h" #include "codegen_c.h"
namespace tvm { namespace tvm {
...@@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC {
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
std::string Finish(); std::string Finish();
bool need_include_path() { bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_); return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
} }
// override behavior // override behavior
void VisitStmt_(const ir::For* op) final; void VisitStmt_(const ir::For* op) final;
...@@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC {
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; void VisitExpr_(const FloatImm *op, std::ostream& os) final;
void VisitExpr_(const Call *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
void VisitStmt_(const Allocate *op) final;
void VisitStmt_(const AttrStmt *op) final;
private: private:
// Whether global barrier is needed. // Whether global barrier is needed.
...@@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_int8_{false}; bool enable_int8_{false};
// whether need math_constants.h // whether need math_constants.h
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
std::unordered_map<const Variable*, std::string> fragment_shapes;
std::unordered_map<const Variable*, std::string> fragment_layouts;
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size);
}; };
} // namespace codegen } // namespace codegen
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \brief Infer TensorCore metadata from tensor intrinsic.
* \file tensorcore_fragment.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
// Get fragment information from tensor intrinsics
class FragmentGetter : public IRVisitor {
public:
// fragment metadata
struct FragmentInfo {
// fragment shape
int m, n, k;
// fragment layout (row-major or column-major)
std::string layout;
FragmentInfo() = default;
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
: m(_m), n(_n), k(_k), layout(_layout) {}
};
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
// Get shape and layout information from load and store intrinsic
CHECK_EQ(op->args.size(), 8U);
const Variable* buffer_var = op->args[0].as<Variable>();
CHECK(buffer_var);
// Get shape
const IntImm* m = op->args[1].as<IntImm>();
const IntImm* n = op->args[2].as<IntImm>();
const IntImm* k = op->args[3].as<IntImm>();
const StringImm* layout = op->args[7].as<StringImm>();
CHECK(m);
CHECK(n);
CHECK(k);
CHECK(layout);
std::string scope = scopes[buffer_var];
if (fragments.count(buffer_var)) {
// check if the fragment has met before
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK_EQ(layout->value, info.layout);
}
} else {
// store metadata
FragmentInfo info;
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
info = FragmentInfo(m->value, n->value, k->value, layout->value);
} else if (scope == "wmma.accumulator") {
info = FragmentInfo(m->value, n->value, k->value, "");
}
fragments[buffer_var] = info;
}
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
// Get shape information from fill intrinsic
CHECK_EQ(op->args.size(), 6U);
const Variable* buffer_var = op->args[0].as<Variable>();
CHECK(buffer_var);
// Get shape
const IntImm* m = op->args[1].as<IntImm>();
const IntImm* n = op->args[2].as<IntImm>();
const IntImm* k = op->args[3].as<IntImm>();
CHECK(m);
CHECK(n);
CHECK(k);
std::string scope = scopes[buffer_var];
// Only wmma.accumulator can use tvm_fill_fragment
CHECK_EQ(scope, "wmma.accumulator");
if (fragments.count(buffer_var)) {
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
} else {
FragmentInfo info(m->value, n->value, k->value, "");
fragments[buffer_var] = info;
}
}
}
// Get memory scope
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buffer = op->node.as<Variable>();
CHECK(buffer);
scopes[buffer] = op->value.as<StringImm>()->value;
}
IRVisitor::Visit_(op);
}
// Memory scope for allocations
std::unordered_map<const Variable*, std::string> scopes;
// Fragment metadata for all fragments
std::unordered_map<const Variable*, FragmentInfo> fragments;
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public IRVisitor {
public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
void Visit_(const Call* op) final {
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U);
const Variable* buffer_var_d = op->args[0].as<Variable>();
const Variable* buffer_var_a = op->args[2].as<Variable>();
const Variable* buffer_var_b = op->args[4].as<Variable>();
const Variable* buffer_var_c = op->args[6].as<Variable>();
CHECK(buffer_var_d);
CHECK(buffer_var_a);
CHECK(buffer_var_b);
CHECK(buffer_var_c);
// Check all fragment A, B, C and D have the same shape
CHECK(CheckShape(buffer_var_d, buffer_var_a));
CHECK(CheckShape(buffer_var_d, buffer_var_b));
CHECK(CheckShape(buffer_var_d, buffer_var_c));
}
}
private:
// A tool for checking shapes of two fragments
bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
CHECK(fragment_getter.fragments.count(buffer1));
CHECK(fragment_getter.fragments.count(buffer2));
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
}
// Fragment infomation
const FragmentGetter &fragment_getter;
};
// Store the metadata into attributes
class InferFragmenter : public IRMutator {
public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
const Variable* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
// Add shape attribute to all fragments
std::string shape = std::to_string(info.m) + ", " +
std::to_string(info.n) + ", " +
std::to_string(info.k);
Expr shape_expr = StringImm::make(shape);
Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
StringImm::make(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
}
}
return stmt;
}
private:
// Fragment infomation
const FragmentGetter &fragment_getter;
};
Stmt InferFragment(Stmt stmt) {
FragmentGetter getter;
getter.Visit(stmt);
FragmentChecker(getter).Visit(stmt);
stmt = InferFragmenter(getter).Mutate(stmt);
return stmt;
}
LoweredFunc InferFragment(LoweredFunc f) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = InferFragment(f->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
...@@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { ...@@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt); return StorageAccessInfoLower().Mutate(stmt);
} }
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator { ...@@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator {
} }
} }
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
const Variable* buffer_var = op->args[1].as<Variable>();
Var var(GetRef<Var>(buffer_var));
const IntImm* flag = op->args[4].as<IntImm>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].read_count;
}
if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].write_count;
}
return expr;
} else {
return IRMutator::Mutate_(op, e);
}
}
private: private:
// RW statistics about data // RW statistics about data
struct Entry { struct Entry {
......
...@@ -50,7 +50,13 @@ enum class StorageRank { ...@@ -50,7 +50,13 @@ enum class StorageRank {
*/ */
kWarp = 2, kWarp = 2,
/*! \brief thread local memory */ /*! \brief thread local memory */
kLocal = 3 kLocal = 3,
/*! \brief wmma scope memory of matrix_a */
kWMMAMatrixA = 4,
/*! \brief wmma scope memory of matrix_b */
kWMMAMatrixB = 5,
/*! \brief wmma scope memory of accumulator */
kWMMAAccumulator = 6,
}; };
/*! /*!
...@@ -89,6 +95,9 @@ struct StorageScope { ...@@ -89,6 +95,9 @@ struct StorageScope {
case StorageRank::kShared: return "shared" + tag; case StorageRank::kShared: return "shared" + tag;
case StorageRank::kWarp: return "warp" + tag; case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag; case StorageRank::kLocal: return "local" + tag;
case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag;
case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag;
case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag;
default: LOG(FATAL) << "unknown storage scope"; return ""; default: LOG(FATAL) << "unknown storage scope"; return "";
} }
} }
...@@ -111,6 +120,15 @@ struct StorageScope { ...@@ -111,6 +120,15 @@ struct StorageScope {
} else if (s.compare(0, 5, "local") == 0) { } else if (s.compare(0, 5, "local") == 0) {
r.rank = StorageRank::kLocal; r.rank = StorageRank::kLocal;
r.tag = s.substr(5, std::string::npos); r.tag = s.substr(5, std::string::npos);
} else if (s.compare(0, 13, "wmma.matrix_a") == 0) {
r.rank = StorageRank::kWMMAMatrixA;
r.tag = s.substr(13, std::string::npos);
} else if (s.compare(0, 13, "wmma.matrix_b") == 0) {
r.rank = StorageRank::kWMMAMatrixB;
r.tag = s.substr(13, std::string::npos);
} else if (s.compare(0, 16, "wmma.accumulator") == 0) {
r.rank = StorageRank::kWMMAAccumulator;
r.tag = s.substr(16, std::string::npos);
} else { } else {
LOG(FATAL) << "unknown storage scope " << s; LOG(FATAL) << "unknown storage scope " << s;
} }
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from topi.testing import conv2d_nhwc_python
from tvm.contrib import nvcc
VERIFY = True
def intrin_wmma_load_matrix(shape, scope):
n, m, l = shape
if scope == "wmma.matrix_a":
row, col = n, l
elif scope == "wmma.matrix_b":
row, col = l, m
A = tvm.placeholder((row, col), name='A', dtype='float16')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col)
C = tvm.compute((row, col), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, n, m, l, BC.elem_offset // (row * col),
BA.access_ptr('r'), col, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm(shape):
n, m, l = shape
A = tvm.placeholder((n, l), name='A', dtype='float16')
B = tvm.placeholder((l, m), name='B', dtype='float16')
k = tvm.reduce_axis((0, l), name="k")
C = tvm.compute((n, m),
lambda ii, jj:
tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
name='C')
BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l)
BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m)
BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
def intrin_func(ins, outs):
BA, BB = ins
BC, = outs
def init():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
return ib.get()
def update():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
BC.data, BC.elem_offset // (n * m),
BA.data, BA.elem_offset // (n * l),
BB.data, BB.elem_offset // (l * m),
BC.data, BC.elem_offset // (n * m)))
return ib.get()
return update(), init(), update()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_wmma_store_matrix(shape):
n, m, l = shape
A = tvm.placeholder((n, m), name='A', dtype='float32')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
C = tvm.compute((n, m), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, n, m, l, BA.elem_offset // (n * m),
BC.access_ptr('w'), m, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def test_tensor_core_batch_matmal():
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
print("skip because gpu does not support tensor core")
return
batch_size = 4
n = 512
m, l = n, n
assert (n % 32 == 0)
assert (m % 8 == 0)
assert (l % 16 == 0)
nn, mm, ll = n // 32, m // 8, l // 16
A = tvm.placeholder((batch_size, nn, ll, 32, 16), name='A', dtype='float16')
B = tvm.placeholder((batch_size, ll, mm, 16, 8), name='B', dtype='float16')
k1 = tvm.reduce_axis((0, ll), name='k1')
k2 = tvm.reduce_axis((0, 16), name='k2')
C = tvm.compute((batch_size, nn, mm, 32, 8),
lambda b, i, j, ii, jj:
tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]),
name='Fragment_C')
s = tvm.create_schedule(C.op)
warp_size = 32
kernel_size = 16
block_row_warps = 2
block_col_warps = 4
warp_row_tiles = 4
warp_col_tiles = 2
chunk = 4
block_x = tvm.thread_axis('blockIdx.x')
block_y = tvm.thread_axis('blockIdx.y')
block_z = tvm.thread_axis('blockIdx.z')
thread_x = tvm.thread_axis('threadIdx.x')
thread_y = tvm.thread_axis('threadIdx.y')
thread_z = tvm.thread_axis('threadIdx.z')
AS = s.cache_read(A, 'shared', [C])
BS = s.cache_read(B, 'shared', [C])
AF = s.cache_read(AS, 'wmma.matrix_a', [C])
BF = s.cache_read(BS, 'wmma.matrix_b', [C])
CF = s.cache_write(C, 'wmma.accumulator')
b, i, j, kernel_i, kernel_j = s[C].op.axis
i, ii = s[C].split(i, factor=warp_row_tiles)
block_i, i = s[C].split(i, factor=block_row_warps)
j, jj = s[C].split(j, factor=warp_col_tiles)
block_j, j = s[C].split(j, factor=block_col_warps)
s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j)
s[C].bind(b, block_z)
s[C].bind(block_i, block_x)
s[C].bind(block_j, block_y)
s[C].bind(i, thread_y)
s[C].bind(j, thread_z)
s[CF].compute_at(s[C], j)
b, warp_i, warp_j, _i, _j = s[CF].op.axis
k, _k = CF.op.reduce_axis
ko, ki = s[CF].split(k, factor=chunk)
s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k)
s[AF].compute_at(s[CF], ki)
s[BF].compute_at(s[CF], ki)
s[AS].compute_at(s[CF], ko)
b, xo, yo, xi, yi = AS.op.axis
tx, xo = s[AS].split(xo, nparts=block_row_warps)
ty, yo = s[AS].split(yo, nparts=block_col_warps)
t = s[AS].fuse(xi, yi)
to, ti = s[AS].split(t, nparts=warp_size)
s[AS].bind(tx, thread_y)
s[AS].bind(ty, thread_z)
s[AS].bind(to, thread_x)
s[BS].compute_at(s[CF], ko)
b, xo, yo, xi, yi = BS.op.axis
tx, xo = s[BS].split(xo, nparts=block_row_warps)
ty, yo = s[BS].split(yo, nparts=block_col_warps)
t = s[BS].fuse(xi, yi)
to, ti = s[BS].split(t, nparts=warp_size)
s[BS].bind(tx, thread_y)
s[BS].bind(ty, thread_z)
s[BS].bind(to, thread_x)
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_a'))
s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_b'))
s[C].tensorize(kernel_i, intrin_wmma_store_matrix((32, 8, 16)))
s[CF].tensorize(_i, intrin_wmma_gemm((32, 8, 16)))
func = tvm.build(s, [A, B, C], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype)
b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((batch_size, nn, mm, 32, 8), dtype=C.dtype), ctx)
func(a, b, c)
evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3))
if VERIFY:
func(a, b, c)
a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4)
def test_tensor_core_batch_conv():
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
print("skip because gpu does not support tensor core")
return
# The sizes of inputs and filters
batch_size = 32
height = 14
width = 14
in_channels = 32
out_channels = 64
kernel_h = 3
kernel_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
block_size = 16
block_row_warps = 2
block_col_warps = 4
warp_row_tiles = 4
warp_col_tiles = 2
warp_size = 32
chunk = 2
# Input feature map: (N, H, W, IC, n, ic)
data_shape = (batch_size // block_size,
height,
width,
in_channels // block_size,
block_size,
block_size)
# Kernel: (H, W, IC, OC, ic, oc)
kernel_shape = (kernel_h,
kernel_w,
in_channels // block_size,
out_channels // block_size,
block_size,
block_size)
# Output feature map: (N, H, W, OC, n, oc)
output_shape = (batch_size // block_size,
height,
width,
out_channels // block_size,
block_size,
block_size)
assert (batch_size % block_size == 0)
assert (in_channels % block_size == 0)
assert (out_channels % block_size == 0)
kh = tvm.reduce_axis((0, kernel_h), name='kh')
kw = tvm.reduce_axis((0, kernel_w), name='kw')
ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
ii = tvm.reduce_axis((0, block_size), name='ii')
# Algorithm
A = tvm.placeholder(data_shape, name='A', dtype="float16")
W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
Apad = tvm.compute(
(batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
block_size),
lambda n, h, w, i, nn, ii: tvm.if_then_else(
tvm.all(h >= pad_h, h - pad_h < height,
w >= pad_w, w - pad_w < width),
A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")),
name='Apad')
Conv = tvm.compute(output_shape,
lambda n, h, w, o, nn, oo: tvm.sum(
Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
W[kh, kw, ic, o, ii, oo].astype("float32"),
axis=[ic, kh, kw, ii]),
name="Conv")
s = tvm.create_schedule(Conv.op)
s[Apad].compute_inline()
AS = s.cache_read(Apad, 'shared', [Conv])
WS = s.cache_read(W, 'shared', [Conv])
AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
ConvF = s.cache_write(Conv, 'wmma.accumulator')
block_x = tvm.thread_axis('blockIdx.x')
block_y = tvm.thread_axis('blockIdx.y')
block_z = tvm.thread_axis('blockIdx.z')
thread_x = tvm.thread_axis('threadIdx.x')
thread_y = tvm.thread_axis('threadIdx.y')
thread_z = tvm.thread_axis('threadIdx.z')
nc, hc, wc, oc, nnc, ooc = Conv.op.axis
block_k = s[Conv].fuse(hc, wc)
s[Conv].bind(block_k, block_z)
nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
block_i, nc = s[Conv].split(nc, factor=block_row_warps)
oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
block_j, oc = s[Conv].split(oc, factor=block_col_warps)
s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
s[Conv].bind(block_i, block_x)
s[Conv].bind(block_j, block_y)
s[Conv].bind(nc, thread_y)
s[Conv].bind(oc, thread_z)
s[ConvF].compute_at(s[Conv], oc)
n, h, w, o, nnf, oof = ConvF.op.axis
ko, ki = s[ConvF].split(ic, factor=chunk)
s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
s[AF].compute_at(s[ConvF], kw)
s[WF].compute_at(s[ConvF], kw)
s[WS].compute_at(s[ConvF], kh)
s[AS].compute_at(s[ConvF], kh)
n, h, w, i, nn, ii = AS.op.axis
tx, xo = s[AS].split(n, nparts=block_row_warps)
ty, yo = s[AS].split(xo, nparts=block_col_warps)
t = s[AS].fuse(nn, ii)
to, ti = s[AS].split(t, factor=warp_size)
s[AS].bind(tx, thread_y)
s[AS].bind(ty, thread_z)
s[AS].bind(ti, thread_x)
kh, kw, ic, o, ii, oo = WS.op.axis
tx, xo = s[WS].split(o, nparts=block_row_warps)
ty, yo = s[WS].split(xo, nparts=block_col_warps)
t = s[WS].fuse(ii, oo)
to, ti = s[WS].split(t, nparts=warp_size)
s[WS].bind(tx, thread_y)
s[WS].bind(ty, thread_z)
s[WS].bind(to, thread_x)
s[WS].vectorize(ti)
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a'))
s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b'))
s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16)))
s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16)))
func = tvm.build(s, [A, W, Conv], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
if VERIFY:
func(a, w, c)
a_np = a_np.transpose(0, 4, 1, 2, 3, 5).reshape(batch_size, height, width, in_channels)
w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(kernel_h, kernel_w, in_channels, out_channels)
c_np = c.asnumpy().transpose((0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width, out_channels)
c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype),
w_np.astype(Conv.dtype),
(stride_h, stride_w),
(pad_h, pad_w)).astype(Conv.dtype)
np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
test_tensor_core_batch_matmal()
test_tensor_core_batch_conv()
...@@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
Returns Returns
------- -------
b_np : np.ndarray b_np : np.ndarray
4-D with shape [out_height, out_width, out_channel, batch] 4-D with shape [batch, out_height, out_width, out_channel]
""" """
batch, in_height, in_width, in_channel = a_np.shape batch, in_height, in_width, in_channel = a_np.shape
kernel_h, kernel_w, _, num_filter = w_np.shape kernel_h, kernel_w, _, num_filter = w_np.shape
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _opt-conv-tensorcore:
How to optimize convolution using TensorCores
==================================
**Author**: `Siyuan Feng <https://github.com/Hzfengsy>`_
In this tutorial, we will demonstrate how to write a high performance convolution
schedule using TensorCores in TVM. In this example, we assume the input to
convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first.
"""
################################################################
# TensorCore Introduction
# -------------------------
# Each Tensor Core provides a 4x4x4 matrix processing array that operates
# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows.
# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation
# matrices C and D may be FP16 or FP32 matrices.
#
# However, CUDA programmers can only use warp-level primitive
# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform
# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking
# the matrix multiplication, programmers must load data from memory into registers
# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates
# that primitive into multiple memory load instructions. At run time, every thread loads
# 16 elements from matrix A and 16 elements from B.
################################################################
# Preparation and Algorithm
# --------------------------
# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions.
# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3.
# We use stride size 1 and padding size 1 for the convolution. In the example, we use
# NHWCnc memory layout.The following code defines the convolution algorithm in TVM.
import tvm
import numpy as np
from tvm.contrib import nvcc
# The sizes of inputs and filters
batch_size = 256
height = 14
width = 14
in_channels = 256
out_channels = 512
kernel_h = 3
kernel_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
# TensorCore shape
block_size = 16
assert (batch_size % block_size == 0)
assert (in_channels % block_size == 0)
assert (out_channels % block_size == 0)
# Input feature map: (N, H, W, IC, n, ic)
data_shape = (batch_size // block_size,
height,
width,
in_channels // block_size,
block_size,
block_size)
# Kernel: (H, W, IC, OC, ic, oc)
kernel_shape = (kernel_h,
kernel_w,
in_channels // block_size,
out_channels // block_size,
block_size,
block_size)
# Output feature map: (N, H, W, OC, n, oc)
output_shape = (batch_size // block_size,
height,
width,
out_channels // block_size,
block_size,
block_size)
# Reduction axes
kh = tvm.reduce_axis((0, kernel_h), name='kh')
kw = tvm.reduce_axis((0, kernel_w), name='kw')
ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
ii = tvm.reduce_axis((0, block_size), name='ii')
# Algorithm
A = tvm.placeholder(data_shape, name='A', dtype="float16")
W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
Apad = tvm.compute(
(batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
block_size),
lambda n, h, w, i, nn, ii: tvm.if_then_else(
tvm.all(h >= pad_h, h - pad_h < height,
w >= pad_w, w - pad_w < width),
A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")),
name='Apad')
Conv = tvm.compute(output_shape,
lambda n, h, w, o, nn, oo: tvm.sum(
Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
W[kh, kw, ic, o, ii, oo].astype("float32"),
axis=[ic, kh, kw, ii]),
name="Conv")
s = tvm.create_schedule(Conv.op)
s[Apad].compute_inline()
###############################################################################
# Memory Scope
# ----------------
#
# In traditional GPU schedule, we have global, shared and local memory scope.
# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`,
# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope
# stores at the on-chip registers level, the same place with local memory.
# Designate the memory hierarchy
AS = s.cache_read(Apad, 'shared', [Conv])
WS = s.cache_read(W, 'shared', [Conv])
AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
ConvF = s.cache_write(Conv, 'wmma.accumulator')
###############################################################################
# Define Tensor Intrinsic
# In fact, TensorCore is a special hardware operation. So, we can just use tensorize
# to replace a unit of computation with the TensorCore instruction. The first thing is
# that we need to define tensor intrinsic.
#
# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`,
# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync`
# are both used in matrix multiplication, so we can just write following three intrinsics.
def intrin_wmma_load_matrix(scope):
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float16')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, n, n, n, BC.elem_offset // 256,
BA.access_ptr('r'), n, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm():
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float16')
B = tvm.placeholder((n, n), name='B', dtype='float16')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((n, n),
lambda ii, jj:
tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
name='C')
BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
BA, BB = ins
BC, = outs
def init():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
return ib.get()
def update():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
BC.data, BC.elem_offset // 256,
BA.data, BA.elem_offset // 256,
BB.data, BB.elem_offset // 256,
BC.data, BC.elem_offset // 256))
return ib.get()
return update(), init(), update()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_wmma_store_matrix():
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float32')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, n, n, n, BA.elem_offset // 256,
BC.access_ptr('w'), n, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
###############################################################################
# Scheduling the Computation
# --------------------------
# To use TensorCores in TVM, we must schedule the computation into specific structure
# to match the tensor intrinsic. The same as traditional GPU programs, we can also use
# shared memory to boost the speed. If you have any questions about blocking and shared
# memory, please refer :ref:`opt-conv-gpu`.
#
# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore
# instructions. Thus, the output shape of each warp is 64x32 and each block outputs
# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles)
# one time.
#
# .. note::
#
# *Warp-level Operation*
#
# Note that all TensorCore instructions are warp-level instructions, which means all 32 threads
# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the
# easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain
# TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution.
# The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time.
#
# Define tiling sizes
block_row_warps = 4
block_col_warps = 2
warp_row_tiles = 2
warp_col_tiles = 4
warp_size = 32
chunk = 2
block_x = tvm.thread_axis('blockIdx.x')
block_y = tvm.thread_axis('blockIdx.y')
block_z = tvm.thread_axis('blockIdx.z')
thread_x = tvm.thread_axis('threadIdx.x')
thread_y = tvm.thread_axis('threadIdx.y')
thread_z = tvm.thread_axis('threadIdx.z')
nc, hc, wc, oc, nnc, ooc = Conv.op.axis
block_k = s[Conv].fuse(hc, wc)
s[Conv].bind(block_k, block_z)
nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
block_i, nc = s[Conv].split(nc, factor=block_row_warps)
oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
block_j, oc = s[Conv].split(oc, factor=block_col_warps)
s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
s[Conv].bind(block_i, block_x)
s[Conv].bind(block_j, block_y)
s[Conv].bind(nc, thread_y)
s[Conv].bind(oc, thread_z)
# Schedule local computation
s[ConvF].compute_at(s[Conv], oc)
n, h, w, o, nnf, oof = ConvF.op.axis
ko, ki = s[ConvF].split(ic, factor=chunk)
s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
# Move intermediate computation into each output compute tile
s[AF].compute_at(s[ConvF], kw)
s[WF].compute_at(s[ConvF], kw)
# Schedule for A's share memory
s[AS].compute_at(s[ConvF], kh)
n, h, w, i, nn, ii = AS.op.axis
tx, xo = s[AS].split(n, nparts=block_row_warps)
ty, yo = s[AS].split(xo, nparts=block_col_warps)
t = s[AS].fuse(nn, ii)
to, ti = s[AS].split(t, factor=warp_size)
s[AS].bind(tx, thread_y)
s[AS].bind(ty, thread_z)
s[AS].bind(ti, thread_x)
# Schedule for W's share memory
s[WS].compute_at(s[ConvF], kh)
kh, kw, ic, o, ii, oo = WS.op.axis
tx, xo = s[WS].split(o, nparts=block_row_warps)
ty, yo = s[WS].split(xo, nparts=block_col_warps)
t = s[WS].fuse(ii, oo)
to, ti = s[WS].split(t, nparts=warp_size)
s[WS].bind(tx, thread_y)
s[WS].bind(ty, thread_z)
s[WS].bind(to, thread_x)
s[WS].vectorize(ti)
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
###############################################################################
# Lowering Computation to Intrinsics
# --------------------------
# The last phase is to lower the computation loops down to TensorCore hardware intrinsics
# by mapping the 2D convolution to tensor intrinsics
#
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
s[ConvF].tensorize(nnf, intrin_wmma_gemm())
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
###############################################################################
# Generate CUDA Kernel
# --------------------
# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution.
# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
# be able to run on our build server
ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.build_config(auto_unroll_max_step=16):
func = tvm.build(s, [A, W, Conv], 'cuda')
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
###############################################################################
# Summary
# This tutorial demonstrates how TVM scheduling primitives can be used to
# call TensorCores on specific GPUs.
...@@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs): ...@@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs):
if debug_flag: if debug_flag:
pass_list.append((1, add_debug)) pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, ir_pass.cpu_access_rewrite))
return tvm.build_config(add_lower_pass=pass_list, **kwargs) return tvm.build_config(add_lower_pass=pass_list, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment