Unverified Commit d3277874 by Tianqi Chen Committed by GitHub

[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)

parent 72f2aea2
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/support/with.h> #include <tvm/support/with.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object { ...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */ /*! \brief Whether to partition const loop */
bool partition_const_loop = false; bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief List of passes to be injected into the low-level pipeline. */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false; bool dump_pass_ir = false;
......
...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs): ...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel. """Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block. This pass will check memory usage and number of threads per block.
""" """
def verify_pass(stmt): def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(stmt, kwargs) valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid: if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel") raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds): ...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func}) return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch, def lower(sch,
args, args,
name="main", name="main",
...@@ -190,15 +171,15 @@ def lower(sch, ...@@ -190,15 +171,15 @@ def lower(sch,
else: else:
mod = sch mod = sch
pass_list = lower_phase0
# Phase 1 # Phase 1
pass_list = [ pass_list += [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(), tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
] ]
pass_list += lower_phase1
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
...@@ -214,8 +195,8 @@ def lower(sch, ...@@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent, cfg.auto_unroll_max_extent,
cfg.unroll_explicit), cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
] ]
pass_list += lower_phase2
# Phase 3 # Phase 3
pass_list += [ pass_list += [
...@@ -225,7 +206,7 @@ def lower(sch, ...@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting: if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] pass_list += lower_phase3
# Instrument BoundCheckers # Instrument BoundCheckers
if cfg.instrument_bound_checkers: if cfg.instrument_bound_checkers:
......
...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc): ...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") ...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0]; BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1); CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) { for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair( add_lower_pass.push_back(std::make_pair(
args[i].operator int(), args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc())); args[i + 1].operator transform::Pass()));
} }
cfg->add_lower_pass = add_lower_pass; cfg->add_lower_pass = add_lower_pass;
}); });
......
...@@ -51,11 +51,13 @@ def test_fold_const(): ...@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data)) z = relay.add(y, relay.const(c_data))
return relay.Function([x], z) return relay.Function([x], z)
def fail(x): def FailPass():
def _transform(m, *args):
raise RuntimeError() raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context. # the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]): with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"): with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
......
...@@ -182,7 +182,7 @@ def test_cuda_shuffle(): ...@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx) sch[c].bind(xo, thrx)
sch[c].vectorize(xi) sch[c].vectorize(xi)
def my_vectorize(stmt): def MyVectorize():
def vectorizer(op): def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized: if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32') four = tvm.tir.const(4, 'int32')
...@@ -198,9 +198,13 @@ def test_cuda_shuffle(): ...@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids) new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda') module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32') a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
......
...@@ -671,8 +671,7 @@ def test_llvm_shuffle(): ...@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x]) c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
def my_vectorize(stmt): def my_vectorize():
def vectorizer(op): def vectorizer(op):
store = op.body store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8) idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
...@@ -684,9 +683,13 @@ def test_llvm_shuffle(): ...@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c]) module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32')) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
......
...@@ -19,10 +19,10 @@ import tvm ...@@ -19,10 +19,10 @@ import tvm
from tvm import te from tvm import te
def get_verify_pass(valid, **kwargs): def get_verify_pass(valid, **kwargs):
def verify_pass(stmt): def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs) valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory(): def test_shared_memory():
def check_shared_memory(dtype): def check_shared_memory(dtype):
......
...@@ -117,19 +117,20 @@ def vectorize8(op): ...@@ -117,19 +117,20 @@ def vectorize8(op):
return body return body
return None return None
def vectorize(stmt): @tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops global loops
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8) tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops: if not loops:
return stmt return sf
# The last list arugment indicates what kinds of nodes will be transformed. # The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8` # Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
return stmt
##################################################################### #####################################################################
# Glue to Lowering # Glue to Lowering
......
...@@ -14,25 +14,22 @@ ...@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-argument # pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime.""" """VTA specific buildin for runtime."""
import tvm import tvm
from . import ir_pass from . import transform
from .environment import get_env from .environment import get_env
def lift_coproc_scope(x): def EarlyRewrite():
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
"""Try to do storage rewrite in early pass.""" """Try to do storage rewrite in early pass."""
def _transform(mod, ctx):
try: try:
return tvm.tir.ir_pass.StorageRewrite(stmt) return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError: except tvm.error.TVMError:
return stmt return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs): def build_config(debug_flag=0, **kwargs):
...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs): ...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...) vta_module = tvm.build(s, ...)
""" """
env = get_env() env = get_env()
def add_debug(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern( debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode", "int32", "VTASetDebugMode",
env.dev.command_handle, env.dev.command_handle,
debug_flag) debug_flag)
return tvm.tir.stmt_seq(debug, stmt) return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy), pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, ir_pass.annotate_alu_coproc_scope), (1, transform.InjectDMAIntrin()),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), (1, transform.InjectSkipCopy()),
(1, lift_coproc_scope), (1, transform.AnnotateALUCoProcScope()),
(1, ir_pass.inject_coproc_sync), (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, early_rewrite)] (1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag: if debug_flag:
pass_list.append((1, add_debug)) pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo)) pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs) return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Additional IR Pass for VTA""" """Additional Transformation Passes. for VTA"""
# pylint: disable=len-as-condition, no-else-return # pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
import tvm import tvm
from tvm import te from tvm import te
from topi import util from topi import util
...@@ -38,7 +38,7 @@ def _match_pragma(stmt, key): ...@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(stmt.attr_key == "pragma_scope" and stmt.value.value == key)) (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
def fold_uop_loop(stmt_in): def FoldUopLoop():
"""Detect and fold uop loop. """Detect and fold uop loop.
VTA support uop programming model VTA support uop programming model
...@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in): ...@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure This pass detect the loop structure
and extract that into uop loop AST. and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Output statement. The pass
""" """
env = get_env()
def _fold_outermost_loop(body): def _fold_outermost_loop(body):
stmt = body stmt = body
if not isinstance(stmt, tvm.tir.For): if not isinstance(stmt, tvm.tir.For):
...@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in): ...@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise ValueError("Failed to fold the GEMM instructions..") raise ValueError("Failed to fold the GEMM instructions..")
def _do_fold(stmt): def _do_fold(stmt):
env = get_env()
if (stmt.attr_key == "coproc_uop_scope" and if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.tir.StringImm) and isinstance(stmt.value, tvm.tir.StringImm) and
stmt.value.value == env.dev.vta_push_uop.value): stmt.value.value == env.dev.vta_push_uop.value):
...@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in): ...@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return tvm.tir.AttrStmt( return tvm.tir.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body) stmt.node, stmt.attr_key, stmt.value, body)
return None return None
out = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"]) def _ftransform(f, mod, ctx):
return out return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
def cpu_access_rewrite(stmt_in): def CPUAccessRewrite():
"""Detect CPU access to VTA buffer and get address correctly. """Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not VTA's buffer is an opaque handle that do not
...@@ -148,18 +146,14 @@ def cpu_access_rewrite(stmt_in): ...@@ -148,18 +146,14 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access. returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env() def _ftransform(f, mod, ctx):
rw_info = {} rw_info = {}
env = get_env()
def _post_order(op): def _post_order(op):
if isinstance(op, tvm.tir.Allocate): if isinstance(op, tvm.tir.Allocate):
buffer_var = op.buffer_var buffer_var = op.buffer_var
...@@ -191,30 +185,31 @@ def cpu_access_rewrite(stmt_in): ...@@ -191,30 +185,31 @@ def cpu_access_rewrite(stmt_in):
new_var = rw_info[buffer_var] new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index) return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform( stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items(): for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt( stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern( new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr", "handle", "VTABufferCPUPtr",
env.dev.command_handle, env.dev.command_handle,
buffer_var), stmt) buffer_var), stmt)
return stmt return f.with_body(stmt)
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
def lift_alloc_to_scope_begin(stmt_in): def LiftAllocToScopeBegin():
"""Lift allocate to beginning of the current scope. """Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _ftransform(f, mod, ctx):
lift_stmt = [[]] lift_stmt = [[]]
def _merge_block(slist, body): def _merge_block(slist, body):
for op in slist: for op in slist:
...@@ -257,46 +252,46 @@ def lift_alloc_to_scope_begin(stmt_in): ...@@ -257,46 +252,46 @@ def lift_alloc_to_scope_begin(stmt_in):
if isinstance(op, tvm.tir.For): if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body) return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform( stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1 assert len(lift_stmt) == 1
return _merge_block(lift_stmt[0], stmt) return f.with_body(_merge_block(lift_stmt[0], stmt))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters def InjectSkipCopy():
---------- """Pass to inject skip copy stmt, used for debug purpose.
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _do_fold(stmt): def _do_fold(stmt):
if _match_pragma(stmt, "skip_dma_copy"): if _match_pragma(stmt, "skip_dma_copy"):
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
return None return None
return tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
def inject_coproc_sync(stmt_in): return tvm.tir.transform.prim_func_pass(
"""Pass to inject skip copy stmt, used in debug. _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
Parameters
---------- def InjectCoProcSync():
stmt_in : Stmt """Pass inject coproc sync
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _ftransform(f, *_):
success = [False] success = [False]
def _do_fold(stmt): def _do_fold(stmt):
if _match_pragma(stmt, "coproc_sync"): if _match_pragma(stmt, "coproc_sync"):
...@@ -311,26 +306,22 @@ def inject_coproc_sync(stmt_in): ...@@ -311,26 +306,22 @@ def inject_coproc_sync(stmt_in):
op.loop_var, op.min, 2, op.for_type, op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body) op.device_api, op.body)
return None return None
stmt = tvm.tir.ir_pass.IRTransform( return f.with_body(tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"]) f.body, None, _do_fold, ["AttrStmt"]))
stmt = tvm.tir.ir_pass.CoProcSync(stmt) return tvm.transform.Sequential(
return stmt [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
tvm.tir.transform.CoProcSync()],
opt_level=0, name="tir.vta.InjectCoProcSync")
def inject_dma_intrin(stmt_in): def InjectDMAIntrin():
"""Pass to inject DMA copy intrinsics. """Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env()
idxd = tvm.tir.indexdiv idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod idxm = tvm.tir.indexmod
...@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in): ...@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value): def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored... # FIXME: pad_value is ignored...
env = get_env()
_ = pad_value _ = pad_value
if dst.scope == "global": if dst.scope == "global":
# Store # Store
...@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in): ...@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else: else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy) return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
def _get_gemm_intrin_buffer(): def _get_gemm_intrin_buffer():
...@@ -619,19 +611,15 @@ def _get_gemm_intrin_buffer(): ...@@ -619,19 +611,15 @@ def _get_gemm_intrin_buffer():
return wgt_layout, inp_layout, out_layout return wgt_layout, inp_layout, out_layout
def inject_conv2d_transpose_skip(stmt_in): def InjectConv2DTransposeSkip():
"""Pass to skip 0-weights in conv2d transpose with stride > 1. """Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _ftransform(func, mod, ctx):
env = get_env() env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer() dwgt, dinp, dout = _get_gemm_intrin_buffer()
...@@ -687,7 +675,8 @@ def inject_conv2d_transpose_skip(stmt_in): ...@@ -687,7 +675,8 @@ def inject_conv2d_transpose_skip(stmt_in):
irb = tvm.tir.ir_builder.create() irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition): with irb.if_scope(condition):
dev = env.dev dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(
dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0, 0, 0,
...@@ -717,24 +706,22 @@ def inject_conv2d_transpose_skip(stmt_in): ...@@ -717,24 +706,22 @@ def inject_conv2d_transpose_skip(stmt_in):
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner return inner
return None return None
ret = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return ret
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
def annotate_alu_coproc_scope(stmt_in):
"""Pass to insert ALU instruction.
Parameters def AnnotateALUCoProcScope():
---------- """Pass to insert ALU instruction.
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _ftransform(func, mod, ctx):
env = get_env() env = get_env()
def _do_fold(stmt): def _do_fold(stmt):
if _match_pragma(stmt, "alu"): if _match_pragma(stmt, "alu"):
...@@ -749,25 +736,21 @@ def annotate_alu_coproc_scope(stmt_in): ...@@ -749,25 +736,21 @@ def annotate_alu_coproc_scope(stmt_in):
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
return stmt return stmt
stmt_out = tvm.tir.ir_pass.IRTransform( return func.with_body(tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"]) func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
return stmt_out _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
def inject_alu_intrin(stmt_in): def InjectALUIntrin():
"""Pass to inject ALU micro-ops. """Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _ftransform(func, mod, ctx):
env = get_env() env = get_env()
idxm = tvm.tir.indexmod idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer() analyzer = tvm.arith.Analyzer()
...@@ -972,24 +955,8 @@ def inject_alu_intrin(stmt_in): ...@@ -972,24 +955,8 @@ def inject_alu_intrin(stmt_in):
return irb.get() return irb.get()
return stmt return stmt
stmt_out = tvm.tir.ir_pass.IRTransform( return func.with_body(tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"]) func.body, None, _do_fold, ["AttrStmt"]))
return stmt_out
def debug_print(stmt):
"""A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
Returns return tvm.tir.transform.prim_func_pass(
------- _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print(stmt)
return stmt
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment