Commit 2c5c4da6 by Logan Weber Committed by Haichen Shen

[Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter…

[Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion (#4218)

* Fix constructor pretty printing

* Make Module::HasDef name consistent with API

* Add VM constructor compilation via eta expansion

* Lint

* Fix CI

* Fix failing test

* Address comment

* Retrigger CI

* Retrigger CI
parent 3f6b3db8
...@@ -145,6 +145,13 @@ class ModuleNode : public RelayNode { ...@@ -145,6 +145,13 @@ class ModuleNode : public RelayNode {
TVM_DLL bool ContainGlobalVar(const std::string& name) const; TVM_DLL bool ContainGlobalVar(const std::string& name) const;
/*! /*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
/*!
* \brief Lookup a global function by its variable. * \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
...@@ -199,13 +206,6 @@ class ModuleNode : public RelayNode { ...@@ -199,13 +206,6 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const; TVM_DLL TypeData LookupDef(const std::string& var) const;
/*! /*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag. * \brief Look up a constructor by its tag.
* \param tag The tag for the constructor. * \param tag The tag for the constructor.
* \return The constructor object. * \return The constructor object.
......
...@@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize" ...@@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"
TVM_DLL Pass CanonicalizeCast(); TVM_DLL Pass CanonicalizeCast();
/*! /*!
* \brief Add abstraction over a function * \brief Add abstraction over a constructor or global variable bound to a function.
* *
* For example: `square` is transformed to * For example: `square` is transformed to
* `fun x -> square x`. * `fn (%x: int32) -> int32 { square(x) }`.
* *
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details. * for more details.
* *
* \param expand_constructor Whether to expand constructors.
* \param expand_global_var Whether to expand global variables.
*
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass EtaExpand(); TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
/*! /*!
* \brief Print the IR for a module to help debugging. * \brief Print the IR for a module to help debugging.
......
...@@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) { ...@@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) {
/* /*
* Concatenates two lists. * Concatenates two lists.
*/ */
def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] { def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
let %updater = fn(%x: A, %xss: List[A]) -> List[A] { @foldr(Cons, %ys, %xs)
Cons(%x, %xss)
};
@foldr(%updater, %ys, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldr(Cons, %ys, %xs)
} }
/* /*
...@@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] { ...@@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
* Reverses a list. * Reverses a list.
*/ */
def @rev[A](%xs: List[A]) -> List[A] { def @rev[A](%xs: List[A]) -> List[A] {
let %updater = fn(%xss: List[A], %x: A) -> List[A] { @foldl(@flip(Cons), Nil, %xs)
Cons(%x, %xss)
};
@foldl(%updater, Nil, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldl(@flip(Cons), Nil, %xs)
} }
/* /*
......
...@@ -529,15 +529,23 @@ def ToCPS(expr, mod=None): ...@@ -529,15 +529,23 @@ def ToCPS(expr, mod=None):
return _transform.to_cps(expr, mod) return _transform.to_cps(expr, mod)
def EtaExpand(): def EtaExpand(expand_constructor=False, expand_global_var=False):
"""Add abstraction over a function """Add abstraction over a constructor or global variable bound to a function
Parameters
----------
expand_constructor: bool
Whether to expand constructors.
expand_global_var: bool
Whether to expand global variables.
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.relay.Pass
The registered pass that eta expands an expression. The registered pass that eta expands an expression.
""" """
return _transform.EtaExpand() return _transform.EtaExpand(expand_constructor, expand_global_var)
def ToGraphNormalForm(): def ToGraphNormalForm():
...@@ -959,6 +967,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -959,6 +967,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
return create_function_pass(pass_func) return create_function_pass(pass_func)
return create_function_pass return create_function_pass
@function_pass(opt_level=1) @function_pass(opt_level=1)
class ChangeBatch: class ChangeBatch:
""" """
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h> #include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h> #include <tvm/relay/feature.h>
...@@ -789,6 +790,16 @@ CreateInterpreter( ...@@ -789,6 +790,16 @@ CreateInterpreter(
Module mod, Module mod,
DLContext context, DLContext context,
Target target) { Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({
transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false)});
transform::PassContext pass_ctx = transform::PassContext::Current();
tvm::With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);
}
auto intrp = std::make_shared<Interpreter>(mod, context, target); auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) { auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr); auto f = DetectFeature(expr);
......
...@@ -874,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) ...@@ -874,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
pass_seqs.push_back(transform::Legalize()); pass_seqs.push_back(transform::Legalize());
} }
// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));
pass_seqs.push_back(transform::SimplifyInference()); pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0]; Expr expr = args[0];
......
...@@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) { ...@@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) {
* We will lift a function out into a global which takes the set of the free * We will lift a function out into a global which takes the set of the free
* vars and then return the new created function. * vars and then return the new created function.
*/ */
struct LambdaLifter : ExprMutator { class LambdaLifter : public ExprMutator {
Module module_; public:
explicit LambdaLifter(const Module& module) : module_(module) {} explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final { Expr VisitExpr_(const FunctionNode* func_node) final {
...@@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator { ...@@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator {
// The "inner" function should be used to generate the // The "inner" function should be used to generate the
// code for the closure. // code for the closure.
Function lifted_func; Function lifted_func;
if (free_vars.size() == 0) { if (free_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars); lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
} else { } else {
lifted_func = lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars); FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
...@@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator { ...@@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator {
auto name = GenerateName(lifted_func); auto name = GenerateName(lifted_func);
auto global = GlobalVarNode::make(name); auto global = GlobalVarNode::make(name);
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
// Add the lifted function to the module. // Add the lifted function to the module.
module_->Add(global, lifted_func); module_->Add(global, lifted_func);
}
if (free_vars.size() == 0) { if (free_vars.size() == 0) {
return std::move(global); return std::move(global);
...@@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator { ...@@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator {
} }
return module_; return module_;
} }
private:
Module module_;
}; };
} // namespace vm } // namespace vm
......
...@@ -69,7 +69,7 @@ class AlphaEqualHandler: ...@@ -69,7 +69,7 @@ class AlphaEqualHandler:
} }
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) { for (const auto& p : lhsm->type_definitions) {
if (!rhsm->HasDef(p.first->var->name_hint) || if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false; return false;
} }
......
...@@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const { ...@@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end(); return global_var_map_.find(name) != global_var_map_.end();
} }
bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name); auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end()) CHECK(it != global_var_map_.end())
...@@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { ...@@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id); return this->LookupDef(id);
} }
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) { Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag); auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end()) CHECK(it != constructor_tag_map_.end())
...@@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add") ...@@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add")
} else if (val->IsInstance<GlobalVarNode>()) { } else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val); GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->())); auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand()(mod_copy); mod_copy = transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint); auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update); mod->Add(var, Downcast<Function>(func), update);
} else { } else {
......
...@@ -669,7 +669,7 @@ class PrettyPrinter : ...@@ -669,7 +669,7 @@ class PrettyPrinter :
Doc VisitExpr_(const ConstructorNode* n) final { Doc VisitExpr_(const ConstructorNode* n) final {
Doc doc; Doc doc;
doc << n->name_hint; doc << n->name_hint;
if (n->inputs.size() != 0) { if (in_adt_def_ && n->inputs.size() != 0) {
doc << "("; doc << "(";
std::vector<Doc> inputs; std::vector<Doc> inputs;
for (Type input : n->inputs) { for (Type input : n->inputs) {
...@@ -775,6 +775,7 @@ class PrettyPrinter : ...@@ -775,6 +775,7 @@ class PrettyPrinter :
} }
Doc VisitType_(const TypeDataNode* node) final { Doc VisitType_(const TypeDataNode* node) final {
in_adt_def_ = true;
Doc doc; Doc doc;
doc << "type " << Print(node->header); doc << "type " << Print(node->header);
...@@ -802,6 +803,7 @@ class PrettyPrinter : ...@@ -802,6 +803,7 @@ class PrettyPrinter :
adt_body << ","; adt_body << ",";
} }
doc << Brace(adt_body); doc << Brace(adt_body);
in_adt_def_ = false;
return doc; return doc;
} }
...@@ -876,6 +878,8 @@ class PrettyPrinter : ...@@ -876,6 +878,8 @@ class PrettyPrinter :
TextMetaDataContext meta_; TextMetaDataContext meta_;
/*! \brief counter of temporary variable */ /*! \brief counter of temporary variable */
size_t temp_var_counter_{0}; size_t temp_var_counter_{0};
/*! \brief whether the printer is currently in an ADT definition */
bool in_adt_def_;
/*! \brief arena for dependency graph */ /*! \brief arena for dependency graph */
common::Arena arena_; common::Arena arena_;
/*! \brief dependency graph of the expr */ /*! \brief dependency graph of the expr */
......
...@@ -20,57 +20,144 @@ ...@@ -20,57 +20,144 @@
/*! /*!
* \file eta_expand.cc * \file eta_expand.cc
* *
* \brief Add abstraction over a function. For example, abs will become (fun x -> abs x). * \brief Add an abstraction over constructors and/or global variables bound to a function.
* *
*/ */
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr_functor.h>
#include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace eta_expand {
/*!
* \brief mutator to replace type variables with fresh ones, while maintaining alpha equality
*/
class TypeVarReplacer : public TypeMutator {
public:
TypeVarReplacer() : replace_map_({}) {}
Type VisitType_(const TypeVarNode* type_var_node) final {
const auto type_var = GetRef<TypeVar>(type_var_node);
if (replace_map_.find(type_var) == replace_map_.end()) {
replace_map_[type_var] = TypeVarNode::make("A", Kind::kType);
}
return replace_map_[type_var];
}
private:
/*! \brief variable replacement map to remap old type vars to fresh ones */
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> replace_map_;
};
/*!
* \brief mutator to perform eta expansion on all functions in a module
*/
class EtaExpander : public ExprMutator {
public:
explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var)
: mod_(mod),
type_var_replacer_(TypeVarReplacer()),
expand_constructor_(expand_constructor),
expand_global_var_(expand_global_var) {
CHECK(expand_constructor || expand_global_var)
<< "must expand at least one language feature";
}
Module Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
const Function func = mod_->Lookup(global_var);
const Function new_func = Downcast<Function>(VisitExpr(func));
mod_->Update(global_var, new_func);
}
return mod_;
}
Expr EtaExpand(const Expr& e, const Module& mod) { Expr VisitExpr_(const CallNode* call) final {
tvm::Array<Var> original_params; // we don't need to expand constructors when they are being called, so we
// prevent them being visited here
Expr new_op = call->op;
if (!call->op.as<ConstructorNode>()) {
new_op = VisitExpr(new_op);
}
tvm::Array<Expr> new_args;
for (const auto& arg : call->args) {
new_args.push_back(VisitExpr(arg));
}
return CallNode::make(new_op, new_args, call->attrs, call->type_args);
}
Expr VisitExpr_(const ConstructorNode* cons_node) final {
Constructor cons = GetRef<Constructor>(cons_node);
if (!expand_constructor_) {
return std::move(cons);
}
// NOTE: we only reach this case if the constructor is not being applied to any arguments
tvm::Array<Expr> params; tvm::Array<Expr> params;
tvm::Array<Var> args; for (const auto& type : cons->inputs) {
tvm::Array<TypeVar> original_type_params; Type param_type = type_var_replacer_.VisitType(type);
Type ret_type; params.push_back(VarNode::make("eta_expand_param", param_type));
}
if (e->IsInstance<GlobalVarNode>()) { tvm::Array<Type> type_params;
auto gvar_node = e.as<GlobalVarNode>(); TypeData adt_def = mod_->LookupDef(cons->belong_to);
auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node)); for (const auto& type_var : adt_def->type_vars) {
original_params = func->params; type_params.push_back(type_var_replacer_.VisitType(type_var));
original_type_params = func->type_params; }
ret_type = func->ret_type; Expr body = CallNode::make(cons, params, Attrs());
} else { Type ret_type = TypeCallNode::make(cons->belong_to, type_params);
CHECK(e->IsInstance<FunctionNode>());
auto func = GetRef<Function>(e.as<FunctionNode>()); return FunctionNode::make(
original_params = func->params; Downcast<tvm::Array<Var>>(params),
original_type_params = func->type_params; body,
ret_type = func->ret_type; ret_type,
Downcast<tvm::Array<TypeVar>>(type_params));
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
if (!expand_global_var_) {
return std::move(gvar);
} }
for (size_t i = 0; i < original_params.size(); ++i) { const auto func = mod_->Lookup(gvar);
auto var = VarNode::make("a", original_params[i]->type_annotation); tvm::Array<Expr> params;
tvm::Array<Var> args;
for (size_t i = 0; i < func->params.size(); ++i) {
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
params.push_back(var); params.push_back(var);
args.push_back(var); args.push_back(var);
} }
auto new_func = return FunctionNode::make(
FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params); args,
CallNode::make(gvar, params),
func->ret_type,
func->type_params);
}
private:
/*! \brief reference to module being expanded */
const Module mod_;
/*! \brief type variable replacer */
TypeVarReplacer type_var_replacer_;
/*! \brief whether to expand constructor nodes */
bool expand_constructor_;
/*! \brief whether to expand global variable nodes */
bool expand_global_var_;
};
return std::move(new_func); } // namespace eta_expand
}
namespace transform { namespace transform {
Pass EtaExpand() { Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Module mod, PassContext pc) {
return Downcast<Function>(EtaExpand(f, m)); return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
}; };
Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {}); return CreateModulePass(pass_func, 1, "EtaExpand", {});
return Sequential({expanded, InferType()});
} }
TVM_REGISTER_API("relay._transform.EtaExpand") TVM_REGISTER_API("relay._transform.EtaExpand")
......
...@@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
} }
Expr VisitExpr_(const ConstructorNode* op) final { Expr VisitExpr_(const ConstructorNode* op) final {
return GetRef<Constructor>(op); return AttachCheckedType(op);
} }
Expr VisitExpr_(const MatchNode* op) final { Expr VisitExpr_(const MatchNode* op) final {
......
...@@ -218,6 +218,27 @@ def test_zeros(): ...@@ -218,6 +218,27 @@ def test_zeros():
x = relay.op.zeros([], "float32") x = relay.op.zeros([], "float32")
astext(x) astext(x)
def test_unapplied_constructor():
type_def_str = r"""
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
main_def_str = r"""
def @main[A]() -> fn (A, List[A]) -> List[A] {
Cons
}
"""
mod = relay.fromtext(SEMVER + type_def_str + main_def_str)
mod_str = str(mod)
# ensure constructors are printed correctly in type definitions (with their
# signature) and as exprs (without their signature)
assert type_def_str.strip() in mod_str
assert main_def_str.strip() in mod_str
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_lstm() test_lstm()
...@@ -239,3 +260,4 @@ if __name__ == "__main__": ...@@ -239,3 +260,4 @@ if __name__ == "__main__":
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_unapplied_constructor()
...@@ -14,27 +14,70 @@ ...@@ -14,27 +14,70 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import os
import numpy as np
import tvm
from tvm import relay from tvm import relay
import tvm.relay.module as _module
import tvm.relay.transform as _transform import tvm.relay.transform as _transform
def test_eta_expand_basic(): def test_eta_expand_global_var():
x = relay.var('x', 'int32') mod = relay.fromtext(r"""
orig = relay.Function([x], x) v0.0.4
mod = _module.Module.from_expr(orig) def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
seq = _transform.Sequential([_transform.EtaExpand()]) %x
}
def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
@aux
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with _transform.PassContext(opt_level=3): with _transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
%x
}
def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
@aux(%x)
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
got = mod["main"] def test_eta_expand_constructor():
mod = relay.fromtext(r"""
v0.0.4
type List[A] {
Cons(A, List[A]),
Nil,
}
def @main[A]() -> (fn(A, List[A]) -> List[A]) {
Cons
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with _transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
type List[A] {
Cons(A, List[A]),
Nil,
}
def @main[A]() -> (fn(A, List[A]) -> List[A]) {
fn [A](%x: A, %xs: List[A]) -> List[A] {
Cons(%x, %xs)
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y))
gv = relay.GlobalVar("gv")
mod[gv] = expected
mod = _transform.InferType()(mod)
expected = mod["gv"]
assert(relay.analysis.alpha_equal(got, expected))
if __name__ == "__main__": if __name__ == '__main__':
test_eta_expand_basic() test_eta_expand_global_var()
test_eta_expand_constructor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment