Commit 4e2d707f by Jared Roesch Committed by Zhi

[Relay][Module] Refactor the way we interface between different modules of Relay. (#3906)

* Module refactor

* Add load module

* Add support for idempotent import

* Tweak load paths

* Move path around

* Expose C++ import functions in Python

* Fix import

* Add doc string

* Fix

* Fix lint

* Fix lint

* Fix test failure

* Add type solver

* Fix lint
parent c31e7771
...@@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node); ...@@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node);
std::string AsText(const NodeRef& node, std::string AsText(const NodeRef& node,
bool show_meta_data = true, bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -185,6 +186,23 @@ class ModuleNode : public RelayNode { ...@@ -185,6 +186,23 @@ class ModuleNode : public RelayNode {
*/ */
TVM_DLL void Update(const Module& other); TVM_DLL void Update(const Module& other);
/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
*
* \note The path resolution behavior is standard,
* if abosolute will be the absolute file, if
* relative it will be resovled against the current
* working directory.
*/
TVM_DLL void Import(const std::string& path);
/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
TVM_DLL void ImportFromStd(const std::string& path);
/*! \brief Construct a module from a standalone expression. /*! \brief Construct a module from a standalone expression.
* *
* Allows one to optionally pass a global function map and * Allows one to optionally pass a global function map and
...@@ -222,6 +240,11 @@ class ModuleNode : public RelayNode { ...@@ -222,6 +240,11 @@ class ModuleNode : public RelayNode {
* for convenient access * for convenient access
*/ */
std::unordered_map<int32_t, Constructor> constructor_tag_map_; std::unordered_map<int32_t, Constructor> constructor_tag_map_;
/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
}; };
struct Module : public NodeRef { struct Module : public NodeRef {
...@@ -235,6 +258,12 @@ struct Module : public NodeRef { ...@@ -235,6 +258,12 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode; using ContainerType = ModuleNode;
}; };
/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -410,6 +410,12 @@ class TypeReporterNode : public Node { ...@@ -410,6 +410,12 @@ class TypeReporterNode : public Node {
*/ */
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
/*!
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual Module GetModule() = 0;
// solver is not serializable. // solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {} void VisitAttrs(tvm::AttrVisitor* v) final {}
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import from __future__ import absolute_import
import os
from sys import setrecursionlimit from sys import setrecursionlimit
from ..api import register_func from ..api import register_func
from . import base from . import base
......
...@@ -16,13 +16,22 @@ ...@@ -16,13 +16,22 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global module storing everything needed to interpret or compile a Relay program.""" """A global module storing everything needed to interpret or compile a Relay program."""
import os
from .base import register_relay_node, RelayNode from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base from .._ffi import base as _base
from . import _make from . import _make
from . import _module from . import _module
from . import expr as _expr from . import expr as _expr
from . import ty as _ty from . import ty as _ty
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__
@register_relay_node @register_relay_node
class Module(RelayNode): class Module(RelayNode):
"""The global Relay module containing collection of functions. """The global Relay module containing collection of functions.
...@@ -202,3 +211,9 @@ class Module(RelayNode): ...@@ -202,3 +211,9 @@ class Module(RelayNode):
funcs = functions if functions is not None else {} funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {} defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs) return _module.Module_FromExpr(expr, funcs, defs)
def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import)
def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import)
...@@ -16,14 +16,11 @@ ...@@ -16,14 +16,11 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions.""" """A prelude containing useful global functions and ADT definitions."""
import os
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module from .module import Module
class Prelude: class Prelude:
...@@ -479,12 +476,10 @@ class Prelude: ...@@ -479,12 +476,10 @@ class Prelude:
Parses the portions of the Prelude written in Relay's text format and adds Parses the portions of the Prelude written in Relay's text format and adds
them to the module. them to the module.
""" """
prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly") # TODO(@jroesch): we should remove this helper when we port over prelude
with open(prelude_file) as prelude: self.mod.import_from_std("prelude.rly")
prelude = fromtext(prelude.read()) self.id = self.mod.get_global_var("id")
self.mod.update(prelude) self.compose = self.mod.get_global_var("compose")
self.id = self.mod.get_global_var("id")
self.compose = self.mod.get_global_var("compose")
def __init__(self, mod=None): def __init__(self, mod=None):
......
...@@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { ...@@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
} }
} }
TVM_REGISTER_API("relay._expr.Bind") TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0]; NodeRef input = args[0];
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <sstream> #include <sstream>
#include <fstream>
#include <unordered_set>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
auto n = make_node<ModuleNode>(); auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs); n->type_definitions = std::move(global_type_defs);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
n->constructor_tag_map_ = {};
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
// set global var map // set global var map
...@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, ...@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
} }
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name); auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end()) CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module"; << "Cannot find global type var " << name << " in the Module";
...@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { ...@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
// set global type var map // set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint)) CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint; << "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var); global_type_var_map_.Set(var->var->name_hint, var);
RegisterConstructors(var, type); RegisterConstructors(var, type);
...@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr( ...@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
return mod; return mod;
} }
void ModuleNode::Import(const std::string& path) {
LOG(INFO) << "Importing: " << path;
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
std::fstream src_file(path, std::fstream::in);
std::string file_contents {
std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>() };
auto mod_to_import = FromText(file_contents, path);
for (auto func : mod_to_import->functions) {
this->Add(func.first, func.second, false);
}
for (auto type : mod_to_import->type_definitions) {
this->AddDef(type.first, type.second);
}
}
}
void ModuleNode::ImportFromStd(const std::string& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
}
Module FromText(const std::string& source, const std::string& source_name) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
Module mod = (*f)(source, source_name);
return mod;
}
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
...@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update") ...@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
mod->Update(from); mod->Update(from);
}); });
TVM_REGISTER_API("relay._module.Module_Import")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->Import(path);
});
TVM_REGISTER_API("relay._module.Module_ImportFromStd")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModuleNode>( .set_dispatch<ModuleNode>(
[](const ModuleNode *node, tvm::IRPrinter *p) { [](const ModuleNode *node, tvm::IRPrinter *p) {
......
...@@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
explicit TypeInferencer(Module mod, GlobalVar current_func) explicit TypeInferencer(Module mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func), : mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, &this->err_reporter) { err_reporter(), solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
} }
// inference the type of expr. // inference the type of expr.
...@@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) { ...@@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) {
AllCheckTypePopulated().VisitExpr(e); AllCheckTypePopulated().VisitExpr(e);
} }
Expr InferType(const Expr& expr, const Module& mod_ref) { Expr InferType(const Expr& expr, const Module& mod) {
if (!mod_ref.defined()) { auto main = mod->GetGlobalVar("main");
Module mod = ModuleNode::FromExpr(expr); auto inferencer = TypeInferencer(mod, main);
// NB(@jroesch): By adding the expression to the module we will auto e = inferencer.Infer(expr);
// type check it anyway; afterwards we can just recover type CHECK(WellFormed(e));
// from the type-checked function to avoid doing unnecessary work. auto free_tvars = FreeTypeVars(e, mod);
CHECK(free_tvars.size() == 0)
Function func = mod->Lookup("main"); << "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
// FromExpr wraps a naked expression as a function, we will unbox return e;
// it here.
if (expr.as<FunctionNode>()) {
return std::move(func);
} else {
return func->body;
}
} else {
auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod_ref);
CHECK(free_tvars.size() == 0)
<< "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
return e;
}
} }
Function InferType(const Function& func, Function InferType(const Function& func,
const Module& mod, const Module& mod,
const GlobalVar& var) { const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation(); func_copy->checked_type_ = func_copy->func_type_annotation();
mod->AddUnchecked(var, func_copy); mod->AddUnchecked(var, func_copy);
......
...@@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode { ...@@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode {
location = ref; location = ref;
} }
TVM_DLL Module GetModule() final {
return this->solver_->module_;
}
private: private:
/*! \brief The location to report unification errors at. */ /*! \brief The location to report unification errors at. */
mutable NodeRef location; mutable NodeRef location;
...@@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { ...@@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
}; };
// constructor // constructor
TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter) TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module,
: reporter_(make_node<Reporter>(this)), ErrorReporter* err_reporter)
current_func(current_func), : reporter_(make_node<Reporter>(this)),
err_reporter_(err_reporter) { current_func(current_func),
err_reporter_(err_reporter),
module_(module) {
CHECK(module_.defined()) << "internal error: module must be defined";
} }
// destructor // destructor
...@@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") ...@@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter(); ErrorReporter *err_reporter = new ErrorReporter();
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter); auto module = ModuleNode::make({}, {});
auto dummy_fn_name = GlobalVarNode::make("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
auto mod = [solver, err_reporter](std::string name) -> PackedFunc { auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
if (name == "Solve") { if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() { return TypedPackedFunc<bool()>([solver]() {
return solver->Solve(); return solver->Solve();
}); });
} else if (name == "Unify") { } else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) { return TypedPackedFunc<Type(Type, Type)>(
[module, solver, err_reporter](Type lhs, Type rhs) {
auto res = solver->Unify(lhs, rhs, lhs); auto res = solver->Unify(lhs, rhs, lhs);
if (err_reporter->AnyErrors()) { if (err_reporter->AnyErrors()) {
err_reporter->RenderErrors(ModuleNode::make({}, {}), true); err_reporter->RenderErrors(module, true);
} }
return res; return res;
}); });
......
...@@ -63,7 +63,7 @@ using common::LinkedList; ...@@ -63,7 +63,7 @@ using common::LinkedList;
*/ */
class TypeSolver { class TypeSolver {
public: public:
TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter); TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter);
~TypeSolver(); ~TypeSolver();
/*! /*!
* \brief Add a type constraint to the solver. * \brief Add a type constraint to the solver.
...@@ -179,6 +179,8 @@ class TypeSolver { ...@@ -179,6 +179,8 @@ class TypeSolver {
GlobalVar current_func; GlobalVar current_func;
/*! \brief Error reporting. */ /*! \brief Error reporting. */
ErrorReporter* err_reporter_; ErrorReporter* err_reporter_;
/*! \brief The module. */
Module module_;
/*! /*!
* \brief GetTypeNode that is corresponds to t. * \brief GetTypeNode that is corresponds to t.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment