Unverified Commit d3277874 by Tianqi Chen Committed by GitHub

[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)

parent 72f2aea2
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/support/with.h> #include <tvm/support/with.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object { ...@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */ /*! \brief Whether to partition const loop */
bool partition_const_loop = false; bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief List of passes to be injected into the low-level pipeline. */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */ /*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false; bool dump_pass_ir = false;
......
...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs): ...@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel. """Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block. This pass will check memory usage and number of threads per block.
""" """
def verify_pass(stmt): def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(stmt, kwargs) valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid: if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel") raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds): ...@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func}) return tvm.IRModule({name: func})
def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
def lower(sch, def lower(sch,
args, args,
name="main", name="main",
...@@ -190,15 +171,15 @@ def lower(sch, ...@@ -190,15 +171,15 @@ def lower(sch,
else: else:
mod = sch mod = sch
pass_list = lower_phase0
# Phase 1 # Phase 1
pass_list = [ pass_list += [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(), tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
] ]
pass_list += lower_phase1
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
...@@ -214,8 +195,8 @@ def lower(sch, ...@@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent, cfg.auto_unroll_max_extent,
cfg.unroll_explicit), cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
] ]
pass_list += lower_phase2
# Phase 3 # Phase 3
pass_list += [ pass_list += [
...@@ -225,7 +206,7 @@ def lower(sch, ...@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting: if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] pass_list += lower_phase3
# Instrument BoundCheckers # Instrument BoundCheckers
if cfg.instrument_bound_checkers: if cfg.instrument_bound_checkers:
......
...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc): ...@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") ...@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0]; BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass; std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1); CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) { for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair( add_lower_pass.push_back(std::make_pair(
args[i].operator int(), args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc())); args[i + 1].operator transform::Pass()));
} }
cfg->add_lower_pass = add_lower_pass; cfg->add_lower_pass = add_lower_pass;
}); });
......
...@@ -51,11 +51,13 @@ def test_fold_const(): ...@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data)) z = relay.add(y, relay.const(c_data))
return relay.Function([x], z) return relay.Function([x], z)
def fail(x): def FailPass():
raise RuntimeError() def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context. # the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]): with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"): with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
......
...@@ -182,7 +182,7 @@ def test_cuda_shuffle(): ...@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx) sch[c].bind(xo, thrx)
sch[c].vectorize(xi) sch[c].vectorize(xi)
def my_vectorize(stmt): def MyVectorize():
def vectorizer(op): def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized: if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32') four = tvm.tir.const(4, 'int32')
...@@ -198,9 +198,13 @@ def test_cuda_shuffle(): ...@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids) new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda') module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32') a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
......
...@@ -671,8 +671,7 @@ def test_llvm_shuffle(): ...@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x]) c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
def my_vectorize(stmt): def my_vectorize():
def vectorizer(op): def vectorizer(op):
store = op.body store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8) idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
...@@ -684,9 +683,13 @@ def test_llvm_shuffle(): ...@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c]) module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32')) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
......
...@@ -19,10 +19,10 @@ import tvm ...@@ -19,10 +19,10 @@ import tvm
from tvm import te from tvm import te
def get_verify_pass(valid, **kwargs): def get_verify_pass(valid, **kwargs):
def verify_pass(stmt): def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs) valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return stmt return f
return verify_pass return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory(): def test_shared_memory():
def check_shared_memory(dtype): def check_shared_memory(dtype):
......
...@@ -117,19 +117,20 @@ def vectorize8(op): ...@@ -117,19 +117,20 @@ def vectorize8(op):
return body return body
return None return None
def vectorize(stmt): @tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops global loops
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8) tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops: if not loops:
return stmt return sf
# The last list arugment indicates what kinds of nodes will be transformed. # The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8` # Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
return stmt
##################################################################### #####################################################################
# Glue to Lowering # Glue to Lowering
......
...@@ -14,25 +14,22 @@ ...@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-argument # pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime.""" """VTA specific buildin for runtime."""
import tvm import tvm
from . import ir_pass from . import transform
from .environment import get_env from .environment import get_env
def lift_coproc_scope(x): def EarlyRewrite():
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
"""Try to do storage rewrite in early pass.""" """Try to do storage rewrite in early pass."""
try: def _transform(mod, ctx):
return tvm.tir.ir_pass.StorageRewrite(stmt) try:
except tvm.error.TVMError: return tvm.tir.transform.StorageRewrite()(mod)
return stmt except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs): def build_config(debug_flag=0, **kwargs):
...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs): ...@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...) vta_module = tvm.build(s, ...)
""" """
env = get_env() env = get_env()
def add_debug(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern( debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode", "int32", "VTASetDebugMode",
env.dev.command_handle, env.dev.command_handle,
debug_flag) debug_flag)
return tvm.tir.stmt_seq(debug, stmt) return f.with_body(tvm.tir.stmt_seq(debug, f.body))
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy), pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, ir_pass.annotate_alu_coproc_scope), (1, transform.InjectDMAIntrin()),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), (1, transform.InjectSkipCopy()),
(1, lift_coproc_scope), (1, transform.AnnotateALUCoProcScope()),
(1, ir_pass.inject_coproc_sync), (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, early_rewrite)] (1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag: if debug_flag:
pass_list.append((1, add_debug)) pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo)) pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs) return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment