Commit 814554e0 by 雾雨魔理沙 Committed by Tianqi Chen

init (#3571)

quickfix
parent 5c410037
...@@ -248,6 +248,30 @@ class Sequential(Pass): ...@@ -248,6 +248,30 @@ class Sequential(Pass):
passes, opt_level, name, required) passes, opt_level, name, required)
def infer_type(expr, mod=None):
"""Infer the type of an expr.
Adding Function into a Module will change it's binding,
and some passes need type inference to work without binding modification.
However, InferType() work by putting stuff into a Module, thus changing all the binding.
This is an escape patch that allow type inference without binding changing.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The input module
Returns
-------
ret : tvm.relay.Expr
The output expression.
"""
return _transform.infer_type(expr, mod)
def InferType(): def InferType():
"""Infer the type of an expr. """Infer the type of an expr.
......
...@@ -824,6 +824,9 @@ Function InferType(const Function& func, ...@@ -824,6 +824,9 @@ Function InferType(const Function& func,
return Downcast<Function>(func_ret); return Downcast<Function>(func_ret);
} }
TVM_REGISTER_API("relay._transform.infer_type")
.set_body_typed<Expr(Expr, Module)>([](Expr l, Module r) { return InferType(l, r); });
namespace transform { namespace transform {
Pass InferType() { Pass InferType() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment