Commit 2c5c4da6 by Logan Weber Committed by Haichen Shen

[Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter…

[Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion (#4218)

* Fix constructor pretty printing

* Make Module::HasDef name consistent with API

* Add VM constructor compilation via eta expansion

* Lint

* Fix CI

* Fix failing test

* Address comment

* Retrigger CI

* Retrigger CI
parent 3f6b3db8
......@@ -145,6 +145,13 @@ class ModuleNode : public RelayNode {
TVM_DLL bool ContainGlobalVar(const std::string& name) const;
/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
......@@ -199,13 +206,6 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const;
/*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
......
......@@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"
TVM_DLL Pass CanonicalizeCast();
/*!
* \brief Add abstraction over a function
* \brief Add abstraction over a constructor or global variable bound to a function.
*
* For example: `square` is transformed to
* `fun x -> square x`.
* `fn (%x: int32) -> int32 { square(x) }`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param expand_constructor Whether to expand constructors.
* \param expand_global_var Whether to expand global variables.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();
TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
/*!
* \brief Print the IR for a module to help debugging.
......
......@@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) {
/*
* Concatenates two lists.
*/
def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
Cons(%x, %xss)
};
@foldr(%updater, %ys, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldr(Cons, %ys, %xs)
@foldr(Cons, %ys, %xs)
}
/*
......@@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
* Reverses a list.
*/
def @rev[A](%xs: List[A]) -> List[A] {
let %updater = fn(%xss: List[A], %x: A) -> List[A] {
Cons(%x, %xss)
};
@foldl(%updater, Nil, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldl(@flip(Cons), Nil, %xs)
@foldl(@flip(Cons), Nil, %xs)
}
/*
......
......@@ -529,15 +529,23 @@ def ToCPS(expr, mod=None):
return _transform.to_cps(expr, mod)
def EtaExpand():
"""Add abstraction over a function
def EtaExpand(expand_constructor=False, expand_global_var=False):
"""Add abstraction over a constructor or global variable bound to a function
Parameters
----------
expand_constructor: bool
Whether to expand constructors.
expand_global_var: bool
Whether to expand global variables.
Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()
return _transform.EtaExpand(expand_constructor, expand_global_var)
def ToGraphNormalForm():
......@@ -959,6 +967,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
return create_function_pass(pass_func)
return create_function_pass
@function_pass(opt_level=1)
class ChangeBatch:
"""
......
......@@ -26,6 +26,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
......@@ -789,6 +790,16 @@ CreateInterpreter(
Module mod,
DLContext context,
Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({
transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false)});
transform::PassContext pass_ctx = transform::PassContext::Current();
tvm::With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);
}
auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr);
......
......@@ -874,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
pass_seqs.push_back(transform::Legalize());
}
// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));
pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
......
......@@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) {
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const Module& module) : module_(module) {}
Expr VisitExpr_(const FunctionNode* func_node) final {
......@@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator {
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
if (free_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
......@@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator {
auto name = GenerateName(lifted_func);
auto global = GlobalVarNode::make(name);
// Add the lifted function to the module.
module_->Add(global, lifted_func);
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
// Add the lifted function to the module.
module_->Add(global, lifted_func);
}
if (free_vars.size() == 0) {
return std::move(global);
......@@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator {
}
return module_;
}
private:
Module module_;
};
} // namespace vm
......
......@@ -69,7 +69,7 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->HasDef(p.first->var->name_hint) ||
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
......
......@@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
......@@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
......@@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add")
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand()(mod_copy);
mod_copy = transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update);
} else {
......
......@@ -669,7 +669,7 @@ class PrettyPrinter :
Doc VisitExpr_(const ConstructorNode* n) final {
Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
if (in_adt_def_ && n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
......@@ -775,6 +775,7 @@ class PrettyPrinter :
}
Doc VisitType_(const TypeDataNode* node) final {
in_adt_def_ = true;
Doc doc;
doc << "type " << Print(node->header);
......@@ -802,6 +803,7 @@ class PrettyPrinter :
adt_body << ",";
}
doc << Brace(adt_body);
in_adt_def_ = false;
return doc;
}
......@@ -876,6 +878,8 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief whether the printer is currently in an ADT definition */
bool in_adt_def_;
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
......
......@@ -20,57 +20,144 @@
/*!
* \file eta_expand.cc
*
* \brief Add abstraction over a function. For example, abs will become (fun x -> abs x).
* \brief Add an abstraction over constructors and/or global variables bound to a function.
*
*/
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
namespace eta_expand {
/*!
* \brief mutator to replace type variables with fresh ones, while maintaining alpha equality
*/
class TypeVarReplacer : public TypeMutator {
public:
TypeVarReplacer() : replace_map_({}) {}
Expr EtaExpand(const Expr& e, const Module& mod) {
tvm::Array<Var> original_params;
tvm::Array<Expr> params;
tvm::Array<Var> args;
tvm::Array<TypeVar> original_type_params;
Type ret_type;
if (e->IsInstance<GlobalVarNode>()) {
auto gvar_node = e.as<GlobalVarNode>();
auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
original_params = func->params;
original_type_params = func->type_params;
ret_type = func->ret_type;
} else {
CHECK(e->IsInstance<FunctionNode>());
auto func = GetRef<Function>(e.as<FunctionNode>());
original_params = func->params;
original_type_params = func->type_params;
ret_type = func->ret_type;
Type VisitType_(const TypeVarNode* type_var_node) final {
const auto type_var = GetRef<TypeVar>(type_var_node);
if (replace_map_.find(type_var) == replace_map_.end()) {
replace_map_[type_var] = TypeVarNode::make("A", Kind::kType);
}
return replace_map_[type_var];
}
for (size_t i = 0; i < original_params.size(); ++i) {
auto var = VarNode::make("a", original_params[i]->type_annotation);
params.push_back(var);
args.push_back(var);
private:
/*! \brief variable replacement map to remap old type vars to fresh ones */
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> replace_map_;
};
/*!
* \brief mutator to perform eta expansion on all functions in a module
*/
class EtaExpander : public ExprMutator {
public:
explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var)
: mod_(mod),
type_var_replacer_(TypeVarReplacer()),
expand_constructor_(expand_constructor),
expand_global_var_(expand_global_var) {
CHECK(expand_constructor || expand_global_var)
<< "must expand at least one language feature";
}
auto new_func =
FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params);
Module Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
const Function func = mod_->Lookup(global_var);
const Function new_func = Downcast<Function>(VisitExpr(func));
mod_->Update(global_var, new_func);
}
return mod_;
}
return std::move(new_func);
}
Expr VisitExpr_(const CallNode* call) final {
// we don't need to expand constructors when they are being called, so we
// prevent them being visited here
Expr new_op = call->op;
if (!call->op.as<ConstructorNode>()) {
new_op = VisitExpr(new_op);
}
tvm::Array<Expr> new_args;
for (const auto& arg : call->args) {
new_args.push_back(VisitExpr(arg));
}
return CallNode::make(new_op, new_args, call->attrs, call->type_args);
}
Expr VisitExpr_(const ConstructorNode* cons_node) final {
Constructor cons = GetRef<Constructor>(cons_node);
if (!expand_constructor_) {
return std::move(cons);
}
// NOTE: we only reach this case if the constructor is not being applied to any arguments
tvm::Array<Expr> params;
for (const auto& type : cons->inputs) {
Type param_type = type_var_replacer_.VisitType(type);
params.push_back(VarNode::make("eta_expand_param", param_type));
}
tvm::Array<Type> type_params;
TypeData adt_def = mod_->LookupDef(cons->belong_to);
for (const auto& type_var : adt_def->type_vars) {
type_params.push_back(type_var_replacer_.VisitType(type_var));
}
Expr body = CallNode::make(cons, params, Attrs());
Type ret_type = TypeCallNode::make(cons->belong_to, type_params);
return FunctionNode::make(
Downcast<tvm::Array<Var>>(params),
body,
ret_type,
Downcast<tvm::Array<TypeVar>>(type_params));
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
if (!expand_global_var_) {
return std::move(gvar);
}
const auto func = mod_->Lookup(gvar);
tvm::Array<Expr> params;
tvm::Array<Var> args;
for (size_t i = 0; i < func->params.size(); ++i) {
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
params.push_back(var);
args.push_back(var);
}
return FunctionNode::make(
args,
CallNode::make(gvar, params),
func->ret_type,
func->type_params);
}
private:
/*! \brief reference to module being expanded */
const Module mod_;
/*! \brief type variable replacer */
TypeVarReplacer type_var_replacer_;
/*! \brief whether to expand constructor nodes */
bool expand_constructor_;
/*! \brief whether to expand global variable nodes */
bool expand_global_var_;
};
} // namespace eta_expand
namespace transform {
Pass EtaExpand() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(EtaExpand(f, m));
};
Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {});
return Sequential({expanded, InferType()});
Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module mod, PassContext pc) {
return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
};
return CreateModulePass(pass_func, 1, "EtaExpand", {});
}
TVM_REGISTER_API("relay._transform.EtaExpand")
......
......@@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
}
Expr VisitExpr_(const ConstructorNode* op) final {
return GetRef<Constructor>(op);
return AttachCheckedType(op);
}
Expr VisitExpr_(const MatchNode* op) final {
......
......@@ -218,6 +218,27 @@ def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
def test_unapplied_constructor():
type_def_str = r"""
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
main_def_str = r"""
def @main[A]() -> fn (A, List[A]) -> List[A] {
Cons
}
"""
mod = relay.fromtext(SEMVER + type_def_str + main_def_str)
mod_str = str(mod)
# ensure constructors are printed correctly in type definitions (with their
# signature) and as exprs (without their signature)
assert type_def_str.strip() in mod_str
assert main_def_str.strip() in mod_str
if __name__ == "__main__":
do_print[0] = True
test_lstm()
......@@ -239,3 +260,4 @@ if __name__ == "__main__":
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_unapplied_constructor()
......@@ -14,27 +14,70 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import numpy as np
import tvm
from tvm import relay
import tvm.relay.module as _module
import tvm.relay.transform as _transform
def test_eta_expand_basic():
x = relay.var('x', 'int32')
orig = relay.Function([x], x)
mod = _module.Module.from_expr(orig)
seq = _transform.Sequential([_transform.EtaExpand()])
def test_eta_expand_global_var():
mod = relay.fromtext(r"""
v0.0.4
def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
%x
}
def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
@aux
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with _transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
%x
}
def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
@aux(%x)
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
got = mod["main"]
def test_eta_expand_constructor():
mod = relay.fromtext(r"""
v0.0.4
type List[A] {
Cons(A, List[A]),
Nil,
}
def @main[A]() -> (fn(A, List[A]) -> List[A]) {
Cons
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with _transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
type List[A] {
Cons(A, List[A]),
Nil,
}
def @main[A]() -> (fn(A, List[A]) -> List[A]) {
fn [A](%x: A, %xs: List[A]) -> List[A] {
Cons(%x, %xs)
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y))
gv = relay.GlobalVar("gv")
mod[gv] = expected
mod = _transform.InferType()(mod)
expected = mod["gv"]
assert(relay.analysis.alpha_equal(got, expected))
if __name__ == "__main__":
test_eta_expand_basic()
if __name__ == '__main__':
test_eta_expand_global_var()
test_eta_expand_constructor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment