Unverified Commit 4c0a53dc by Tianqi Chen Committed by GitHub

[TIR][REFACTOR] RewriteForTensorCore -> te/schedule (#5379)

* [TIR][REFACTIR] RewriteForTensorCore -> te/schedule

RewriteForTensor depends on the schedule information, which makes it differ
from a typical pass(which should get all the information from the input TIR).

As a result, we refactor it as a SchedulePostProc step for now.
We should revisit it later as we introduce more support for tensor core patterns in the TIR.

* Fix VTA to fit the new IR Pattern
parent 22db299b
...@@ -35,6 +35,23 @@ namespace tvm { ...@@ -35,6 +35,23 @@ namespace tvm {
namespace te { namespace te {
/*! /*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);
/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);
/*!
* \brief Infer the bound of all iteration variables relates to the schedule. * \brief Infer the bound of all iteration variables relates to the schedule.
* *
* \param sch The root schedule to infer all the bounds. * \param sch The root schedule to infer all the bounds.
...@@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch); ...@@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/ */
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop); Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);
/*! /*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create * \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations. * a PrimFunc that can then be used for further TIR optimizations.
...@@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, ...@@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body, Stmt body,
Optional<Map<Tensor, Buffer>> bindings); Optional<Map<Tensor, Buffer>> bindings);
/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);
/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);
} // namespace te } // namespace te
} // namespace tvm } // namespace tvm
#endif // TVM_TE_SCHEDULE_PASS_H_ #endif // TVM_TE_SCHEDULE_PASS_H_
...@@ -165,19 +165,6 @@ Stmt Inline(Stmt stmt, ...@@ -165,19 +165,6 @@ Stmt Inline(Stmt stmt,
PrimExpr body); PrimExpr body);
/*! /*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
te::Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer);
/*!
* \brief Verify if there is any argument bound to compact buffer. * \brief Verify if there is any argument bound to compact buffer.
* *
* \param stmt The stmt to be verified. * \param stmt The stmt to be verified.
......
...@@ -84,7 +84,7 @@ def get_binds(args, compact=False, binds=None): ...@@ -84,7 +84,7 @@ def get_binds(args, compact=False, binds=None):
return binds, arg_list return binds, arg_list
def form_body(sch): def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function. """According to the given schedule, form a function.
Parameters Parameters
...@@ -92,15 +92,35 @@ def form_body(sch): ...@@ -92,15 +92,35 @@ def form_body(sch):
sch : tvm.te.schedule.Schedule sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns Returns
------- -------
The body formed according to the given schedule The body formed according to the given schedule
""" """
# normalize schedule first # normalize schedule first
cfg = BuildConfig.current()
sch = sch.normalize() sch = sch.normalize()
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
return stmt
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name): def _wrap_as_prim_func_pass(flist, name):
...@@ -166,24 +186,13 @@ def lower(sch, ...@@ -166,24 +186,13 @@ def lower(sch,
# Phase 0 # Phase 0
if isinstance(sch, schedule.Schedule): if isinstance(sch, schedule.Schedule):
stmt = form_body(sch) mod = form_irmodule(sch, args, name, binds)
else:
for f in lower_phase0: mod = sch
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
# Phase 1 # Phase 1
pass_list = [ pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.NarrowDataType(32),
......
...@@ -30,7 +30,7 @@ HalideIR. ...@@ -30,7 +30,7 @@ HalideIR.
# 2. Support multi-level HalideIR # 2. Support multi-level HalideIR
import inspect import inspect
import tvm._ffi import tvm._ffi
from tvm.driver.build_module import form_body import tvm.te.schedule
from tvm._ffi.base import decorate from tvm._ffi.base import decorate
from .module import HybridModule from .module import HybridModule
...@@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"): ...@@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"):
The built results is wrapped in a HybridModule. The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules. The usage of HybridModule is roughly the same as normal TVM-built modules.
""" """
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name) src = _Dump(stmt, inputs, outputs, name)
return HybridModule(src, name) return HybridModule(src, name)
......
...@@ -18,9 +18,12 @@ ...@@ -18,9 +18,12 @@
*/ */
/*! /*!
* \file tensor_core.cc * \file schedule_postproc_rewrite_for_tensor_core.cc
*
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/ */
// IR Passes for TensorCore CodeGen #include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
...@@ -32,12 +35,11 @@ ...@@ -32,12 +35,11 @@
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <unordered_map> #include <unordered_map>
#include "ir_util.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace tir { namespace te {
using namespace te; using namespace te;
using runtime::StorageRank; using runtime::StorageRank;
...@@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor { ...@@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor {
} }
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_tensor_core) { if (op->attr_key == tir::attr::pragma_tensor_core) {
tensor_core_on_ = true; tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) { } else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value; storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else { } else {
...@@ -414,7 +416,7 @@ class BufferAnalyser : public StmtExprVisitor { ...@@ -414,7 +416,7 @@ class BufferAnalyser : public StmtExprVisitor {
} }
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) { if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert( thread_extent_.insert(
std::make_pair( std::make_pair(
...@@ -422,10 +424,10 @@ class BufferAnalyser : public StmtExprVisitor { ...@@ -422,10 +424,10 @@ class BufferAnalyser : public StmtExprVisitor {
value->value)); value->value));
} }
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) { } else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value; storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) { } else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node); te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>(); const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
...@@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator { ...@@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) { if (op->attr_key == tir::attr::realize_scope) {
auto node = op->node.as<te::OperationNode>(); auto node = op->node.as<te::OperationNode>();
if (node != nullptr) { if (node != nullptr) {
if (!frag_reg_.count(node->name)) { if (!frag_reg_.count(node->name)) {
...@@ -1186,7 +1188,8 @@ class TensorCoreIRMutator : public StmtExprMutator { ...@@ -1186,7 +1188,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
int warp_threads_y_{-1}; int warp_threads_y_{-1};
}; };
Stmt RewriteForTensorCore(Stmt stmt, Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule, Schedule schedule,
Map<Tensor, Buffer> extern_buffer) { Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA // Check if current lower target is CUDA
...@@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt, ...@@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
} }
} // namespace tir TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
.set_body_typed([](Stmt stmt,
Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer) {
return SchedulePostProcRewriteForTensorCore(
stmt, schedule, extern_buffer);
});
} // namespace te
} // namespace tvm } // namespace tvm
...@@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") ...@@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
} }
}); });
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
const te::Schedule& schedule,
const Map<te::Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
......
...@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in): ...@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
selects = [] selects = []
def _find_basics(op): def _find_basics(op):
if isinstance(op, tvm.tir.Call): if isinstance(op, tvm.tir.BufferLoad):
calls.append(op) calls.append(op)
elif isinstance(op, tvm.tir.Select): elif isinstance(op, tvm.tir.Select):
selects.append(op) selects.append(op)
...@@ -664,18 +664,18 @@ def inject_conv2d_transpose_skip(stmt_in): ...@@ -664,18 +664,18 @@ def inject_conv2d_transpose_skip(stmt_in):
body = op.body.body body = op.body.body
while isinstance(body, tvm.tir.IfThenElse): while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case body = body.then_case
args = body.args args = body.indices
res_tensor = body.func.output(0) res_buffer = body.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt( inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope', [dout, res_buffer], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner return inner
else: else:
conv_call, data_call, kernel_call = calls[-3:] conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor = data_call.func.output(0) pad_data_tensor = data_call.buffer
kernel_tensor = kernel_call.func.output(0) kernel_tensor = kernel_call.buffer
res_tensor = conv_call.func.output(0) res_tensor = conv_call.buffer
if selects: if selects:
condition = selects[0].condition condition = selects[0].condition
...@@ -696,19 +696,19 @@ def inject_conv2d_transpose_skip(stmt_in): ...@@ -696,19 +696,19 @@ def inject_conv2d_transpose_skip(stmt_in):
0, 0, 0)) 0, 0, 0))
inner = irb.get() inner = irb.get()
args = conv_call.args args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT) 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt( inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope', [dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.args args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt( inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope', [dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.args args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN) 1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt( inner = tvm.tir.AttrStmt(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment