Unverified Commit d3277874 by Tianqi Chen Committed by GitHub

[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)

parent 72f2aea2
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/support/with.h> #include <tvm/support/with.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object { ...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */ /*! \brief Whether to partition const loop */
bool partition_const_loop = false; bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief List of passes to be injected into the low-level pipeline. */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false; bool dump_pass_ir = false;
......
...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs): ...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel. """Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block. This pass will check memory usage and number of threads per block.
""" """
def verify_pass(stmt): def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(stmt, kwargs) valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid: if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel") raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds): ...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func}) return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch, def lower(sch,
args, args,
name="main", name="main",
...@@ -190,15 +171,15 @@ def lower(sch, ...@@ -190,15 +171,15 @@ def lower(sch,
else: else:
mod = sch mod = sch
pass_list = lower_phase0
# Phase 1 # Phase 1
pass_list = [ pass_list += [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(), tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
] ]
pass_list += lower_phase1
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
...@@ -214,8 +195,8 @@ def lower(sch, ...@@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent, cfg.auto_unroll_max_extent,
cfg.unroll_explicit), cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
] ]
pass_list += lower_phase2
# Phase 3 # Phase 3
pass_list += [ pass_list += [
...@@ -225,7 +206,7 @@ def lower(sch, ...@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting: if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] pass_list += lower_phase3
# Instrument BoundCheckers # Instrument BoundCheckers
if cfg.instrument_bound_checkers: if cfg.instrument_bound_checkers:
......
...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc): ...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") ...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0]; BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1); CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) { for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair( add_lower_pass.push_back(std::make_pair(
args[i].operator int(), args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc())); args[i + 1].operator transform::Pass()));
} }
cfg->add_lower_pass = add_lower_pass; cfg->add_lower_pass = add_lower_pass;
}); });
......
...@@ -51,11 +51,13 @@ def test_fold_const(): ...@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data)) z = relay.add(y, relay.const(c_data))
return relay.Function([x], z) return relay.Function([x], z)
def fail(x): def FailPass():
raise RuntimeError() def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context. # the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]): with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"): with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
......
...@@ -182,7 +182,7 @@ def test_cuda_shuffle(): ...@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx) sch[c].bind(xo, thrx)
sch[c].vectorize(xi) sch[c].vectorize(xi)
def my_vectorize(stmt): def MyVectorize():
def vectorizer(op): def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized: if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32') four = tvm.tir.const(4, 'int32')
...@@ -198,9 +198,13 @@ def test_cuda_shuffle(): ...@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids) new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda') module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32') a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
......
...@@ -671,8 +671,7 @@ def test_llvm_shuffle(): ...@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x]) c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
def my_vectorize(stmt): def my_vectorize():
def vectorizer(op): def vectorizer(op):
store = op.body store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8) idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
...@@ -684,9 +683,13 @@ def test_llvm_shuffle(): ...@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c]) module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32')) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
......
...@@ -19,10 +19,10 @@ import tvm ...@@ -19,10 +19,10 @@ import tvm
from tvm import te from tvm import te
def get_verify_pass(valid, **kwargs): def get_verify_pass(valid, **kwargs):
def verify_pass(stmt): def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs) valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory(): def test_shared_memory():
def check_shared_memory(dtype): def check_shared_memory(dtype):
......
...@@ -117,19 +117,20 @@ def vectorize8(op): ...@@ -117,19 +117,20 @@ def vectorize8(op):
return body return body
return None return None
def vectorize(stmt): @tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops global loops
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8) tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops: if not loops:
return stmt return sf
# The last list arugment indicates what kinds of nodes will be transformed. # The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8` # Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
return stmt
##################################################################### #####################################################################
# Glue to Lowering # Glue to Lowering
......
...@@ -14,25 +14,22 @@ ...@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-argument # pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime.""" """VTA specific buildin for runtime."""
import tvm import tvm
from . import ir_pass from . import transform
from .environment import get_env from .environment import get_env
def lift_coproc_scope(x): def EarlyRewrite():
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
"""Try to do storage rewrite in early pass.""" """Try to do storage rewrite in early pass."""
try: def _transform(mod, ctx):
return tvm.tir.ir_pass.StorageRewrite(stmt) try:
except tvm.error.TVMError: return tvm.tir.transform.StorageRewrite()(mod)
return stmt except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs): def build_config(debug_flag=0, **kwargs):
...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs): ...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...) vta_module = tvm.build(s, ...)
""" """
env = get_env() env = get_env()
def add_debug(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern( debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode", "int32", "VTASetDebugMode",
env.dev.command_handle, env.dev.command_handle,
debug_flag) debug_flag)
return tvm.tir.stmt_seq(debug, stmt) return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy), pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, ir_pass.annotate_alu_coproc_scope), (1, transform.InjectDMAIntrin()),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), (1, transform.InjectSkipCopy()),
(1, lift_coproc_scope), (1, transform.AnnotateALUCoProcScope()),
(1, ir_pass.inject_coproc_sync), (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, early_rewrite)] (1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag: if debug_flag:
pass_list.append((1, add_debug)) pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo)) pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs) return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Additional IR Pass for VTA""" """Additional Transformation Passes. for VTA"""
# pylint: disable=len-as-condition, no-else-return # pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
import tvm import tvm
from tvm import te from tvm import te
from topi import util from topi import util
...@@ -38,7 +38,7 @@ def _match_pragma(stmt, key): ...@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(stmt.attr_key == "pragma_scope" and stmt.value.value == key)) (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
def fold_uop_loop(stmt_in): def FoldUopLoop():
"""Detect and fold uop loop. """Detect and fold uop loop.
VTA support uop programming model VTA support uop programming model
...@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in): ...@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure This pass detect the loop structure
and extract that into uop loop AST. and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Output statement. The pass
""" """
env = get_env()
def _fold_outermost_loop(body): def _fold_outermost_loop(body):
stmt = body stmt = body
if not isinstance(stmt, tvm.tir.For): if not isinstance(stmt, tvm.tir.For):
...@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in): ...@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise ValueError("Failed to fold the GEMM instructions..") raise ValueError("Failed to fold the GEMM instructions..")
def _do_fold(stmt): def _do_fold(stmt):
env = get_env()
if (stmt.attr_key == "coproc_uop_scope" and if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.tir.StringImm) and isinstance(stmt.value, tvm.tir.StringImm) and
stmt.value.value == env.dev.vta_push_uop.value): stmt.value.value == env.dev.vta_push_uop.value):
...@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in): ...@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return tvm.tir.AttrStmt( return tvm.tir.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body) stmt.node, stmt.attr_key, stmt.value, body)
return None return None
out = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"]) def _ftransform(f, mod, ctx):
return out return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
def cpu_access_rewrite(stmt_in): def CPUAccessRewrite():
"""Detect CPU access to VTA buffer and get address correctly. """Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not VTA's buffer is an opaque handle that do not
...@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in): ...@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access. returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env() def _ftransform(f, mod, ctx):
rw_info = {} rw_info = {}
def _post_order(op): env = get_env()
if isinstance(op, tvm.tir.Allocate): def _post_order(op):
buffer_var = op.buffer_var if isinstance(op, tvm.tir.Allocate):
if not buffer_var in rw_info: buffer_var = op.buffer_var
return None if not buffer_var in rw_info:
new_var = rw_info[buffer_var] return None
let_stmt = tvm.tir.LetStmt( new_var = rw_info[buffer_var]
let_stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), op.body)
alloc = tvm.tir.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
if isinstance(op, tvm.tir.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Load(op.dtype, new_var, op.index)
if isinstance(op, tvm.tir.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern( new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr", "handle", "VTABufferCPUPtr",
env.dev.command_handle, env.dev.command_handle,
buffer_var), op.body) buffer_var), stmt)
alloc = tvm.tir.Allocate( return f.with_body(stmt)
buffer_var, op.dtype, op.extents, return tvm.tir.transform.prim_func_pass(
op.condition, let_stmt) _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
del rw_info[buffer_var]
return alloc
if isinstance(op, tvm.tir.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Load(op.dtype, new_var, op.index)
if isinstance(op, tvm.tir.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), stmt)
return stmt
def lift_alloc_to_scope_begin(stmt_in): def LiftAllocToScopeBegin():
"""Lift allocate to beginning of the current scope. """Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
lift_stmt = [[]] def _ftransform(f, mod, ctx):
def _merge_block(slist, body): lift_stmt = [[]]
for op in slist: def _merge_block(slist, body):
if op.body == body: for op in slist:
body = op if op.body == body:
elif isinstance(op, tvm.tir.Allocate): body = op
body = tvm.tir.Allocate( elif isinstance(op, tvm.tir.Allocate):
op.buffer_var, op.dtype, body = tvm.tir.Allocate(
op.extents, op.condition, body) op.buffer_var, op.dtype,
elif isinstance(op, tvm.tir.AttrStmt): op.extents, op.condition, body)
body = tvm.tir.AttrStmt( elif isinstance(op, tvm.tir.AttrStmt):
op.node, op.attr_key, op.value, body) body = tvm.tir.AttrStmt(
elif isinstance(op, tvm.tir.For): op.node, op.attr_key, op.value, body)
body = tvm.tir.For( elif isinstance(op, tvm.tir.For):
op.loop_var, op.min, op.extent, op.for_type, body = tvm.tir.For(
op.device_api, body) op.loop_var, op.min, op.extent, op.for_type,
else: op.device_api, body)
raise RuntimeError("unexpected op") else:
del slist[:] raise RuntimeError("unexpected op")
return body del slist[:]
return body
def _pre_order(op):
if isinstance(op, tvm.tir.For): def _pre_order(op):
lift_stmt.append([]) if isinstance(op, tvm.tir.For):
elif isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "virtual_thread":
lift_stmt.append([]) lift_stmt.append([])
elif isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "virtual_thread":
lift_stmt.append([])
def _post_order(op): def _post_order(op):
if isinstance(op, tvm.tir.Allocate): if isinstance(op, tvm.tir.Allocate):
lift_stmt[-1].append(op)
return op.body
if isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op) lift_stmt[-1].append(op)
return op.body return op.body
if op.attr_key == "virtual_thread": if isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op)
return op.body
if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body) return _merge_block(lift_stmt.pop() + [op], op.body)
return op raise RuntimeError("not reached")
if isinstance(op, tvm.tir.For): stmt_in = f.body
return _merge_block(lift_stmt.pop() + [op], op.body) stmt = tvm.tir.ir_pass.IRTransform(
raise RuntimeError("not reached") stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
stmt = tvm.tir.ir_pass.IRTransform( assert len(lift_stmt) == 1
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) return f.with_body(_merge_block(lift_stmt[0], stmt))
assert len(lift_stmt) == 1
return _merge_block(lift_stmt[0], stmt)
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters def InjectSkipCopy():
---------- """Pass to inject skip copy stmt, used for debug purpose.
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
def _do_fold(stmt): def _do_fold(stmt):
if _match_pragma(stmt, "skip_dma_copy"): if _match_pragma(stmt, "skip_dma_copy"):
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
return None return None
return tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
def inject_coproc_sync(stmt_in): return tvm.tir.transform.prim_func_pass(
"""Pass to inject skip copy stmt, used in debug. _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
Parameters
---------- def InjectCoProcSync():
stmt_in : Stmt """Pass inject coproc sync
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
success = [False] def _ftransform(f, *_):
def _do_fold(stmt): success = [False]
if _match_pragma(stmt, "coproc_sync"): def _do_fold(stmt):
success[0] = True if _match_pragma(stmt, "coproc_sync"):
sync = tvm.tir.Call( success[0] = True
"int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) sync = tvm.tir.Call(
return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
if _match_pragma(stmt, "trim_loop"): return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
op = stmt.body if _match_pragma(stmt, "trim_loop"):
assert isinstance(op, tvm.tir.For) op = stmt.body
return tvm.tir.For( assert isinstance(op, tvm.tir.For)
op.loop_var, op.min, 2, op.for_type, return tvm.tir.For(
op.device_api, op.body) op.loop_var, op.min, 2, op.for_type,
return None op.device_api, op.body)
stmt = tvm.tir.ir_pass.IRTransform( return None
stmt_in, None, _do_fold, ["AttrStmt"]) return f.with_body(tvm.tir.ir_pass.IRTransform(
stmt = tvm.tir.ir_pass.CoProcSync(stmt) f.body, None, _do_fold, ["AttrStmt"]))
return stmt return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
tvm.tir.transform.CoProcSync()],
def inject_dma_intrin(stmt_in): opt_level=0, name="tir.vta.InjectCoProcSync")
def InjectDMAIntrin():
"""Pass to inject DMA copy intrinsics. """Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env()
idxd = tvm.tir.indexdiv idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod idxm = tvm.tir.indexmod
...@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in): ...@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value): def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored... # FIXME: pad_value is ignored...
env = get_env()
_ = pad_value _ = pad_value
if dst.scope == "global": if dst.scope == "global":
# Store # Store
...@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in): ...@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else: else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy) return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
def _get_gemm_intrin_buffer(): def _get_gemm_intrin_buffer():
...@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer(): ...@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer():
return wgt_layout, inp_layout, out_layout return wgt_layout, inp_layout, out_layout
def inject_conv2d_transpose_skip(stmt_in): def InjectConv2DTransposeSkip():
"""Pass to skip 0-weights in conv2d transpose with stride > 1. """Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env() def _ftransform(func, mod, ctx):
dwgt, dinp, dout = _get_gemm_intrin_buffer() env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer()
calls = []
selects = [] calls = []
selects = []
def _find_basics(op):
if isinstance(op, tvm.tir.BufferLoad): def _find_basics(op):
calls.append(op) if isinstance(op, tvm.tir.BufferLoad):
elif isinstance(op, tvm.tir.Select): calls.append(op)
selects.append(op) elif isinstance(op, tvm.tir.Select):
selects.append(op)
def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"): def _do_fold(op):
is_init = ".init" in str(op) if _match_pragma(op, "conv2d_transpose_gemm"):
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) is_init = ".init" in str(op)
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
if is_init:
# create inner most block if is_init:
irb = tvm.tir.ir_builder.create() # create inner most block
dev = env.dev irb = tvm.tir.ir_builder.create()
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0, 0,
0, 0, 0))
inner = irb.get()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body = op.body.body
while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case
args = body.indices
res_buffer = body.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_buffer], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
else:
conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor = data_call.buffer
kernel_tensor = kernel_call.buffer
res_tensor = conv_call.buffer
if selects:
condition = selects[0].condition
else:
condition = tvm.tir.const(1, 'int')
# create inner most block
irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition):
dev = env.dev dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0, 0, 1,
dout.access_ptr("rw", "int32"), dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"), 0, 0,
dwgt.access_ptr("r", "int32"),
0, 0, 0)) 0, 0, 0))
inner = irb.get() inner = irb.get()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
args = conv_call.indices body = op.body.body
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], while isinstance(body, tvm.tir.IfThenElse):
1, 0, 1, 0, env.BLOCK_OUT) body = body.then_case
inner = tvm.tir.AttrStmt( args = body.indices
[dout, res_tensor], 'buffer_bind_scope', res_buffer = body.buffer
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
args = kernel_call.indices inner = tvm.tir.AttrStmt(
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], [dout, res_buffer], 'buffer_bind_scope',
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
inner = tvm.tir.AttrStmt( return inner
[dwgt, kernel_tensor], 'buffer_bind_scope', else:
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) conv_call, data_call, kernel_call = calls[-3:]
args = data_call.indices pad_data_tensor = data_call.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], kernel_tensor = kernel_call.buffer
1, 0, 1, 0, env.BLOCK_IN) res_tensor = conv_call.buffer
inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
ret = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return ret
def annotate_alu_coproc_scope(stmt_in): if selects:
condition = selects[0].condition
else:
condition = tvm.tir.const(1, 'int')
# create inner most block
irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition):
dev = env.dev
irb.scope_attr(
dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
inner = irb.get()
args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
def AnnotateALUCoProcScope():
"""Pass to insert ALU instruction. """Pass to insert ALU instruction.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env() def _ftransform(func, mod, ctx):
def _do_fold(stmt): env = get_env()
if _match_pragma(stmt, "alu"): def _do_fold(stmt):
irb = tvm.tir.ir_builder.create() if _match_pragma(stmt, "alu"):
irb.scope_attr(env.dev.vta_axis, "coproc_scope", irb = tvm.tir.ir_builder.create()
env.dev.get_task_qid(env.dev.QID_COMPUTE)) irb.scope_attr(env.dev.vta_axis, "coproc_scope",
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", env.dev.get_task_qid(env.dev.QID_COMPUTE))
tvm.tir.StringImm("VTAPushALUOp")) irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
irb.emit(stmt) tvm.tir.StringImm("VTAPushALUOp"))
return irb.get() irb.emit(stmt)
if _match_pragma(stmt, "skip_alu"): return irb.get()
return tvm.tir.Evaluate(0) if _match_pragma(stmt, "skip_alu"):
return stmt return tvm.tir.Evaluate(0)
return stmt
stmt_out = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"]) return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, None, _do_fold, ["AttrStmt"]))
return stmt_out return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
def inject_alu_intrin(stmt_in):
def InjectALUIntrin():
"""Pass to inject ALU micro-ops. """Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns Returns
------- -------
stmt_out : Stmt fpass : tvm.transform.Pass
Transformed statement The pass
""" """
env = get_env() def _ftransform(func, mod, ctx):
idxm = tvm.tir.indexmod env = get_env()
analyzer = tvm.arith.Analyzer() idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
def _do_fold(stmt): def _do_fold(stmt):
def _equal(x, y): def _equal(x, y):
return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents): def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff) src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff) dst_coeff = list(dst_coeff)
extents = list(extents) extents = list(extents)
rev_src_coeff = [src_coeff.pop()] rev_src_coeff = [src_coeff.pop()]
rev_dst_coeff = [dst_coeff.pop()] rev_dst_coeff = [dst_coeff.pop()]
rev_extents = [] rev_extents = []
assert src_coeff assert src_coeff
vsrc = src_coeff.pop() vsrc = src_coeff.pop()
vdst = dst_coeff.pop() vdst = dst_coeff.pop()
vext = extents.pop() vext = extents.pop()
while src_coeff: while src_coeff:
next_src = src_coeff.pop() next_src = src_coeff.pop()
next_dst = dst_coeff.pop() next_dst = dst_coeff.pop()
next_ext = extents.pop() next_ext = extents.pop()
if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
vext = analyzer.simplify(vext * next_ext) vext = analyzer.simplify(vext * next_ext)
else: else:
rev_src_coeff.append(vsrc) rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst) rev_dst_coeff.append(vdst)
rev_extents.append(vext) rev_extents.append(vext)
vsrc = next_src vsrc = next_src
vdst = next_dst vdst = next_dst
vext = next_ext vext = next_ext
rev_src_coeff.append(vsrc) rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst) rev_dst_coeff.append(vdst)
rev_extents.append(vext) rev_extents.append(vext)
rev_src_coeff.reverse() rev_src_coeff.reverse()
rev_dst_coeff.reverse() rev_dst_coeff.reverse()
rev_extents.reverse() rev_extents.reverse()
return rev_src_coeff, rev_dst_coeff, rev_extents return rev_src_coeff, rev_dst_coeff, rev_extents
if _match_pragma(stmt, "alu"): if _match_pragma(stmt, "alu"):
# Get to the innermost loop body # Get to the innermost loop body
loop_body = stmt.body loop_body = stmt.body
nest_size = 0 nest_size = 0
while isinstance(loop_body, tvm.tir.For): while isinstance(loop_body, tvm.tir.For):
loop_body = loop_body.body loop_body = loop_body.body
nest_size += 1 nest_size += 1
# Get the src/dst arguments # Get the src/dst arguments
dst_var = loop_body.buffer_var dst_var = loop_body.buffer_var
dst_idx = loop_body.index dst_idx = loop_body.index
# Derive loop variables and extents # Derive loop variables and extents
tmp_body = stmt.body tmp_body = stmt.body
indices = [] indices = []
extents = [] extents = []
for _ in range(nest_size): for _ in range(nest_size):
indices.append(tmp_body.loop_var) indices.append(tmp_body.loop_var)
extents.append(tmp_body.extent) extents.append(tmp_body.extent)
tmp_body = tmp_body.body tmp_body = tmp_body.body
# Derive opcode # Derive opcode
if isinstance(loop_body.value, tvm.tir.Add): if isinstance(loop_body.value, tvm.tir.Add):
alu_opcode = env.dev.ALU_OPCODE_ADD alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Sub): elif isinstance(loop_body.value, tvm.tir.Sub):
alu_opcode = env.dev.ALU_OPCODE_SUB alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Mul): elif isinstance(loop_body.value, tvm.tir.Mul):
alu_opcode = env.dev.ALU_OPCODE_MUL alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Min): elif isinstance(loop_body.value, tvm.tir.Min):
alu_opcode = env.dev.ALU_OPCODE_MIN alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Max): elif isinstance(loop_body.value, tvm.tir.Max):
alu_opcode = env.dev.ALU_OPCODE_MAX alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Call): elif isinstance(loop_body.value, tvm.tir.Call):
if loop_body.value.name == 'shift_left': if loop_body.value.name == 'shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value.args[0]
rhs = analyzer.simplify(-loop_body.value.args[1]) rhs = analyzer.simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right': elif loop_body.value.name == 'shift_right':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.tir.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value
rhs = loop_body.value.args[1] rhs = tvm.tir.const(0, "int32")
else: else:
raise RuntimeError( raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name)) "Expression not recognized %s, %s, %s" % (
elif isinstance(loop_body.value, tvm.tir.Load): type(loop_body.value), str(loop_body.value), str(stmt)))
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value # Derive array index coefficients
rhs = tvm.tir.const(0, "int32") dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
else: # Check if lhs/rhs is immediate
raise RuntimeError( use_imm = False
"Expression not recognized %s, %s, %s" % ( imm_val = None
type(loop_body.value), str(loop_body.value), str(stmt))) if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var)
# Derive array index coefficients src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) use_imm = True
# Check if lhs/rhs is immediate imm_val = rhs
use_imm = False if isinstance(lhs, tvm.tir.IntImm):
imm_val = None assert rhs.buffer_var.same_as(dst_var)
if isinstance(rhs, tvm.tir.IntImm): src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
assert lhs.buffer_var.same_as(dst_var) use_imm = True
src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) imm_val = lhs
use_imm = True if imm_val is None:
imm_val = rhs imm_val = 0
if isinstance(lhs, tvm.tir.IntImm): assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
assert rhs.buffer_var.same_as(dst_var) src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
use_imm = True # Determine which side has the same coefficients
imm_val = lhs lhs_equal = True
if imm_val is None: rhs_equal = True
imm_val = 0 for i, coef in enumerate(dst_coeff):
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) lhs_equal = False
src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
# Determine which side has the same coefficients rhs_equal = False
lhs_equal = True # Make sure at least one of the source is identical to the
rhs_equal = True # destination (in-place computation)
for i, coef in enumerate(dst_coeff): assert lhs_equal or rhs_equal
if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): # Assign the source coefficients
lhs_equal = False if lhs_equal:
if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): src_coeff = src_rhs_coeff
rhs_equal = False else:
# Make sure at least one of the source is identical to the src_coeff = src_lhs_coeff
# destination (in-place computation)
assert lhs_equal or rhs_equal # Ensure that we have the proper tensor dimensions in the
# Assign the source coefficients # innermost loop (pattern match)
if lhs_equal: src_coeff = list(src_coeff)
src_coeff = src_rhs_coeff dst_coeff = list(dst_coeff)
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.ir.structural_equal(
analyzer.simplify(
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(
analyzer.simplify(
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
else: else:
src_coeff = src_lhs_coeff src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
# Ensure that we have the proper tensor dimensions in the extents = extents[:-2]
# innermost loop (pattern match) src_coeff.append(src_offset)
src_coeff = list(src_coeff) dst_coeff.append(dst_offset)
dst_coeff = list(dst_coeff) src_coeff = [
extents = list(extents) analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
assert len(src_coeff) > 1 dst_coeff = [
assert len(dst_coeff) > 1 analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
assert len(extents) != 0
assert tvm.ir.structural_equal( # Flatten the outer loops
analyzer.simplify( if extents:
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
assert tvm.ir.structural_equal(
analyzer.simplify( # Insert ALU micro-ops
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) irb = tvm.tir.ir_builder.create()
assert tvm.ir.structural_equal(src_coeff[-2], 1) for idx, extent in enumerate(extents):
assert tvm.ir.structural_equal(dst_coeff[-2], 1) irb.emit(tvm.tir.call_extern(
if env.BATCH > 1: "int32", "VTAUopLoopBegin",
assert len(src_coeff) > 2 extent, dst_coeff[idx], src_coeff[idx], 0))
assert len(dst_coeff) > 2 use_imm = int(use_imm)
assert len(extents) > 1
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
else:
src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
extents = extents[:-2]
src_coeff.append(src_offset)
dst_coeff.append(dst_offset)
src_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
irb = tvm.tir.ir_builder.create()
for idx, extent in enumerate(extents):
irb.emit(tvm.tir.call_extern(
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
irb.emit(tvm.tir.call_extern(
"int32", "VTAUopPush",
1, 0,
dst_coeff[len(dst_coeff)-1],
src_coeff[len(src_coeff)-1],
0,
alu_opcode, use_imm, imm_val))
for extent in extents:
irb.emit(tvm.tir.call_extern( irb.emit(tvm.tir.call_extern(
"int32", "VTAUopLoopEnd")) "int32", "VTAUopPush",
return irb.get() 1, 0,
return stmt dst_coeff[len(dst_coeff)-1],
src_coeff[len(src_coeff)-1],
stmt_out = tvm.tir.ir_pass.IRTransform( 0,
stmt_in, None, _do_fold, ["AttrStmt"]) alu_opcode, use_imm, imm_val))
return stmt_out for extent in extents:
irb.emit(tvm.tir.call_extern(
"int32", "VTAUopLoopEnd"))
def debug_print(stmt): return irb.get()
"""A debug pass that print the stmt return stmt
Parameters return func.with_body(tvm.tir.ir_pass.IRTransform(
---------- func.body, None, _do_fold, ["AttrStmt"]))
stmt : Stmt
The input statement return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")
Returns
-------
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print(stmt)
return stmt
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment