Unverified Commit 4c0a53dc by Tianqi Chen Committed by GitHub

[TIR][REFACTOR] RewriteForTensorCore -> te/schedule (#5379)

* [TIR][REFACTIR] RewriteForTensorCore -> te/schedule

RewriteForTensor depends on the schedule information, which makes it differ
from a typical pass(which should get all the information from the input TIR).

As a result, we refactor it as a SchedulePostProc step for now.
We should revisit it later as we introduce more support for tensor core patterns in the TIR.

* Fix VTA to fit the new IR Pattern
parent 22db299b
......@@ -35,6 +35,23 @@ namespace tvm {
namespace te {
/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);
/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);
/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
*
* \param sch The root schedule to infer all the bounds.
......@@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);
/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
......@@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);
/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);
/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);
} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_PASS_H_
......@@ -165,19 +165,6 @@ Stmt Inline(Stmt stmt,
PrimExpr body);
/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
te::Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
......
......@@ -84,23 +84,43 @@ def get_binds(args, compact=False, binds=None):
return binds, arg_list
def form_body(sch):
def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
cfg = BuildConfig.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
return stmt
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
......@@ -166,24 +186,13 @@ def lower(sch,
# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
mod = form_irmodule(sch, args, name, binds)
else:
mod = sch
# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
......
......@@ -30,7 +30,7 @@ HalideIR.
# 2. Support multi-level HalideIR
import inspect
import tvm._ffi
from tvm.driver.build_module import form_body
import tvm.te.schedule
from tvm._ffi.base import decorate
from .module import HybridModule
......@@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"):
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)
return HybridModule(src, name)
......
......@@ -18,9 +18,12 @@
*/
/*!
* \file tensor_core.cc
* \file schedule_postproc_rewrite_for_tensor_core.cc
*
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/
// IR Passes for TensorCore CodeGen
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
......@@ -32,12 +35,11 @@
#include <tvm/target/target.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace tir {
namespace te {
using namespace te;
using runtime::StorageRank;
......@@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
if (op->attr_key == tir::attr::pragma_tensor_core) {
tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else {
......@@ -414,7 +416,7 @@ class BufferAnalyser : public StmtExprVisitor {
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
std::make_pair(
......@@ -422,10 +424,10 @@ class BufferAnalyser : public StmtExprVisitor {
value->value));
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
} else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
......@@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
if (op->attr_key == tir::attr::realize_scope) {
auto node = op->node.as<te::OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
......@@ -1186,9 +1188,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
int warp_threads_y_{-1};
};
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->target_name != "cuda") {
......@@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}
} // namespace tir
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
.set_body_typed([](Stmt stmt,
Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer) {
return SchedulePostProcRewriteForTensorCore(
stmt, schedule, extern_buffer);
});
} // namespace te
} // namespace tvm
......@@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
const te::Schedule& schedule,
const Map<te::Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
......
......@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
selects = []
def _find_basics(op):
if isinstance(op, tvm.tir.Call):
if isinstance(op, tvm.tir.BufferLoad):
calls.append(op)
elif isinstance(op, tvm.tir.Select):
selects.append(op)
......@@ -664,18 +664,18 @@ def inject_conv2d_transpose_skip(stmt_in):
body = op.body.body
while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case
args = body.args
res_tensor = body.func.output(0)
args = body.indices
res_buffer = body.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
[dout, res_buffer], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
else:
conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor = data_call.func.output(0)
kernel_tensor = kernel_call.func.output(0)
res_tensor = conv_call.func.output(0)
pad_data_tensor = data_call.buffer
kernel_tensor = kernel_call.buffer
res_tensor = conv_call.buffer
if selects:
condition = selects[0].condition
......@@ -696,19 +696,19 @@ def inject_conv2d_transpose_skip(stmt_in):
0, 0, 0))
inner = irb.get()
args = conv_call.args
args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.args
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.args
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment