Commit 0138997f by Tianqi Chen Committed by GitHub

[BUILD] Allow inject custom pass via phase (#408)

parent f73c461f
...@@ -62,6 +62,7 @@ class BuildConfig(object): ...@@ -62,6 +62,7 @@ class BuildConfig(object):
assert self._old_scope assert self._old_scope
BuildConfig.current = self._old_scope BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig() BuildConfig.current = BuildConfig()
def build_config(**kwargs): def build_config(**kwargs):
...@@ -102,7 +103,8 @@ def build_config(**kwargs): ...@@ -102,7 +103,8 @@ def build_config(**kwargs):
Whether split the loop containing double buffer so Whether split the loop containing double buffer so
that the buffer fetching won't contain condition. that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None
phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api. Additional lowering passes to be applied before make_api.
Returns Returns
...@@ -193,11 +195,19 @@ def lower(sch, ...@@ -193,11 +195,19 @@ def lower(sch,
""" """
binds, arg_list = get_binds(args, binds) binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current cfg = BuildConfig.current
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] > 1]
# normalize schedule first # normalize schedule first
sch = sch.normalize() sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt) stmt = ir_pass.InjectPrefetch(stmt)
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
if not simple_mode: if not simple_mode:
...@@ -211,13 +221,15 @@ def lower(sch, ...@@ -211,13 +221,15 @@ def lower(sch,
cfg.auto_unroll_max_step, cfg.auto_unroll_max_step,
cfg.auto_unroll_min_depth, cfg.auto_unroll_min_depth,
cfg.unroll_explicit) cfg.unroll_explicit)
if cfg.add_lower_pass: for f in lower_phase1:
for f in cfg.add_lower_pass: stmt = f(stmt)
stmt = f(stmt) # Phase 2
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase2:
stmt = f(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment