Commit 4e2d707f by Jared Roesch Committed by Zhi

[Relay][Module] Refactor the way we interface between different modules of Relay. (#3906)

* Module refactor

* Add load module

* Add support for idempotent import

* Tweak load paths

* Move path around

* Expose C++ import functions in Python

* Fix import

* Add doc string

* Fix

* Fix lint

* Fix lint

* Fix test failure

* Add type solver

* Fix lint
parent c31e7771
......@@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node);
std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -33,6 +33,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
namespace relay {
......@@ -185,6 +186,23 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void Update(const Module& other);
/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
*
* \note The path resolution behavior is standard,
* if abosolute will be the absolute file, if
* relative it will be resovled against the current
* working directory.
*/
TVM_DLL void Import(const std::string& path);
/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
TVM_DLL void ImportFromStd(const std::string& path);
/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
......@@ -222,6 +240,11 @@ class ModuleNode : public RelayNode {
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;
/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
};
struct Module : public NodeRef {
......@@ -235,6 +258,12 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};
/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);
} // namespace relay
} // namespace tvm
......
......@@ -410,6 +410,12 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
/*!
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual Module GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
......
......@@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
import os
from sys import setrecursionlimit
from ..api import register_func
from . import base
......
......@@ -16,13 +16,22 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global module storing everything needed to interpret or compile a Relay program."""
import os
from .base import register_relay_node, RelayNode
from .. import register_func
from .._ffi import base as _base
from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@register_func("tvm.relay.std_path")
def _std_path():
global __STD_PATH__
return __STD_PATH__
@register_relay_node
class Module(RelayNode):
"""The global Relay module containing collection of functions.
......@@ -202,3 +211,9 @@ class Module(RelayNode):
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
def _import(self, file_to_import):
return _module.Module_Import(self, file_to_import)
def import_from_std(self, file_to_import):
return _module.Module_ImportFromStd(self, file_to_import)
......@@ -16,14 +16,11 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
import os
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
class Prelude:
......@@ -479,10 +476,8 @@ class Prelude:
Parses the portions of the Prelude written in Relay's text format and adds
them to the module.
"""
prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly")
with open(prelude_file) as prelude:
prelude = fromtext(prelude.read())
self.mod.update(prelude)
# TODO(@jroesch): we should remove this helper when we port over prelude
self.mod.import_from_std("prelude.rly")
self.id = self.mod.get_global_var("id")
self.compose = self.mod.get_global_var("compose")
......
......@@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}
TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
......
......@@ -26,6 +26,8 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <sstream>
#include <fstream>
#include <unordered_set>
namespace tvm {
namespace relay {
......@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
n->constructor_tag_map_ = {};
for (const auto& kv : n->functions) {
// set global var map
......@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
}
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
......@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
RegisterConstructors(var, type);
......@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
return mod;
}
void ModuleNode::Import(const std::string& path) {
LOG(INFO) << "Importing: " << path;
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
std::fstream src_file(path, std::fstream::in);
std::string file_contents {
std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>() };
auto mod_to_import = FromText(file_contents, path);
for (auto func : mod_to_import->functions) {
this->Add(func.first, func.second, false);
}
for (auto type : mod_to_import->type_definitions) {
this->AddDef(type.first, type.second);
}
}
}
void ModuleNode::ImportFromStd(const std::string& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
}
Module FromText(const std::string& source, const std::string& source_name) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
Module mod = (*f)(source, source_name);
return mod;
}
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module")
......@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
mod->Update(from);
});
TVM_REGISTER_API("relay._module.Module_Import")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->Import(path);
});
TVM_REGISTER_API("relay._module.Module_ImportFromStd")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModuleNode>(
[](const ModuleNode *node, tvm::IRPrinter *p) {
......
......@@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
explicit TypeInferencer(Module mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, &this->err_reporter) {
err_reporter(), solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
}
// inference the type of expr.
......@@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) {
AllCheckTypePopulated().VisitExpr(e);
}
Expr InferType(const Expr& expr, const Module& mod_ref) {
if (!mod_ref.defined()) {
Module mod = ModuleNode::FromExpr(expr);
// NB(@jroesch): By adding the expression to the module we will
// type check it anyway; afterwards we can just recover type
// from the type-checked function to avoid doing unnecessary work.
Function func = mod->Lookup("main");
// FromExpr wraps a naked expression as a function, we will unbox
// it here.
if (expr.as<FunctionNode>()) {
return std::move(func);
} else {
return func->body;
}
} else {
auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
Expr InferType(const Expr& expr, const Module& mod) {
auto main = mod->GetGlobalVar("main");
auto inferencer = TypeInferencer(mod, main);
auto e = inferencer.Infer(expr);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod_ref);
auto free_tvars = FreeTypeVars(e, mod);
CHECK(free_tvars.size() == 0)
<< "Found unbound type variables in " << e << ": " << free_tvars;
EnsureCheckedType(e);
return e;
}
}
Function InferType(const Function& func,
const Module& mod,
const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
func_copy->checked_type_ = func_copy->func_type_annotation();
mod->AddUnchecked(var, func_copy);
......
......@@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode {
location = ref;
}
TVM_DLL Module GetModule() final {
return this->solver_->module_;
}
private:
/*! \brief The location to report unification errors at. */
mutable NodeRef location;
......@@ -526,10 +530,13 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
};
// constructor
TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module,
ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter) {
err_reporter_(err_reporter),
module_(module) {
CHECK(module_.defined()) << "internal error: module must be defined";
}
// destructor
......@@ -653,18 +660,22 @@ TVM_REGISTER_API("relay._analysis._test_type_solver")
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter();
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
auto module = ModuleNode::make({}, {});
auto dummy_fn_name = GlobalVarNode::make("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() {
return solver->Solve();
});
} else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
return TypedPackedFunc<Type(Type, Type)>(
[module, solver, err_reporter](Type lhs, Type rhs) {
auto res = solver->Unify(lhs, rhs, lhs);
if (err_reporter->AnyErrors()) {
err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
err_reporter->RenderErrors(module, true);
}
return res;
});
......
......@@ -63,7 +63,7 @@ using common::LinkedList;
*/
class TypeSolver {
public:
TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
......@@ -179,6 +179,8 @@ class TypeSolver {
GlobalVar current_func;
/*! \brief Error reporting. */
ErrorReporter* err_reporter_;
/*! \brief The module. */
Module module_;
/*!
* \brief GetTypeNode that is corresponds to t.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment