Unverified Commit d3277874 by Tianqi Chen Committed by GitHub

[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)

parent 72f2aea2
......@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <vector>
......@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief List of passes to be injected into the low-level pipeline. */
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
......
......@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
return f
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
......@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch,
args,
name="main",
......@@ -190,15 +171,15 @@ def lower(sch,
else:
mod = sch
pass_list = lower_phase0
# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
pass_list += lower_phase1
# Phase 2
if not simple_mode:
......@@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
]
pass_list += lower_phase2
# Phase 3
pass_list += [
......@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
pass_list += lower_phase3
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
......
......@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
......@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
args[i + 1].operator transform::Pass()));
}
cfg->add_lower_pass = add_lower_pass;
});
......
......@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
def fail(x):
def FailPass():
def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
......
......@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx)
sch[c].vectorize(xi)
def my_vectorize(stmt):
def MyVectorize():
def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32')
......@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
......
......@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op)
def my_vectorize(stmt):
def my_vectorize():
def vectorizer(op):
store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
......@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
......
......@@ -19,10 +19,10 @@ import tvm
from tvm import te
def get_verify_pass(valid, **kwargs):
def verify_pass(stmt):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
return stmt
return verify_pass
def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return f
return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory():
def check_shared_memory(dtype):
......
......@@ -117,19 +117,20 @@ def vectorize8(op):
return body
return None
def vectorize(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops:
return stmt
return sf
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
return stmt
#####################################################################
# Glue to Lowering
......
......@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from . import ir_pass
from . import transform
from .environment import get_env
def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
def _transform(mod, ctx):
try:
return tvm.tir.ir_pass.StorageRewrite(stmt)
return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError:
return stmt
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs):
......@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, transform.InjectDMAIntrin()),
(1, transform.InjectSkipCopy()),
(1, transform.AnnotateALUCoProcScope()),
(1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
......
......@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Additional IR Pass for VTA"""
# pylint: disable=len-as-condition, no-else-return
"""Additional Transformation Passes. for VTA"""
# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
import tvm
from tvm import te
from topi import util
......@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(stmt.attr_key == "pragma_scope" and stmt.value.value == key))
def fold_uop_loop(stmt_in):
def FoldUopLoop():
"""Detect and fold uop loop.
VTA support uop programming model
......@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Output statement.
fpass : tvm.transform.Pass
The pass
"""
env = get_env()
def _fold_outermost_loop(body):
stmt = body
if not isinstance(stmt, tvm.tir.For):
......@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise ValueError("Failed to fold the GEMM instructions..")
def _do_fold(stmt):
env = get_env()
if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.tir.StringImm) and
stmt.value.value == env.dev.vta_push_uop.value):
......@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return tvm.tir.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body)
return None
out = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return out
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
def cpu_access_rewrite(stmt_in):
def CPUAccessRewrite():
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
......@@ -148,18 +146,14 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
env = get_env()
def _ftransform(f, mod, ctx):
rw_info = {}
env = get_env()
def _post_order(op):
if isinstance(op, tvm.tir.Allocate):
buffer_var = op.buffer_var
......@@ -191,30 +185,31 @@ def cpu_access_rewrite(stmt_in):
new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), stmt)
return stmt
return f.with_body(stmt)
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
def lift_alloc_to_scope_begin(stmt_in):
def LiftAllocToScopeBegin():
"""Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(f, mod, ctx):
lift_stmt = [[]]
def _merge_block(slist, body):
for op in slist:
......@@ -257,46 +252,46 @@ def lift_alloc_to_scope_begin(stmt_in):
if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
return _merge_block(lift_stmt[0], stmt)
return f.with_body(_merge_block(lift_stmt[0], stmt))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters
----------
stmt_in : Stmt
Input statement
def InjectSkipCopy():
"""Pass to inject skip copy stmt, used for debug purpose.
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _do_fold(stmt):
if _match_pragma(stmt, "skip_dma_copy"):
return tvm.tir.Evaluate(0)
return None
return tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
def inject_coproc_sync(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
Parameters
----------
stmt_in : Stmt
Input statement
def InjectCoProcSync():
"""Pass inject coproc sync
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(f, *_):
success = [False]
def _do_fold(stmt):
if _match_pragma(stmt, "coproc_sync"):
......@@ -311,26 +306,22 @@ def inject_coproc_sync(stmt_in):
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
stmt = tvm.tir.ir_pass.CoProcSync(stmt)
return stmt
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, None, _do_fold, ["AttrStmt"]))
return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
tvm.tir.transform.CoProcSync()],
opt_level=0, name="tir.vta.InjectCoProcSync")
def inject_dma_intrin(stmt_in):
def InjectDMAIntrin():
"""Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
env = get_env()
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
......@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
# Store
......@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
def _get_gemm_intrin_buffer():
......@@ -619,19 +611,15 @@ def _get_gemm_intrin_buffer():
return wgt_layout, inp_layout, out_layout
def inject_conv2d_transpose_skip(stmt_in):
def InjectConv2DTransposeSkip():
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(func, mod, ctx):
env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer()
......@@ -687,7 +675,8 @@ def inject_conv2d_transpose_skip(stmt_in):
irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition):
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(
dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0,
......@@ -717,24 +706,22 @@ def inject_conv2d_transpose_skip(stmt_in):
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
ret = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return ret
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
def annotate_alu_coproc_scope(stmt_in):
"""Pass to insert ALU instruction.
Parameters
----------
stmt_in : Stmt
Input statement
def AnnotateALUCoProcScope():
"""Pass to insert ALU instruction.
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(func, mod, ctx):
env = get_env()
def _do_fold(stmt):
if _match_pragma(stmt, "alu"):
......@@ -749,25 +736,21 @@ def annotate_alu_coproc_scope(stmt_in):
return tvm.tir.Evaluate(0)
return stmt
stmt_out = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
def inject_alu_intrin(stmt_in):
def InjectALUIntrin():
"""Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(func, mod, ctx):
env = get_env()
idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
......@@ -972,24 +955,8 @@ def inject_alu_intrin(stmt_in):
return irb.get()
return stmt
stmt_out = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def debug_print(stmt):
"""A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, None, _do_fold, ["AttrStmt"]))
Returns
-------
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print(stmt)
return stmt
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment