Commit d64bf6b5 by Minmin Sun (孙敏敏) Committed by Tianqi Chen

Auto TensorCore CodeGen (#4234)

* Add Auto TensorCore TensorCore Unit Test

* Rebase to tvm master branch & Add auto tensor core

* Code Refine

* Add tensor core switch by pragma

* Add pragma in tensor core example code

* Get real tile size to replace hard coded 16

* support more than 2 dimensions (e.g. batchmatmul) for buffer bind scope

* support batch matmul

* Move cuda env check to tensor_core.cc

* Coderefine for tensor_core.cc

* Refine comments

* Some refinements of code and comment

* Update TensorCore UT to pass the CPU test

* remove redundant code

* matmul's storage align for different layout

* Add support for differenct position of type cast

* Add formal tutorial for auto tensorcore codegen

* move tensorcore check up to tutorial code

* code and doc refine

* comment out tune_and_evaluate in tutorial

* fix cpplint error
parent 281f643c
......@@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope";
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
......
......@@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
......
......@@ -387,6 +387,7 @@ def lower(sch,
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
......
......@@ -94,6 +94,12 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
}
});
TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tensor_core.cc
*/
// IR Passes for TensorCore CodeGen
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/target_info.h>
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
using intrinsic::tvm_address_of;
struct Tile {
int m{-1};
int n{-1};
int k{-1};
};
std::string simplify_name(std::string input) {
auto pos = input.find(".");
if (pos != std::string::npos) {
return input.substr(0, pos);
} else {
return input;
}
}
Expr unpack_type_cast(const Expr &input, const Type &target_type) {
auto cast = input.as<Cast>();
if (cast == nullptr) {
return input;
} else if (cast->type == target_type) {
return cast->value;
}
return Expr();
}
// MMAMatcher matches C = Cast(A)*Cast(B)+C,
// where A & B are fp16/int8 local buffers,
// and C is fp32/int32 local buffer.
class MMAMatcher: public IRVisitor {
public:
explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
BufferInfo bi;
bi.name = kv.second->name;
bi.dtype = kv.second->dtype;
bi.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
}
}
using IRVisitor::Visit_;
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
tensor_core_on_ = true;
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
Visit(op->body);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Provide* op) final {
IRVisitor::Visit_(op);
auto it = buf_map_.find(TensorKey{op->func, op->value_index});
if (it == buf_map_.end()) {
return;
}
const BufferInfo& bi = it->second;
if (bi.released) {
return;
}
if (tensor_core_on_ && mma_sync_match_(op, bi)) {
matched_ = true;
}
}
void Visit_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
if (!buf_map_.at(key).external) {
return;
}
Visit(op->body);
} else {
BufferInfo bi;
bi.name = key.GetName();
bi.dtype = op->type;
buf_map_[key] = bi;
Visit(op->body);
buf_map_[key].released = true;
}
}
inline bool Matched() const {return matched_;}
friend class ScheduleAnalyser;
friend class BufferAnalyser;
private:
struct BufferInfo {
std::string name;
Type dtype;
bool external{false};
bool released{false};
bool same_as(const BufferInfo &bi) {
if (this->dtype != bi.dtype) return false;
if (this->name != bi.name) return false;
if (this->external != bi.external) return false;
if (this->released != bi.released) return false;
return true;
}
};
// Check whether the storage scope is local
bool check_local_buffer_(const Call* op, BufferInfo* bi) {
if (op->call_type == Call::Halide) {
auto it = storage_scope_.find(op->func.get());
if (it == storage_scope_.end()) {
return false;
}
const std::string& strkey = it->second;
if (strkey != "local") {
return false;
}
auto it1 = buf_map_.find(TensorKey{op->func, op->value_index});
if (it1 == buf_map_.end()) {
return false;
}
*bi = it1->second;
if (bi->released) {
return false;
}
return true;
}
return false;
}
// Do the pattern matching
bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) {
auto* add = op->value.as<Add>();
if (add == nullptr) {
return false;
}
auto* load_c = add->a.as<Call>();
BufferInfo buffer_c;
if (!check_local_buffer_(load_c, &buffer_c)
|| !buffer_c.same_as(store_buffer)
|| !(buffer_c.dtype == Float(32) ||
buffer_c.dtype == Int(32))) {
return false;
}
auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<Mul>();
if (mul == nullptr) {
return false;
}
auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
auto load_a = load_a_expr.as<Call>();
BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a)
|| !(buffer_a.dtype == Float(16) ||
buffer_a.dtype == Int(8))) {
return false;
}
auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
auto load_b = load_b_expr.as<Call>();
BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b)
|| !(buffer_b.dtype == Float(16) ||
buffer_b.dtype == Int(8))) {
return false;
}
frag_reg_.insert(buffer_c.name);
frag_reg_.insert(buffer_a.name);
frag_reg_.insert(buffer_b.name);
buf_name_.insert(std::make_pair(load_a, buffer_a.name));
buf_name_.insert(std::make_pair(load_b, buffer_b.name));
mma_sync_.insert(std::make_pair(op,
Array<Expr>{load_a_expr, load_b_expr, add->a}));
return true;
}
std::unordered_map<TensorKey, BufferInfo> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_;
std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
std::unordered_map<const Node*, std::string> buf_name_;
std::unordered_set<std::string> frag_reg_;
bool matched_{false};
bool tensor_core_on_{false};
};
// BodyVisitor visits the body stmt of original ComputeOp
// to get the access indices of input matrices,
// if it is recognized as matrix multiply.
class BodyVisitor : public IRVisitor {
public:
BodyVisitor() {}
using IRVisitor::Visit_;
void Visit_(const Reduce* op) final {
auto* comm_add = op->combiner->result[0].as<Add>();
if (comm_add == nullptr || op->combiner->result.size() > 1) {
return;
}
for (Expr source : op->source) {
auto mul_0 = unpack_type_cast(source, Float(32)).as<Mul>();
auto mul_1 = unpack_type_cast(source, Int(32)).as<Mul>();
if (mul_0 == nullptr && mul_1 == nullptr) {
continue;
}
tensorcore_candidate_ = true;
IRVisitor::Visit(source);
}
}
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
args_.insert(std::make_pair(op->name, op->args));
}
friend class ScheduleAnalyser;
private:
std::unordered_map<std::string, Array<Expr>> args_;
bool tensorcore_candidate_{false};
};
// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major
class ScheduleAnalyser {
public:
explicit ScheduleAnalyser(const MMAMatcher &mma_matcher)
: mma_sync_(mma_matcher.mma_sync_),
buf_name_(mma_matcher.buf_name_) {}
bool MatrixIdentify(Schedule schedule) {
// TODO(minmin): handle the case where MatMul is not the output stage
for (Operation output : schedule->outputs) {
const ComputeOpNode* compute = output.as<ComputeOpNode>();
if (compute == nullptr) {
// Not a ComputeOp
continue;
}
auto axis = compute->axis;
auto reduce_axis = compute->reduce_axis;
if (axis.size() < 2 || reduce_axis.size() != 1) {
continue;
}
const Variable* axis_var[2];
const Variable* reduce_axis_var;
axis_var[0] = axis[axis.size()-2]->var.as<Variable>();
axis_var[1] = axis[axis.size()-1]->var.as<Variable>();
reduce_axis_var = reduce_axis[0]->var.as<Variable>();
BodyVisitor body_visitor;
for (Expr expr : compute->body) {
body_visitor.Visit(expr);
}
if (!body_visitor.tensorcore_candidate_) {
continue;
}
for (auto iter : body_visitor.args_) {
auto name = iter.first;
auto args = iter.second;
if (args.size() < 2) {
continue;
}
const Variable* var0 = args[args.size() - 2].as<Variable>();
const Variable* var1 = args[args.size() - 1].as<Variable>();
if (var0 == nullptr || var1 == nullptr) {
continue;
}
std::string matrix_abc, major;
if (var0 == reduce_axis_var && var1 == axis_var[1]) {
matrix_abc = "matrix_a";
major = "col_major";
} else if (var0 == reduce_axis_var && var1 == axis_var[0]) {
matrix_abc = "matrix_b";
major = "row_major";
} else if (var0 == axis_var[1] && var1 == reduce_axis_var) {
matrix_abc = "matrix_a";
major = "row_major";
} else if (var0 == axis_var[0] && var1 == reduce_axis_var) {
matrix_abc = "matrix_b";
major = "col_major";
}
matrix_abc_.insert(std::make_pair(name, matrix_abc));
matrix_major_.insert(std::make_pair(name, major));
}
matrix_abc_.insert(std::make_pair(compute->name, "accumulator"));
matrix_major_.insert(std::make_pair(compute->name, "col_major"));
}
for (auto &mma_sync : mma_sync_) {
auto &operands = mma_sync.second;
auto* load_a = operands[0].as<Call>();
auto* load_b = operands[1].as<Call>();
auto input0 = simplify_name(buf_name_.find(load_a)->second);
auto input1 = simplify_name(buf_name_.find(load_b)->second);
auto it0 = matrix_abc_.find(input0);
auto it1 = matrix_abc_.find(input1);
if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) {
return false;
}
if (it0->second == "matrix_a" && it1->second == "matrix_b") {
return true;
} else if (it0->second == "matrix_b" && it1->second == "matrix_a") {
mma_sync.second = Array<Expr>{operands[1], operands[0], operands[2]};
} else {
return false;
}
}
return true;
}
friend class BufferAnalyser;
friend class TensorCoreIRMutator;
private:
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
std::unordered_map<const Node*, std::string> buf_name_;
};
// IndexVisitor visits access index of fragment
// to record variable for loop scaling
class IndexVisitor : public IRVisitor {
public:
IndexVisitor() {}
using IRVisitor::Visit_;
void Visit_(const Variable* op) final {
loop_scaling_.insert(std::make_pair(op, scaling_factor_));
}
friend class BufferAnalyser;
friend class TensorCoreIRMutator;
private:
std::unordered_map<const Variable*, unsigned> loop_scaling_;
unsigned scaling_factor_{0};
};
// BufferAnalyser gets buffer info,
// e.g. thread tile and warp tile, for TensorCore CodeGen
class BufferAnalyser : public IRVisitor {
public:
explicit BufferAnalyser(Map<Tensor, Buffer> extern_buffer,
const ScheduleAnalyser &schedule_analyser,
const MMAMatcher &mma_matcher)
: matrix_abc_(schedule_analyser.matrix_abc_),
matrix_major_(schedule_analyser.matrix_major_),
frag_reg_(mma_matcher.frag_reg_) {
for (auto kv : extern_buffer) {
BufferInfo bi;
bi.name = kv.second->name;
bi.dtype = kv.second->dtype;
bi.strides = kv.second->strides;
bi.shape = kv.second->shape;
bi.external = true;
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
}
}
using IRVisitor::Visit_;
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::thread_extent) {
if (const IntImm* value = op->value.as<IntImm>()) {
thread_extent_.insert(
std::make_pair(
op->node.as<IterVarNode>()->var->name_hint,
value->value));
}
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
Visit(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor = Downcast<Tensor>(op->node);
const Call* tuple = op->value.as<Call>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
size_t dim = tuple->args[0].as<IntImm>()->value;
if (dim >= vinfo.size()) {
vinfo.resize(dim + 1);
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
Visit(op->body);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Provide* op) final {
IRVisitor::Visit_(op);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferInfo& bi = it->second;
CHECK(!bi.released)
<< "Read a buffer that is already out of scope";
if (matrix_abc_.count(key.GetName())) {
if (bi.shape.size() < 2) {
invalid_ = true;
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
const IntImm* shape = bi.shape[i].as<IntImm>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
}
}
}
Array<Expr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
Expr stride = IntImm::make(Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = Mul::make(stride, bi.shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(Int(32), 1));
}
strides_.insert(std::make_pair(key.GetName(), strides));
if (frag_reg_.count(bi.name)) {
Expr dst = Call::make(bi.dtype,
bi.name,
op->args,
Call::Halide,
op->func,
0);
frag_load_.insert(std::make_pair(op, dst));
auto rel_index = bi.RelIndex(op->args);
if (op->args.size() < 2) {
invalid_ = true;
return;
}
std::vector<int> tile_size;
for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
index_visitor.scaling_factor_ = 16;
if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
tile_size.push_back(shape->value);
index_visitor.scaling_factor_ = shape->value;
} else {
invalid_ = true;
return;
}
auto index = rel_index[i];
auto simplified_index = ir::Simplify(index);
index_visitor.Visit(simplified_index);
}
std::string input_name = simplify_name(bi.name);
auto it = matrix_abc_.find(input_name);
auto it2 = matrix_major_.find(input_name);
bool ret = true;
if (it != matrix_abc_.end() && it2 != matrix_major_.end()) {
if (it->second == "matrix_a" && it2->second == "col_major") {
ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
}
if (it->second == "matrix_a" && it2->second == "row_major") {
ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.m, tile_size[1]);
}
if (it->second == "matrix_b" && it2->second == "col_major") {
ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
}
if (it->second == "matrix_b" && it2->second == "row_major") {
ret &= assign_or_check_(&thread_tile_.n, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
}
if (it->second == "accumulator") {
ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
}
if (!ret) {
invalid_ = true;
return;
}
}
}
const Call* value = op->value.as<Call>();
if (value != nullptr && frag_reg_.count(value->name)) {
Expr dst = Call::make(bi.dtype,
bi.name,
op->args,
Call::Halide,
op->func,
0);
frag_store_.insert(std::make_pair(op, dst));
}
}
void Visit_(const Call* op) final {
IRVisitor::Visit_(op);
if (op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferInfo& bi = it->second;
CHECK(!bi.released)
<< "Read a buffer that is already out of scope";
if (matrix_abc_.count(op->name)) {
if (bi.shape.size() < 2) {
invalid_ = true;
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
const IntImm* shape = bi.shape[i].as<IntImm>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
}
}
}
Array<Expr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
Expr stride = IntImm::make(Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = Mul::make(stride, bi.shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(Int(32), 1));
}
strides_.insert(std::make_pair(key.GetName(), strides));
if (!frag_reg_.count(bi.name)) {
return;
}
auto rel_index = bi.RelIndex(op->args);
if (op->args.size() < 2) {
invalid_ = true;
return;
}
for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
index_visitor.scaling_factor_ = 16;
if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
index_visitor.scaling_factor_ = shape->value;
}
auto index = rel_index[i];
auto simplified_index = ir::Simplify(index);
index_visitor.Visit(simplified_index);
}
}
}
void Visit_(const Realize* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
Visit(op->body);
} else {
// create a buffer entry
BufferInfo bi;
bi.bounds = op->bounds;
Array<Expr> shape;
for (auto r : bi.bounds) {
shape.push_back(r->extent);
}
Array<Expr> strides;
if (dim_align_.count(key) != 0 && shape.size() != 0) {
std::vector<Expr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
int first_dim = 0;
Expr stride = make_const(shape[first_dim].type(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + \
indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
stride = stride * shape[dim];
}
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
}
bi.name = key.GetName();
bi.dtype = op->type;
bi.strides = strides;
bi.shape = shape;
buf_map_[key] = bi;
Visit(op->body);
buf_map_[key].released = true;
}
}
// Derive warp tile from thread tile,
// and check whether it is qualified for TensorCore.
bool QualifiedForTensorCore() {
if (invalid_) {
return false;
}
auto itx = thread_extent_.find("threadIdx.x");
if (itx == thread_extent_.end()) {
return false;
}
int warp_threads_x = itx->second;
warp_tile_.m = warp_threads_x * thread_tile_.m;
warp_threads_y_ = 32 / warp_threads_x;
auto ity = thread_extent_.find("threadIdx.y");
if (ity == thread_extent_.end()) {
return false;
}
if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) {
return false;
}
warp_tile_.n = warp_threads_y_ * thread_tile_.n;
warp_tile_.k = thread_tile_.k;
return supported_warp_tile_();
}
friend class TensorCoreIRMutator;
private:
struct DimAlignInfo {
int align_factor{0};
int align_offset{0};
};
struct BufferInfo {
std::string name;
Type dtype;
Array<Expr> strides;
Array<Expr> shape;
Region bounds;
bool external{false};
bool released{false};
inline Array<Expr> RelIndex(Array<Expr> args) const {
if (bounds.size() != 0) {
Array<Expr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
}
return index;
} else {
return args;
}
}
};
bool assign_or_check_(int* dst, int src) {
if (*dst <= 0) {
*dst = src;
return true;
}
if (*dst == src) {
return true;
}
return false;
}
bool supported_warp_tile_() {
if (warp_tile_.m == 16 &&
warp_tile_.n == 16 &&
warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 32 &&
warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 32 &&
warp_tile_.n == 8 &&
warp_tile_.k == 16) {
return true;
}
return false;
}
std::unordered_map<TensorKey, BufferInfo> buf_map_;
std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
std::unordered_map<const Node*, std::string> storage_scope_;
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<std::string, Array<Expr>> strides_;
std::unordered_map<const Provide*, Expr> frag_load_;
std::unordered_map<const Provide*, Expr> frag_store_;
std::unordered_map<std::string, int> thread_extent_;
IndexVisitor index_visitor;
Tile warp_tile_;
Tile thread_tile_;
int warp_threads_y_{-1};
bool invalid_{false};
};
// ThreadIdxMutator does the thread index unification inside a warp
class ThreadIdxMutator : public IRMutator {
public:
explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
Expr Mutate_(const Variable* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
op = expr.as<Variable>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
Expr zero = IntImm::make(Int(32), 0);
return zero;
}
if (op->name_hint == "threadIdx.y") {
Expr div = Div::make(expr, warp_y_);
Expr mul = Mul::make(div, warp_y_);
return mul;
}
}
return expr;
}
private:
Expr warp_y_;
};
// TensorCoreIRMutator mutates the AST for TensorCore CodeGen
// based on tensor core intrinsics
class TensorCoreIRMutator : public IRMutator {
public:
explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser,
const BufferAnalyser &buffer_analyser)
: matrix_abc_(schedule_analyser.matrix_abc_),
matrix_major_(schedule_analyser.matrix_major_),
mma_sync_(schedule_analyser.mma_sync_),
strides_(buffer_analyser.strides_),
frag_reg_(buffer_analyser.frag_reg_),
loop_scaling_(buffer_analyser.index_visitor.loop_scaling_),
frag_load_(buffer_analyser.frag_load_),
frag_store_(buffer_analyser.frag_store_),
warp_tile_(buffer_analyser.warp_tile_),
warp_threads_y_(buffer_analyser.warp_threads_y_) {}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
bounds_[key] = op->bounds;
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Realize>();
if (op != nullptr) {
if (!frag_reg_.count(key.GetName())) {
return stmt;
}
auto new_extents = get_tile_size_(simplify_name(key.GetName()));
Region new_bounds;
for (size_t i = 0; i < op->bounds.size() - 2; ++i) {
new_bounds.push_back(op->bounds[i]);
}
CHECK_GE(op->bounds.size(), 2)
<< "Less than 2 dimensions for matrix " << key.GetName();
new_bounds.push_back(Range::make_by_min_extent(
op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
new_bounds.push_back(Range::make_by_min_extent(
op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
op->condition, op->body);
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
if (op->attr_key == attr::realize_scope) {
auto node = op->node.as<OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
return stmt;
}
auto it = matrix_abc_.find(simplify_name(node->name));
CHECK(it != matrix_abc_.end())
<< "Cannot find matrix info for " << node->name;
auto matrix_abc = "wmma." + it->second;
Stmt body = Mutate(op->body);
return AttrStmt::make(op->node,
op->attr_key,
matrix_abc,
body);
}
}
return stmt;
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
const auto &operands = it->second;
Expr a = operands[0];
auto ca = a.as<Call>();
Expr b = operands[1];
auto cb = b.as<Call>();
Expr c = operands[2];
auto cc = c.as<Call>();
NodePtr<BufferNode> buffer_node_a = make_node<BufferNode>();
NodePtr<BufferNode> buffer_node_b = make_node<BufferNode>();
NodePtr<BufferNode> buffer_node_c = make_node<BufferNode>();
auto mma_sync_call =
[&buffer_node_a, &buffer_node_b]
(const Buffer &buffer) {
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
return Evaluate::make(
Call::make(Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
Call::Intrinsic));
};
auto call_add_c =
[this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) {
return add_buffer_bind_scope_(cc, buffer_node_c,
TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->type);
};
auto call_add_b =
[this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) {
return add_buffer_bind_scope_(cb, buffer_node_b,
TensorKey{cb->func, cb->value_index}, call_add_c, cb->type);
};
return add_buffer_bind_scope_(ca, buffer_node_a,
TensorKey{ca->func, ca->value_index}, call_add_b, ca->type);
}
auto it2 = frag_load_.find(op);
if (it2 != frag_load_.end()) {
Expr dst = it2->second;
if (op->value.as<FloatImm>() != nullptr ||
op->value.as<IntImm>() != nullptr) {
auto call = dst.as<Call>();
auto fill_fragment_call =
[this, &op](const Buffer &buffer) {
return Evaluate::make(
Call::make(Handle(),
intrinsic::tvm_fill_fragment,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, op->value},
Call::Intrinsic));
};
NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
return add_buffer_bind_scope_(call, buffer_node,
TensorKey{call->func, call->value_index},
fill_fragment_call, call->type);
}
const Call* value = op->value.as<Call>();
CHECK(value != nullptr)
<< "Can only load fragment from a buffer";
auto it = strides_.find(value->name);
CHECK(it != strides_.end())
<< "Cannot find stride for " << value->name;
auto strides = it->second;
CHECK_GE(strides.size(), 2);
Expr stride = strides[strides.size()-2];
// thread index unification inside a warp
Expr warp_y = IntImm::make(Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
Expr mutated_value = thread_idx_mutator.Mutate(op->value);
Expr src = Call::make(value->type,
"&",
{mutated_value},
Call::Extern);
auto call = dst.as<Call>();
Expr matrix_major;
auto iter2 = matrix_major_.find(simplify_name(call->name));
CHECK(iter2 != matrix_major_.end())
<< "Can not determine matrix major for " << call->name;
if (iter2->second == "col_major") {
matrix_major = StringImm::make("col_major");
} else if (iter2->second == "row_major") {
matrix_major = StringImm::make("row_major");
} else {
LOG(FATAL) << "invalid matrix major for " << call->name;
}
auto load_matrix_call =
[this, &src, &stride, &matrix_major](const Buffer &buffer) {
return Evaluate::make(
Call::make(Handle(),
intrinsic::tvm_load_matrix_sync,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, src, stride, matrix_major},
Call::Intrinsic));
};
NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
return add_buffer_bind_scope_(call, buffer_node,
TensorKey{op->func, op->value_index},
load_matrix_call, call->type);
}
auto it3 = frag_store_.find(op);
if (it3 != frag_store_.end()) {
TensorKey key{op->func, op->value_index};
auto it = strides_.find(key.GetName());
CHECK(it != strides_.end())
<< "Cannot find stride for " << key.GetName();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
Expr stride = strides[strides.size()-2];
Expr dst = it3->second;
// thread index unification inside a warp
Expr warp_y = IntImm::make(Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator.Mutate(dst);
dst = Call::make(Handle(),
"&",
{dst},
Call::Extern);
auto call = op->value.as<Call>();
auto store_matrix_call =
[this, &dst, &stride](const Buffer &buffer) {
return Evaluate::make(
Call::make(Handle(),
intrinsic::tvm_store_matrix_sync,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, dst, stride,
StringImm::make("col_major")},
Call::Intrinsic));
};
NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
return add_buffer_bind_scope_(call, buffer_node,
TensorKey{call->func, call->value_index},
store_matrix_call, call->type);
}
return stmt;
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
if (op != nullptr) {
auto it = loop_scaling_.find(op->loop_var.get());
if (it != loop_scaling_.end()) {
int scale_factor = it->second;
int scaled_extent_value = 1;
if (const IntImm *ori_extent = op->extent.as<IntImm>()) {
int ori_extent_value = ori_extent->value;
scaled_extent_value = ori_extent_value / scale_factor;
}
Expr scaled_extent = make_const(op->extent.type(), scaled_extent_value);
stmt = For::make(op->loop_var, op->min, scaled_extent, op->for_type,
op->device_api, op->body);
}
}
return stmt;
}
private:
Array<Expr> get_tile_size_(const std::string &name) {
auto it = matrix_abc_.find(name);
auto it2 = matrix_major_.find(name);
CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
<< "Cannot find matrix info for " << name;
Expr size0 = make_const(Int(32), 16);
Expr size1 = make_const(Int(32), 16);
if (it->second == "matrix_a" && it2->second == "col_major") {
size0 = make_const(Int(32), warp_tile_.k);
size1 = make_const(Int(32), warp_tile_.m);
}
if (it->second == "matrix_a" && it2->second == "row_major") {
size0 = make_const(Int(32), warp_tile_.m);
size1 = make_const(Int(32), warp_tile_.k);
}
if (it->second == "matrix_b" && it2->second == "row_major") {
size0 = make_const(Int(32), warp_tile_.k);
size1 = make_const(Int(32), warp_tile_.n);
}
if (it->second == "matrix_b" && it2->second == "col_major") {
size0 = make_const(Int(32), warp_tile_.n);
size1 = make_const(Int(32), warp_tile_.k);
}
if (it->second == "matrix_c") {
size0 = make_const(Int(32), warp_tile_.n);
size1 = make_const(Int(32), warp_tile_.m);
}
Array<Expr> tile_size = {size0, size1};
return tile_size;
}
Stmt add_buffer_bind_scope_(const Call* call,
const NodePtr<BufferNode> &buffer_node, const TensorKey &key,
const std::function<Stmt(const Buffer &buffer)> &call_back,
DataType datatype) {
auto it = bounds_.find(key);
CHECK(it != bounds_.end());
Array<Expr> min_bound;
for (auto i : it->second) {
min_bound.push_back(i->min);
}
CHECK_GE(it->second.size(), 2);
Array<Expr> shape;
for (size_t i = 0; i < it->second.size() - 2; ++i) {
shape.push_back(it->second[i]->extent);
}
auto tile_size = get_tile_size_(simplify_name(call->name));
shape.push_back(tile_size[0]);
shape.push_back(tile_size[1]);
Array<Expr> strides;
for (size_t i = 1; i < shape.size(); ++i) {
Expr stride = IntImm::make(Int(32), 1);
for (size_t j = shape.size() - 1; j >= i; --j) {
stride = Mul::make(stride, shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(Int(32), 1));
Expr elem_offset = IntImm::make(Int(32), 0);
CHECK_EQ(call->args.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
elem_offset = Add::make(
elem_offset, Mul::make(
strides[i], Sub::make(call->args[i], min_bound[i])));
}
auto it2 = matrix_abc_.find(simplify_name(call->name));
CHECK(it2 != matrix_abc_.end())
<< "Cannot find matrix info for " << call->name;
buffer_node->data = Variable::make(Handle(), call->name);
buffer_node->name = call->name;
buffer_node->scope = "wmma." + it2->second;
buffer_node->dtype = datatype;
buffer_node->strides = strides;
buffer_node->shape = shape;
buffer_node->data_alignment = 1;
buffer_node->elem_offset = Simplify(elem_offset);
buffer_node->offset_factor = 1;
Buffer buffer(buffer_node);
NodePtr<TensorNode> tensor_node = make_node<TensorNode>();
tensor_node->value_index = key.value_index;
tensor_node->op = Downcast<Operation>(key.f);
tensor_node->shape = shape;
tensor_node->dtype = datatype;
Tensor tensor(tensor_node);
Array<Expr> args;
for (size_t i = 0; i < call->args.size(); ++i) {
args.push_back(call->args[i]);
args.push_back(shape[i]);
}
auto tuple = Call::make(Handle(),
intrinsic::tvm_tuple,
args,
Call::Intrinsic);
Array<NodeRef> node = {buffer, tensor};
return AttrStmt::make(node,
"buffer_bind_scope",
tuple,
call_back(buffer));
}
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
std::unordered_map<std::string, Array<Expr>> strides_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<const Variable*, unsigned> loop_scaling_;
std::unordered_map<const Provide*, Expr> frag_load_;
std::unordered_map<const Provide*, Expr> frag_store_;
std::unordered_map<TensorKey, Region> bounds_;
Tile warp_tile_;
int warp_threads_y_{-1};
};
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->target_name != "cuda") {
return stmt;
}
// Check if current runtime support GPU CUDA
TVMContext ctx{kDLGPU, 0};
auto api = tvm::runtime::DeviceAPI::Get(ctx, true);
if (api == nullptr) {
return stmt;
}
MMAMatcher mma_matcher(extern_buffer);
mma_matcher.Visit(stmt);
if (!mma_matcher.Matched()) {
return stmt;
}
ScheduleAnalyser schedule_analyser(mma_matcher);
if (!schedule_analyser.MatrixIdentify(schedule)) {
return stmt;
}
BufferAnalyser buffer_analyser(extern_buffer,
schedule_analyser, mma_matcher);
buffer_analyser.Visit(stmt);
if (!buffer_analyser.QualifiedForTensorCore()) {
return stmt;
}
return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt);
}
} // namespace ir
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import topi
import numpy as np
from tvm.contrib import nvcc
def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
A = tvm.placeholder((n, l), name='A', dtype='float16')
B = tvm.placeholder((l, m), name='B', dtype='float16')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('float32') * B[k, j].astype('float32'), axis=k))
s = tvm.create_schedule(C.op)
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
AA = s.cache_read(A, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CL = s.cache_write(C, "local")
bx = 4
by = 32
step_k = 8
v = 4
TX = 8
TY = 1
tile_x = bx * TX
tile_y = by * TY
WX = min(warp_tile_m, tile_x)
tile_k = 16
vthread = 1
yo, ty = s[C].split(y, tile_y*vthread)
vy, ty = s[C].split(ty, tile_y)
ty, yi = s[C].split(ty, TY)
xo, xi = s[C].split(x, tile_x)
tz, xi = s[C].split(xi, WX)
tx, xi = s[C].split(xi, TX)
ko, ki = s[CL].split(k, step_k * tile_k)
kl, ki = s[CL].split(ki, tile_k)
s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy"))
s[CL].compute_at(s[C], tx)
yo, xo = CL.op.axis
s[CL].reorder(ko, kl, ki, yo, xo)
s[AA].compute_at(s[CL], ko)
xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
tx, vec = s[AA].split(tx, factor=v)
fused = s[AA].fuse(s[AA].op.axis[0], xo)
_, ty = s[AA].split(fused, factor=by)
s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
s[AA].vectorize(vec)
s[BB].compute_at(s[CL], ko)
xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
tx, vec = s[BB].split(tx, factor=v)
fused = s[BB].fuse(s[BB].op.axis[0], xo)
_, ty = s[BB].split(fused, factor=by)
s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BB].vectorize(vec)
s[AL].compute_at(s[CL], kl)
s[BL].compute_at(s[CL], kl)
s[CL].pragma(ko, 'tensor_core')
func = tvm.build(s, [A, B, C], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(l, m)).astype(B.dtype)
c_np = np.zeros((n, m), dtype=np.float32)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
func(a, b, c)
evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
print('gemm m=%d n=%d k=%d: %f ms' % (m, n, l, evaluator(a, b, c).mean * 1e3))
c_np = np.dot(a_np, b_np)
np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
A = tvm.placeholder((batch, n, l), name='A', dtype='float16')
B = tvm.placeholder((batch, l, m), name='B', dtype='float16')
k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute((batch, n, m), lambda b, i, j: tvm.sum((A[b, i, k] * B[b, k, j]).astype('float32'), axis=k))
s = tvm.create_schedule(C.op)
z, y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
AA = s.cache_read(A, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CL = s.cache_write(C, "local")
bx = 2
by = 32
step_k = 8
v = 4
TX = 8
TY = 1
tile_x = bx * TX
tile_y = by * TY
WX = min(warp_tile_m, tile_x)
tile_k = 16
vthread = 1
yo, ty = s[C].split(y, tile_y*vthread)
vy, ty = s[C].split(ty, tile_y)
ty, yi = s[C].split(ty, TY)
xo, xi = s[C].split(x, tile_x)
tz, xi = s[C].split(xi, WX)
tx, xi = s[C].split(xi, TX)
ko, ki = s[CL].split(k, step_k * tile_k)
kl, ki = s[CL].split(ki, tile_k)
s[C].reorder(z, yo, xo, tz, ty, tx, yi, xi)
s[C].bind(z, tvm.thread_axis("blockIdx.z"))
s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy"))
s[CL].compute_at(s[C], tx)
zo, yo, xo = CL.op.axis
s[CL].reorder(ko, kl, ki, zo, yo, xo)
s[AA].compute_at(s[CL], ko)
xo, xi = s[AA].split(s[AA].op.axis[2], factor=bx*v)
tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
tx, vec = s[AA].split(tx, factor=v)
fused = s[AA].fuse(s[AA].op.axis[1], xo)
_, ty = s[AA].split(fused, factor=by)
s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
s[AA].vectorize(vec)
s[BB].compute_at(s[CL], ko)
xo, xi = s[BB].split(s[BB].op.axis[2], factor=bx*v)
tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
tx, vec = s[BB].split(tx, factor=v)
fused = s[BB].fuse(s[BB].op.axis[1], xo)
_, ty = s[BB].split(fused, factor=by)
s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BB].vectorize(vec)
s[AL].compute_at(s[CL], kl)
s[BL].compute_at(s[CL], kl)
s[CL].pragma(ko, 'tensor_core')
func = tvm.build(s, [A, B, C], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(batch, l, m)).astype(B.dtype)
c_np = np.zeros((batch, n, m), dtype=np.float32)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((batch, n, m), dtype=C.dtype), ctx)
func(a, b, c)
evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
print('batch gemm m=%d n=%d k=%d batch=%d: %f ms' % (m, n, l, batch, evaluator(a, b, c).mean * 1e3))
for bs in range(batch):
c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :])
np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
def test_tensor_core_matmul():
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
print("skip because gpu does not support tensor core")
return
tensor_core_matmul(16) #test with warp_tile 16x16x16
tensor_core_matmul(8) #test with warp_tile 8x32x16
tensor_core_matmul(32) #test with warp_tile 32x8x16
def test_tensor_core_batch_matmul():
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
print("skip because gpu does not support tensor core")
return
tensor_core_batch_matmul()
if __name__ == '__main__':
test_tensor_core_matmul()
test_tensor_core_batch_matmul()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _opt-matmul-auto-tensorcore:
How to optimize matmul with Auto TensorCore CodeGen
==================================
**Author**: `Minmin Sun <https://github.com/minminsun>`_, \
`Lanbo Li <https://github.com/Orion34C>`_, \
`Chenfan Jia <https://github.com/jcf94>`_, \
`Jun Yang <https://github.com/yangjunpro>`_
In this tutorial, we will demonstrate how to write a high performance matmul
schedule on Volta/Turing GPUs with TVM Auto TensorCore CodeGen.
This is a transparent solution to generate tensorcore kernel
with most transformations done in ir passes.
Users can also write schedule with tensorization to generate TensorCore code.
Both solutions use the same tensorcore intrinsics.
Please refer to :ref:`opt-conv-tensorcore` tutorial for more details.
"""
################################################################
# Preparation and Algorithm
# --------------------------
# 2 kinds of input data types are supported: float16 and int8.
# For float16, the accumulator is float32.
# For int8, the accumulator is int32.
# For data layouts, 'N' means None-transpose while 'T' means Transpose.
import logging
import sys
import numpy as np
import tvm
from tvm import autotvm
from tvm.contrib import nvcc
def matmul_nn(A, B, L, dtype='float16', layout='NN'):
k = tvm.reduce_axis((0, L), name='k')
if dtype == 'float16':
out_type = 'float'
elif dtype == 'int8':
out_type = 'int'
if (layout == 'NN'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
if (layout == 'NT'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k))
if (layout == 'TN'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k))
if (layout == 'TT'):
return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k))
###############################################################################
# Scheduling the Computation
# --------------------------
# This schedule is no different than a non-tensorcore matmul schedule on GPU.
# Please refer to :ref:`opt-gemm` tutorial for basics of optimizing matmul schedule.
# When the "tensor_core" pragma is set, the "rewrite for tensorcore" ir pass
# will automatically transform the schedule for tensorcore codegen,
# otherwise normal CUDA code, with lower performance but equal functionality, will be generated.
#
# .. note::
#
# *Requirements of TesnsorCore*
#
# Note that in the following 2 cases, even though the "tensor_core" pragma is set, TVM will still fall back to normal CUDA codegen:
# (1) The m, n or k of input matrices is not multiple of 16;
# (2) The warp tile size is not 16x16x16 on CUDA9, or not one of {16x16x16, 32x8x16, 8x32x16} on CUDA version >= 10.0.
#
# In this schedule, storage_align is used to reduce bank conflicts of shared memory. Please refer to this
# `doc <https://docs.tvm.ai/api/python/schedule.html#tvm.schedule.Stage.storage_align>`_
# for the usage of storage_align primitive. In short, we need to add an offset to some shared memory buffer
# to reduce bank conflicts.
# According to the `wmma doc <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma-description>`_,
# the stride of load_matrix_sync must be a multiple of 16 bytes,
# so we choose 8 as offset for float16 and 16 as offset for int8.
#
# We use AutoTVM to search for best configurations in this schedule.
@autotvm.template
def test_gemm(N, L, M, dtype, layout):
if (layout == "NN"):
shape_a = (N, L)
shape_b = (L, M)
elif (layout == "NT"):
shape_a = (L, N)
shape_b = (L, M)
elif (layout == "TN"):
shape_a = (N, L)
shape_b = (M, L)
elif (layout == "TT"):
shape_a = (L, N)
shape_b = (M, L)
else:
print ("Unsupported layout:", layout)
sys.exit(1);
A = tvm.placeholder(shape_a, name='A', dtype=dtype)
B = tvm.placeholder(shape_b, name='B', dtype=dtype)
C = matmul_nn(A, B, L, dtype, layout)
s = tvm.create_schedule(C.op)
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
# storage_align params
factor = 16
offset = 8
if dtype == 'int8':
factor = 32
offset = 16
# create cache stages
AA = s.cache_read(A, "shared", [C])
if (layout == "NN" or layout == "TN"):
s[AA].storage_align(AA.op.axis[0], factor, offset)
AL = s.cache_read(AA, "local", [C])
BB = s.cache_read(B, "shared", [C])
if (layout == "TT" or layout == "NT"):
s[BB].storage_align(BB.op.axis[0], factor, offset)
BL = s.cache_read(BB, "local", [C])
CL = s.cache_write(C, "local")
#autotvm search space definition
cfg = autotvm.get_config()
cfg.define_knob("bx", [2, 4, 8])
cfg.define_knob("by", [16, 32, 64])
cfg.define_knob("step_k", [8, 16, 32])
cfg.define_knob("v", [4, 8])
by = cfg['by'].val
bx = cfg['bx'].val
step_k = cfg['step_k'].val
v = cfg['v'].val
# thread tile
TX = 8
TY = 1
# warp tile
warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
warp_tile_k = 16 # it must be 16
# block tile
tile_x = bx * TX
tile_y = by * TY
yo, ty = s[C].split(y, tile_y)
ty, yi = s[C].split(ty, TY)
# schedule for C stage
xo, xi = s[C].split(x, tile_x)
WX = min(warp_tile_m, tile_x)
tz, xi = s[C].split(xi, WX)
tx, xi = s[C].split(xi, TX)
s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
# schedule for CL stage
ko, ki = s[CL].split(k, step_k * warp_tile_k)
kl, ki = s[CL].split(ki, warp_tile_k)
s[CL].compute_at(s[C], tx)
yo, xo = CL.op.axis
s[CL].reorder(ko, kl, ki, yo, xo)
# schedule for AA stage
s[AA].compute_at(s[CL], ko)
xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
tx, vec = s[AA].split(tx, factor=v)
fused = s[AA].fuse(s[AA].op.axis[0], xo)
_, ty = s[AA].split(fused, factor=by)
s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
# vectorization is very important for float16/int8 inputs
s[AA].vectorize(vec)
# schedule for BB stage
s[BB].compute_at(s[CL], ko)
xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
tx, vec = s[BB].split(tx, factor=v)
fused = s[BB].fuse(s[BB].op.axis[0], xo)
_, ty = s[BB].split(fused, factor=by)
s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BB].vectorize(vec)
s[AL].compute_at(s[CL], kl)
s[BL].compute_at(s[CL], kl)
# set the 'tensor_core' pragma for tensorcore codegen
s[CL].pragma(ko, 'tensor_core')
return s, [A, B, C]
###############################################################################
# AutoTune and Test
# --------------------
# Finally we use a tuner to tune the schedule, generate code with best config
# and run the kernel to compare with numpy to check whether the results are correct.
# check whether the gpu has tensorcore
ctx = tvm.gpu()
if not nvcc.have_tensorcore(ctx.compute_version):
print('the gpu has no tensorcore, skipping...')
sys.exit(0)
M, N, L = 512, 32, 512
dtype = 'float16'
layout = 'NN'
if len(sys.argv) >= 4:
M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
if len(sys.argv) >= 5:
dtype = sys.argv[4]
if len(sys.argv) >= 6:
layout = sys.argv[5]
def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
print(task.config_space)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
measure_option = autotvm.measure_option(
builder='local',
runner=autotvm.LocalRunner(number=5))
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=1000,
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('matmul.log')])
dispatch_context = autotvm.apply_history_best("matmul.log")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)
with autotvm.apply_history_best('matmul.log'):
with tvm.target.create("cuda"):
with tvm.build_config():
s, arg_bufs = test_gemm(N, L, M, dtype, layout)
print(tvm.lower(s, arg_bufs, simple_mode=True))
func = tvm.build(s, arg_bufs)
dev_module = func.imported_modules[0]
print(dev_module.get_source())
# check correctness
if (layout == "NN"):
shape_a = (N, L)
shape_b = (L, M)
elif (layout == "NT"):
shape_a = (L, N)
shape_b = (L, M)
elif (layout == "TN"):
shape_a = (N, L)
shape_b = (M, L)
elif (layout == "TT"):
shape_a = (L, N)
shape_b = (M, L)
a_np = None
b_np = None
c_np = None
c_np_type = None
if dtype == 'float16':
c_np_type = np.float32
a_np = np.random.uniform(size=shape_a).astype(np.float16)
b_np = np.random.uniform(size=shape_b).astype(np.float16)
if (layout == "NN"):
c_np = np.dot(a_np, b_np)
elif (layout == "NT"):
c_np = np.dot(a_np.T, b_np)
elif (layout == "TN"):
c_np = np.dot(a_np, b_np.T)
elif (layout == "TT"):
c_np = np.dot(a_np.T, b_np.T)
elif dtype == 'int8':
c_np_type = np.int32
a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8)
b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8)
if (layout == "NN"):
c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
elif (layout == "NT"):
c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
elif (layout == "TN"):
c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
elif (layout == "TT"):
c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
a_tvm = tvm.nd.array(a_np, ctx=ctx)
b_tvm = tvm.nd.array(b_np, ctx=ctx)
func(a_tvm, b_tvm, c_tvm)
tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)
evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
print('Time cost of this operator: %f' % evaluator(a_tvm, b_tvm, c_tvm).mean)
# We do not run the tuning in our webpage server since it takes some time.
# Uncomment the following line to run it by yourself.
# tune_and_evaluate(M, N, L, dtype, layout)
######################################################################
# Sample Output
# -------------
# .. code-block:: bash
#
# Best config:
# [('bx', 4), ('by', 32), ('step_k', 16), ('v', 8)],,None,40
# Finish loading 162 records
# produce compute {
# // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
# // attr [compute.local] storage_scope = "wmma.accumulator"
# allocate compute.local[float32 * 256]
# // attr [A.shared] storage_scope = "shared"
# allocate A.shared[float16 * 8448]
# // attr [B.shared] storage_scope = "shared"
# allocate B.shared[float16 * 8192]
# // attr [A.shared.local] storage_scope = "wmma.matrix_b"
# allocate A.shared.local[float16 * 256]
# // attr [B.shared.local] storage_scope = "wmma.matrix_a"
# allocate B.shared.local[float16 * 256]
# // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16
# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
# produce compute.local {
# for (j.c.init, 0, 1) {
# tvm_fill_fragment(compute.local, 16, 16, 16, 0, 0f)
# }
# // attr [iter_var(k.outer, )] pragma_tensor_core = 1
# for (k.outer, 0, 2) {
# produce A.shared {
# for (ax0.ax1.outer.fused.outer, 0, 8) {
# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
# A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + (floordiv(threadIdx.y, 8)*264)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = A[ramp(((((((ax0.ax1.outer.fused.outer*2048) + (floordiv(threadIdx.y, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]
# }
# }
# produce B.shared {
# for (ax0.ax1.outer.fused.outer, 0, 8) {
# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
# B.shared[ramp(((((ax0.ax1.outer.fused.outer*1024) + (threadIdx.y*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = B[ramp(((((((k.outer*131072) + (ax0.ax1.outer.fused.outer*16384)) + (threadIdx.y*512)) + (blockIdx.x*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]
# }
# }
# for (k.inner.outer, 0, 16) {
# produce A.shared.local {
# for (ax1, 0, 1) {
# tvm_load_matrix_sync(A.shared.local, 16, 16, 16, 0, &(A.shared[(((threadIdx.y/16)*4224) + (k.inner.outer*16))]), 264, "col_major")
# }
# }
# produce B.shared.local {
# for (ax0, 0, 1) {
# for (ax1, 0, 1) {
# tvm_load_matrix_sync(B.shared.local, 16, 16, 16, 0, &(B.shared[((k.inner.outer*512) + (threadIdx.z*16))]), 32, "col_major")
# }
# }
# }
# for (k.inner.inner, 0, 1) {
# for (j.c, 0, 1) {
# tvm_mma_sync(compute.local, 0, B.shared.local, 0, A.shared.local, 0, compute.local, 0)
# }
# }
# }
# }
# }
# for (j.inner.inner.inner, 0, 1) {
# tvm_store_matrix_sync(compute.local, 16, 16, 16, 0, &(compute[((((threadIdx.y/16)*8192) + (blockIdx.x*32)) + (threadIdx.z*16))]), 512, "col_major")
# }
# }
#
# #include <cuda_fp16.h>
# __device__ half max(const half a, const half b)
# {
# return __hgt(__half(a), __half(b)) ? a : b;
# }
# __device__ half min(const half a, const half b)
# {
# return __hlt(__half(a), __half(b)) ? a : b;
# }
# __device__ half operator+(const volatile __half &a, const volatile __half &b)
# {
# return __hadd(a, b);
# }
# __device__ half operator<=(const volatile __half &a, const volatile __half &b)
# {
# return __hlt(a, b);
# }
# __device__ half operator*(const volatile __half &a, const volatile __half &b)
# {
# return __hmul(a, b);
# }
# #include <mma.h>
# extern "C" __global__ void default_function_kernel0( half* __restrict__ A, half* __restrict__ B, float* __restrict__ compute) {
# nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> compute_local[1];
# __shared__ half A_shared[8448];
# __shared__ half B_shared[8192];
# nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> A_shared_local[1];
# nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> B_shared_local[1];
# for (int j_c_init = 0; j_c_init < 1; ++j_c_init) {
# (void)nvcuda::wmma::fill_fragment(compute_local[0], 0.000000e+00f);
# }
# for (int k_outer = 0; k_outer < 2; ++k_outer) {
# __syncthreads();
# for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; ++ax0_ax1_outer_fused_outer) {
# ((__shared__ float4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 1056) + ((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(A + ((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];
# }
# for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) {
# ((__shared__ float4*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 1024) + (((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(B + ((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];
# }
# __syncthreads();
# for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) {
# for (int ax1 = 0; ax1 < 1; ++ax1) {
# (void)nvcuda::wmma::load_matrix_sync(A_shared_local[0], &(A_shared[(((((int)threadIdx.y) / 16) * 4224) + (k_inner_outer * 16))]), 264);
# }
# for (int ax0 = 0; ax0 < 1; ++ax0) {
# for (int ax11 = 0; ax11 < 1; ++ax11) {
# (void)nvcuda::wmma::load_matrix_sync(B_shared_local[0], &(B_shared[((k_inner_outer * 512) + (((int)threadIdx.z) * 16))]), 32);
# }
# }
# for (int k_inner_inner = 0; k_inner_inner < 1; ++k_inner_inner) {
# for (int j_c = 0; j_c < 1; ++j_c) {
# (void)nvcuda::wmma::mma_sync(compute_local[0], B_shared_local[0], A_shared_local[0], compute_local[0]);
# }
# }
# }
# }
# for (int j_inner_inner_inner = 0; j_inner_inner_inner < 1; ++j_inner_inner_inner) {
# (void)nvcuda::wmma::store_matrix_sync(&(compute[((((((int)threadIdx.y) / 16) * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16))]), compute_local[0], 512, nvcuda::wmma::mem_col_major);
# }
# }
#
#
# Time cost of this operator: 0.000008
###############################################################################
# Summary
# --------------------------
# This tutorial demonstrates how to use the AutoTensorCoreCodeGen of TVM
# to generate tensorcore kernels.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment