Commit 178f48c3 by Zhi Committed by Tianqi Chen

[Relay][Transform] Support Dumping IR to help debugging (#3493)

* [Relay][Transform] Support Dumping IR to help debugging

* debugprint->printir
parent a31dd162
......@@ -540,6 +540,13 @@ TVM_DLL Pass CanonicalizeCast();
*/
TVM_DLL Pass EtaExpand();
/*!
* \brief Print the IR for a module to help debugging.
*
* \return the pass.
*/
TVM_DLL Pass PrintIR();
} // namespace transform
/*!
......
......@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast()
def PrintIR():
"""
Print the IR for a module to help debugging.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return _transform.PrintIR()
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file src/relay/pass/print_ir.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace transform {
Pass PrintIR() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m);
return m;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
TVM_REGISTER_API("relay._transform.PrintIR")
.set_body_typed(PrintIR);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -504,6 +504,62 @@ def test_sequential_with_scoping():
assert analysis.alpha_equal(zz, zexpected)
def test_print_ir():
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
relay.transform.DeadCodeElimination()
])
def redirect_output(call):
"""Redirect the C++ logging info."""
import sys
import os
import threading
stderr_fileno = sys.stderr.fileno()
stderr_save = os.dup(stderr_fileno)
stderr_pipe = os.pipe()
os.dup2(stderr_pipe[1], stderr_fileno)
os.close(stderr_pipe[1])
output = ''
def record():
nonlocal output
while True:
data = os.read(stderr_pipe[0], 1024)
if not data:
break
output += data.decode("utf-8")
t = threading.Thread(target=record)
t.start()
call()
os.close(stderr_fileno)
t.join()
os.close(stderr_pipe[0])
os.dup2(stderr_save, stderr_fileno)
os.close(stderr_save)
return output
def run_pass():
mod = relay.Module({"main": func})
with relay.build_config(opt_level=3):
mod = seq(mod)
out = redirect_output(run_pass)
assert "Dumping the module IR" in out
assert "multiply" in out
if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
......@@ -512,3 +568,4 @@ if __name__ == "__main__":
test_sequential_pass()
test_sequential_with_scoping()
test_pass_info()
test_print_ir()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment