Commit d4a46898 by xqdan Committed by Tianqi Chen

Support dump ir for each pass (#693) (#791)

* Support dump ir for each pass(#693)

* expose DumpIR

* fix comments

* fix comments
parent 079e2307
...@@ -5,18 +5,96 @@ LoweredFunc and compiled Module. ...@@ -5,18 +5,96 @@ LoweredFunc and compiled Module.
""" """
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings import warnings
import types
from . import api from . import api
from . import tensor from . import tensor
from . import schedule from . import schedule
from . import expr from . import expr
from . import ir_pass from . import ir_pass
from . import stmt as _stmt
from . import container from . import container
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray from . import ndarray
from . import target as _target from . import target as _target
class DumpIR(object):
"""Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
How to use:
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
run()
"""
scope_level = 0
def __init__(self):
self._pass_id = 0
self._recover_list = []
def decorate(self, func):
''' decorate the pass function'''
def dump(*args, **kwargs):
'''dump function'''
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv
pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump
def decorate_irpass(self):
'''decorate ir_pass and ScheduleOps'''
self._old_sgpass = schedule.ScheduleOps
schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
vset = vars(ir_pass)
k = v = 0
def recover():
vset[k] = v
for k, v in vset.items():
self._recover_list.append(recover)
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
def decorate_custompass(self):
''' decorate add_lower_pass pass in BuildConfig'''
cfg = BuildConfig.current
self._old_custom_pass = cfg.add_lower_pass
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass]
BuildConfig.current.add_lower_pass = pass_list
def enter(self):
'''only decorate outermost nest'''
if DumpIR.scope_level > 0:
return
self.decorate_irpass()
self.decorate_custompass()
self._pass_id = 0
DumpIR.scope_level += 1
def exit(self):
'''recover outermost nest'''
if DumpIR.scope_level > 1:
return
# recover decorated functions
for f in self._recover_list:
f()
schedule.ScheduleOps = self._old_sgpass
BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -37,10 +115,12 @@ class BuildConfig(object): ...@@ -37,10 +115,12 @@ class BuildConfig(object):
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": 1, "double_buffer_split_loop": 1,
"add_lower_pass": None "add_lower_pass": None,
"dump_pass_ir": False
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._old_scope = None self._old_scope = None
self._dump_ir = DumpIR()
for k, _ in kwargs.items(): for k, _ in kwargs.items():
if k not in BuildConfig.defaults: if k not in BuildConfig.defaults:
raise ValueError( raise ValueError(
...@@ -59,10 +139,14 @@ class BuildConfig(object): ...@@ -59,10 +139,14 @@ class BuildConfig(object):
attr.update(self._attr) attr.update(self._attr)
self._attr = attr self._attr = attr
BuildConfig.current = self BuildConfig.current = self
if self.dump_pass_ir is True:
self._dump_ir.enter()
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
assert self._old_scope assert self._old_scope
if self.dump_pass_ir is True:
self._dump_ir.exit()
BuildConfig.current = self._old_scope BuildConfig.current = self._old_scope
...@@ -115,6 +199,8 @@ def build_config(**kwargs): ...@@ -115,6 +199,8 @@ def build_config(**kwargs):
phase contains an integer on which optimization pass we apply the pass. phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api. Additional lowering passes to be applied before make_api.
dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
Returns Returns
------- -------
config: BuildConfig config: BuildConfig
...@@ -247,7 +333,6 @@ def lower(sch, ...@@ -247,7 +333,6 @@ def lower(sch,
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch, def build(sch,
args=None, args=None,
target=None, target=None,
......
import tvm import tvm
import os
def test_unroll_loop(): def test_unroll_loop():
dtype = 'int64' dtype = 'int64'
...@@ -24,4 +25,20 @@ def test_unroll_loop(): ...@@ -24,4 +25,20 @@ def test_unroll_loop():
if __name__ == "__main__": if __name__ == "__main__":
test_unroll_loop() with tvm.build_config(dump_pass_ir=True):
test_unroll_loop()
def end_with(*suffix):
ends = suffix
def run(s):
f = map(s.endswith, ends)
if True in f: return s
return run
file_list = os.listdir('./')
cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list)
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment