Commit d4a46898 by xqdan Committed by Tianqi Chen

Support dump ir for each pass (#693) (#791)

* Support dump ir for each pass(#693)

* expose DumpIR

* fix comments

* fix comments
parent 079e2307
......@@ -5,18 +5,96 @@ LoweredFunc and compiled Module.
"""
from __future__ import absolute_import as _abs
import warnings
import types
from . import api
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
class DumpIR(object):
"""Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
How to use:
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
run()
"""
scope_level = 0
def __init__(self):
self._pass_id = 0
self._recover_list = []
def decorate(self, func):
''' decorate the pass function'''
def dump(*args, **kwargs):
'''dump function'''
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv
pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump
def decorate_irpass(self):
'''decorate ir_pass and ScheduleOps'''
self._old_sgpass = schedule.ScheduleOps
schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
vset = vars(ir_pass)
k = v = 0
def recover():
vset[k] = v
for k, v in vset.items():
self._recover_list.append(recover)
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
def decorate_custompass(self):
''' decorate add_lower_pass pass in BuildConfig'''
cfg = BuildConfig.current
self._old_custom_pass = cfg.add_lower_pass
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass]
BuildConfig.current.add_lower_pass = pass_list
def enter(self):
'''only decorate outermost nest'''
if DumpIR.scope_level > 0:
return
self.decorate_irpass()
self.decorate_custompass()
self._pass_id = 0
DumpIR.scope_level += 1
def exit(self):
'''recover outermost nest'''
if DumpIR.scope_level > 1:
return
# recover decorated functions
for f in self._recover_list:
f()
schedule.ScheduleOps = self._old_sgpass
BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1
class BuildConfig(object):
"""Configuration scope to set a build config option.
......@@ -37,10 +115,12 @@ class BuildConfig(object):
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1,
"add_lower_pass": None
"add_lower_pass": None,
"dump_pass_ir": False
}
def __init__(self, **kwargs):
self._old_scope = None
self._dump_ir = DumpIR()
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
......@@ -59,10 +139,14 @@ class BuildConfig(object):
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
if self.dump_pass_ir is True:
self._dump_ir.enter()
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
if self.dump_pass_ir is True:
self._dump_ir.exit()
BuildConfig.current = self._old_scope
......@@ -115,6 +199,8 @@ def build_config(**kwargs):
phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api.
dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
Returns
-------
config: BuildConfig
......@@ -247,7 +333,6 @@ def lower(sch,
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch,
args=None,
target=None,
......
import tvm
import os
def test_unroll_loop():
dtype = 'int64'
......@@ -24,4 +25,20 @@ def test_unroll_loop():
if __name__ == "__main__":
test_unroll_loop()
with tvm.build_config(dump_pass_ir=True):
test_unroll_loop()
def end_with(*suffix):
ends = suffix
def run(s):
f = map(s.endswith, ends)
if True in f: return s
return run
file_list = os.listdir('./')
cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list)
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment