Commit 1598e329 by Wei Chen Committed by Haichen Shen

Add EtaExpand to transform API (#3406)

* Add EtaExpand to transform API

* Add test case
parent f7d15f64
......@@ -406,6 +406,15 @@ def ToANormalForm():
"""
return _transform.ToANormalForm()
def EtaExpand():
"""Add abstraction over a function
Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()
def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
......
......@@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
namespace transform {
Pass EtaExpand() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(EtaExpand(f, m));
};
return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
}
TVM_REGISTER_API("relay._transform.EtaExpand")
.set_body_typed(EtaExpand);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -15,13 +15,20 @@
# specific language governing permissions and limitations
# under the License.
from tvm import relay
import tvm.relay.module as _module
import tvm.relay.transform as _transform
def test_eta_expand_basic():
mod = relay.Module()
x = relay.var('x', 'int32')
y = relay.var('y', 'int32')
orig = relay.Function([x], x)
got = relay.ir_pass.eta_expand(orig, mod)
mod = _module.Module.from_expr(orig)
seq = _transform.Sequential([_transform.EtaExpand()])
with _transform.PassContext(opt_level=3):
mod = seq(mod)
got = mod[mod.entry_func.name_hint]
y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y))
got = relay.ir_pass.infer_type(got, mod)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment