Unverified Commit d3277874 by Tianqi Chen Committed by GitHub

[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)

parent 72f2aea2
......@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <vector>
......@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief List of passes to be injected into the low-level pipeline. */
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
......@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
return f
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
......@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch,
......@@ -190,15 +171,15 @@ def lower(sch,
mod = sch
pass_list = lower_phase0
# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
pass_list += [
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
pass_list += lower_phase1
# Phase 2
if not simple_mode:
......@@ -214,8 +195,8 @@ def lower(sch,
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
pass_list += lower_phase2
# Phase 3
pass_list += [
......@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
pass_list += lower_phase3
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
......@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
new_body : Stmt
The new body.
new_func : PrimFunc
The created new function.
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
......@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
args[i + 1].operator transform::Pass()));
cfg->add_lower_pass = add_lower_pass;
......@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
def fail(x):
raise RuntimeError()
def FailPass():
def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
......@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx)
def my_vectorize(stmt):
def MyVectorize():
def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32')
......@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
......@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op)
def my_vectorize(stmt):
def my_vectorize():
def vectorizer(op):
store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
......@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
......@@ -19,10 +19,10 @@ import tvm
from tvm import te
def get_verify_pass(valid, **kwargs):
def verify_pass(stmt):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
return stmt
return verify_pass
def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return f
return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory():
def check_shared_memory(dtype):
......@@ -117,19 +117,20 @@ def vectorize8(op):
return body
return None
def vectorize(stmt):
def vectorize(f, mod, ctx):
global loops
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops:
return stmt
return sf
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
return stmt
# Glue to Lowering
......@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from . import ir_pass
from . import transform
from .environment import get_env
def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
return tvm.tir.ir_pass.StorageRewrite(stmt)
except tvm.error.TVMError:
return stmt
def _transform(mod, ctx):
return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs):
......@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
env = get_env()
def add_debug(stmt):
def add_debug(f, *_):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, transform.InjectDMAIntrin()),
(1, transform.InjectSkipCopy()),
(1, transform.AnnotateALUCoProcScope()),
(1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
......@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Additional IR Pass for VTA"""
# pylint: disable=len-as-condition, no-else-return
"""Additional Transformation Passes. for VTA"""
# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
import tvm
from tvm import te
from topi import util
......@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(stmt.attr_key == "pragma_scope" and stmt.value.value == key))
def fold_uop_loop(stmt_in):
def FoldUopLoop():
"""Detect and fold uop loop.
VTA support uop programming model
......@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure
and extract that into uop loop AST.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Output statement.
fpass : tvm.transform.Pass
The pass
env = get_env()
def _fold_outermost_loop(body):
stmt = body
if not isinstance(stmt, tvm.tir.For):
......@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise ValueError("Failed to fold the GEMM instructions..")
def _do_fold(stmt):
env = get_env()
if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.tir.StringImm) and
stmt.value.value == env.dev.vta_push_uop.value):
......@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return tvm.tir.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body)
return None
out = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return out
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
def cpu_access_rewrite(stmt_in):
def CPUAccessRewrite():
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
......@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
env = get_env()
rw_info = {}
def _post_order(op):
if isinstance(op, tvm.tir.Allocate):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
return None
new_var = rw_info[buffer_var]
let_stmt = tvm.tir.LetStmt(
def _ftransform(f, mod, ctx):
rw_info = {}
env = get_env()
def _post_order(op):
if isinstance(op, tvm.tir.Allocate):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
return None
new_var = rw_info[buffer_var]
let_stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
buffer_var), op.body)
alloc = tvm.tir.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
if isinstance(op, tvm.tir.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Load(op.dtype, new_var, op.index)
if isinstance(op, tvm.tir.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
buffer_var), op.body)
alloc = tvm.tir.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
if isinstance(op, tvm.tir.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Load(op.dtype, new_var, op.index)
if isinstance(op, tvm.tir.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = te.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
new_var, tvm.tir.call_extern(
"handle", "VTABufferCPUPtr",
buffer_var), stmt)
return stmt
buffer_var), stmt)
return f.with_body(stmt)
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
def lift_alloc_to_scope_begin(stmt_in):
def LiftAllocToScopeBegin():
"""Lift allocate to beginning of the current scope.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
lift_stmt = [[]]
def _merge_block(slist, body):
for op in slist:
if op.body == body:
body = op
elif isinstance(op, tvm.tir.Allocate):
body = tvm.tir.Allocate(
op.buffer_var, op.dtype,
op.extents, op.condition, body)
elif isinstance(op, tvm.tir.AttrStmt):
body = tvm.tir.AttrStmt(
op.node, op.attr_key, op.value, body)
elif isinstance(op, tvm.tir.For):
body = tvm.tir.For(
op.loop_var, op.min, op.extent, op.for_type,
op.device_api, body)
raise RuntimeError("unexpected op")
del slist[:]
return body
def _pre_order(op):
if isinstance(op, tvm.tir.For):
elif isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "virtual_thread":
def _ftransform(f, mod, ctx):
lift_stmt = [[]]
def _merge_block(slist, body):
for op in slist:
if op.body == body:
body = op
elif isinstance(op, tvm.tir.Allocate):
body = tvm.tir.Allocate(
op.buffer_var, op.dtype,
op.extents, op.condition, body)
elif isinstance(op, tvm.tir.AttrStmt):
body = tvm.tir.AttrStmt(
op.node, op.attr_key, op.value, body)
elif isinstance(op, tvm.tir.For):
body = tvm.tir.For(
op.loop_var, op.min, op.extent, op.for_type,
op.device_api, body)
raise RuntimeError("unexpected op")
del slist[:]
return body
def _pre_order(op):
if isinstance(op, tvm.tir.For):
elif isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "virtual_thread":
def _post_order(op):
if isinstance(op, tvm.tir.Allocate):
return op.body
if isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "storage_scope":
def _post_order(op):
if isinstance(op, tvm.tir.Allocate):
return op.body
if op.attr_key == "virtual_thread":
if isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "storage_scope":
return op.body
if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached")
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
return _merge_block(lift_stmt[0], stmt)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
return f.with_body(_merge_block(lift_stmt[0], stmt))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used for debug purpose.
stmt_in : Stmt
Input statement
def InjectSkipCopy():
"""Pass to inject skip copy stmt, used for debug purpose.
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
def _do_fold(stmt):
if _match_pragma(stmt, "skip_dma_copy"):
return tvm.tir.Evaluate(0)
return None
return tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, _do_fold, None, ["AttrStmt"]))
def inject_coproc_sync(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
stmt_in : Stmt
Input statement
def InjectCoProcSync():
"""Pass inject coproc sync
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
success = [False]
def _do_fold(stmt):
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
sync = tvm.tir.Call(
"int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.tir.For)
return tvm.tir.For(
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
stmt = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
stmt = tvm.tir.ir_pass.CoProcSync(stmt)
return stmt
def inject_dma_intrin(stmt_in):
def _ftransform(f, *_):
success = [False]
def _do_fold(stmt):
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
sync = tvm.tir.Call(
"int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.tir.For)
return tvm.tir.For(
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
return f.with_body(tvm.tir.ir_pass.IRTransform(
f.body, None, _do_fold, ["AttrStmt"]))
return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
opt_level=0, name="tir.vta.InjectCoProcSync")
def InjectDMAIntrin():
"""Pass to inject DMA copy intrinsics.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
env = get_env()
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
......@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
# Store
......@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
def _get_gemm_intrin_buffer():
......@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer():
return wgt_layout, inp_layout, out_layout
def inject_conv2d_transpose_skip(stmt_in):
def InjectConv2DTransposeSkip():
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer()
calls = []
selects = []
def _find_basics(op):
if isinstance(op, tvm.tir.BufferLoad):
elif isinstance(op, tvm.tir.Select):
def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"):
is_init = ".init" in str(op)
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
if is_init:
# create inner most block
irb = tvm.tir.ir_builder.create()
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0, 0,
0, 0, 0))
inner = irb.get()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body = op.body.body
while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case
args = body.indices
res_buffer = body.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_buffer], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor = data_call.buffer
kernel_tensor = kernel_call.buffer
res_tensor = conv_call.buffer
if selects:
condition = selects[0].condition
condition = tvm.tir.const(1, 'int')
# create inner most block
irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition):
def _ftransform(func, mod, ctx):
env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer()
calls = []
selects = []
def _find_basics(op):
if isinstance(op, tvm.tir.BufferLoad):
elif isinstance(op, tvm.tir.Select):
def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"):
is_init = ".init" in str(op)
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
if is_init:
# create inner most block
irb = tvm.tir.ir_builder.create()
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0,
0, 1,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0,
0, 0, 0))
inner = irb.get()
args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
ret = tvm.tir.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return ret
inner = irb.get()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body = op.body.body
while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case
args = body.indices
res_buffer = body.buffer
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_buffer], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor = data_call.buffer
kernel_tensor = kernel_call.buffer
res_tensor = conv_call.buffer
def annotate_alu_coproc_scope(stmt_in):
if selects:
condition = selects[0].condition
condition = tvm.tir.const(1, 'int')
# create inner most block
irb = tvm.tir.ir_builder.create()
with irb.if_scope(condition):
dev = env.dev
dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
inner = irb.get()
args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
def AnnotateALUCoProcScope():
"""Pass to insert ALU instruction.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
env = get_env()
def _do_fold(stmt):
if _match_pragma(stmt, "alu"):
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
return irb.get()
if _match_pragma(stmt, "skip_alu"):
return tvm.tir.Evaluate(0)
return stmt
stmt_out = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def inject_alu_intrin(stmt_in):
def _ftransform(func, mod, ctx):
env = get_env()
def _do_fold(stmt):
if _match_pragma(stmt, "alu"):
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
return irb.get()
if _match_pragma(stmt, "skip_alu"):
return tvm.tir.Evaluate(0)
return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
def InjectALUIntrin():
"""Pass to inject ALU micro-ops.
stmt_in : Stmt
Input statement
stmt_out : Stmt
Transformed statement
fpass : tvm.transform.Pass
The pass
env = get_env()
idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
def _ftransform(func, mod, ctx):
env = get_env()
idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
rev_src_coeff = [src_coeff.pop()]
rev_dst_coeff = [dst_coeff.pop()]
rev_extents = []
assert src_coeff
vsrc = src_coeff.pop()
vdst = dst_coeff.pop()
vext = extents.pop()
while src_coeff:
next_src = src_coeff.pop()
next_dst = dst_coeff.pop()
next_ext = extents.pop()
if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
vext = analyzer.simplify(vext * next_ext)
vsrc = next_src
vdst = next_dst
vext = next_ext
return rev_src_coeff, rev_dst_coeff, rev_extents
if _match_pragma(stmt, "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
while isinstance(loop_body, tvm.tir.For):
loop_body = loop_body.body
nest_size += 1
# Get the src/dst arguments
dst_var = loop_body.buffer_var
dst_idx = loop_body.index
# Derive loop variables and extents
tmp_body = stmt.body
indices = []
extents = []
for _ in range(nest_size):
tmp_body = tmp_body.body
# Derive opcode
if isinstance(loop_body.value, tvm.tir.Add):
alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Sub):
alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Mul):
alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Min):
alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Max):
alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = analyzer.simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right':
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
rev_src_coeff = [src_coeff.pop()]
rev_dst_coeff = [dst_coeff.pop()]
rev_extents = []
assert src_coeff
vsrc = src_coeff.pop()
vdst = dst_coeff.pop()
vext = extents.pop()
while src_coeff:
next_src = src_coeff.pop()
next_dst = dst_coeff.pop()
next_ext = extents.pop()
if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
vext = analyzer.simplify(vext * next_ext)
vsrc = next_src
vdst = next_dst
vext = next_ext
return rev_src_coeff, rev_dst_coeff, rev_extents
if _match_pragma(stmt, "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
while isinstance(loop_body, tvm.tir.For):
loop_body = loop_body.body
nest_size += 1
# Get the src/dst arguments
dst_var = loop_body.buffer_var
dst_idx = loop_body.index
# Derive loop variables and extents
tmp_body = stmt.body
indices = []
extents = []
for _ in range(nest_size):
tmp_body = tmp_body.body
# Derive opcode
if isinstance(loop_body.value, tvm.tir.Add):
alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Sub):
alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Mul):
alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Min):
alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Max):
alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = analyzer.simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.tir.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
lhs = loop_body.value
rhs = tvm.tir.const(0, "int32")
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.tir.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.tir.const(0, "int32")
raise RuntimeError(
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients
dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
# Check if lhs/rhs is immediate
use_imm = False
imm_val = None
if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
use_imm = True
imm_val = rhs
if isinstance(lhs, tvm.tir.IntImm):
assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
use_imm = True
imm_val = lhs
if imm_val is None:
imm_val = 0
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
# Determine which side has the same coefficients
lhs_equal = True
rhs_equal = True
for i, coef in enumerate(dst_coeff):
if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
lhs_equal = False
if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
rhs_equal = False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert lhs_equal or rhs_equal
# Assign the source coefficients
if lhs_equal:
src_coeff = src_rhs_coeff
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients
dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
# Check if lhs/rhs is immediate
use_imm = False
imm_val = None
if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
use_imm = True
imm_val = rhs
if isinstance(lhs, tvm.tir.IntImm):
assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
use_imm = True
imm_val = lhs
if imm_val is None:
imm_val = 0
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
# Determine which side has the same coefficients
lhs_equal = True
rhs_equal = True
for i, coef in enumerate(dst_coeff):
if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
lhs_equal = False
if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
rhs_equal = False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert lhs_equal or rhs_equal
# Assign the source coefficients
if lhs_equal:
src_coeff = src_rhs_coeff
src_coeff = src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.ir.structural_equal(
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
src_coeff = src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.ir.structural_equal(
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
extents = extents[:-2]
src_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
irb = tvm.tir.ir_builder.create()
for idx, extent in enumerate(extents):
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
"int32", "VTAUopPush",
1, 0,
alu_opcode, use_imm, imm_val))
for extent in extents:
src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
extents = extents[:-2]
src_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
irb = tvm.tir.ir_builder.create()
for idx, extent in enumerate(extents):
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
"int32", "VTAUopLoopEnd"))
return irb.get()
return stmt
stmt_out = tvm.tir.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def debug_print(stmt):
"""A debug pass that print the stmt
stmt : Stmt
The input statement
stmt : Stmt
# pylint: disable=superfluous-parens
return stmt
"int32", "VTAUopPush",
1, 0,
alu_opcode, use_imm, imm_val))
for extent in extents:
"int32", "VTAUopLoopEnd"))
return irb.get()
return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform(
func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment