Commit 1598e329 by Wei Chen Committed by Haichen Shen

Add EtaExpand to transform API (#3406)

* Add EtaExpand to transform API

* Add test case
parent f7d15f64
...@@ -406,6 +406,15 @@ def ToANormalForm(): ...@@ -406,6 +406,15 @@ def ToANormalForm():
""" """
return _transform.ToANormalForm() return _transform.ToANormalForm()
def EtaExpand():
"""Add abstraction over a function
Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()
def ToGraphNormalForm(): def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression """Turn A Normal Form expression into Graph Normal Form expression
......
...@@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) { ...@@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand); TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
namespace transform {
Pass EtaExpand() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(EtaExpand(f, m));
};
return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
}
TVM_REGISTER_API("relay._transform.EtaExpand")
.set_body_typed(EtaExpand);
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -15,13 +15,20 @@ ...@@ -15,13 +15,20 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from tvm import relay from tvm import relay
import tvm.relay.module as _module
import tvm.relay.transform as _transform
def test_eta_expand_basic(): def test_eta_expand_basic():
mod = relay.Module()
x = relay.var('x', 'int32') x = relay.var('x', 'int32')
y = relay.var('y', 'int32')
orig = relay.Function([x], x) orig = relay.Function([x], x)
got = relay.ir_pass.eta_expand(orig, mod) mod = _module.Module.from_expr(orig)
seq = _transform.Sequential([_transform.EtaExpand()])
with _transform.PassContext(opt_level=3):
mod = seq(mod)
got = mod[mod.entry_func.name_hint]
y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y)) expected = relay.Function([y], orig(y))
got = relay.ir_pass.infer_type(got, mod) got = relay.ir_pass.infer_type(got, mod)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment