Unverified Commit c69092ae by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Unified IR IRModule structure. (#4699)

This PR brings relay::Module as the unified IRModule structure.
IRModule will be used as the basic unit for transformations
through out the stack.

- Rename relay::Module -> IRModule
- Move relay/module.h -> ir/module.h
- ModuleNode::FromExpr -> IRModule::FromExpr
- FromText -> IRModule::FromText
parent bd17baa2
......@@ -353,7 +353,7 @@ registration.
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
// Create a module for optimization.
auto mod = relay::ModuleNode::FromExpr(fx);
auto mod = IRModule::FromExpr(fx);
// Create a sequential pass.
tvm::Array<relay::transform::Pass> pass_seqs{
......
......@@ -18,55 +18,41 @@
*/
/*!
* \file tvm/relay/module.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
* \file tvm/ir/module.h
* \brief IRModule that holds the functions and type definitions.
*/
#ifndef TVM_RELAY_MODULE_H_
#define TVM_RELAY_MODULE_H_
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#ifndef TVM_IR_MODULE_H_
#define TVM_IR_MODULE_H_
#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/adt.h>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
namespace relay {
struct Module;
/*! \brief The global environment of Relay programs.
*
* The global environment contains the global
* information needed to compile a Relay program.
*
* It contains all global functions, and configuration
* options.
class IRModule;
/*!
* \brief IRModule that holds functions and type definitions.
*
* Many operations require access to the global
* Module. We pass the Module by value
* in a functional style as an explicit argument,
* but we mutate the Module while optimizing
* Relay programs.
* IRModule is the basic unit for all IR transformations across the stack.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* a Module while auto-tuning.
* Many operations require access to the global IRModule.
* We pass the IRModule by value in a functional style as an explicit argument,
* but we mutate the Module while optimizing programs.
* \sa IRModule
*/
class ModuleNode : public RelayNode {
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
ModuleNode() {}
IRModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("functions", &functions);
......@@ -75,10 +61,6 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_);
}
TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
......@@ -219,7 +201,7 @@ class ModuleNode : public RelayNode {
* functions in another environment.
* \param other The other environment.
*/
TVM_DLL void Update(const Module& other);
TVM_DLL void Update(const IRModule& other);
/*!
* \brief Import Relay code from the file at path.
......@@ -243,24 +225,8 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
/*! \brief Helper function for registering a typedef's constructors */
......@@ -285,27 +251,62 @@ class ModuleNode : public RelayNode {
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
friend class IRModule;
};
struct Module : public ObjectRef {
Module() {}
explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {}
ModuleNode* operator->() const {
return static_cast<ModuleNode*>(get_mutable());
/*!
* \brief Managed reference class to IRModuleNode.
* \sa IRModuleNode
*/
class IRModule : public ObjectRef {
public:
/*!
* \brief constructor
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
* \brief constructor
* \param n The object pointer.
*/
explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
IRModuleNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}
/*!
* \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
using ContainerType = ModuleNode;
/*!
* \brief Parse text format source file into an IRModule.
* \param text A string of Relay source code.
* \param source_path The path to the source file.
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};
/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_MODULE_H_
#endif // TVM_IR_MODULE_H_
......@@ -30,9 +30,8 @@
namespace tvm {
// TODO(tqchen): remove after migrate Module to ir.
namespace relay {
struct Module;
}
class IRModule;
/*!
* \brief reporter that reports back to the
......@@ -76,7 +75,7 @@ class TypeReporterNode : public Object {
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual relay::Module GetModule() = 0;
TVM_DLL virtual IRModule GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {}
......
......@@ -71,7 +71,7 @@ Stmt CanonicalSimplify(Stmt stmt,
* \return Canonicalized expression.
*/
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Deep compare lhs and rhs
......
......@@ -26,7 +26,7 @@
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
......@@ -49,7 +49,7 @@ namespace relay {
*
* \return The kind of the passed type.
*/
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
/*!
* \brief Check whether an expression is constant.
......@@ -188,7 +188,7 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*
* \return List of free vars, in the PostDFS order visited by expr.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod);
/*!
* \brief Get free TypeVars from type t.
......@@ -201,7 +201,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const IRModule& mod);
/*!
* \brief Get all bound type variables from expression expr.
......@@ -214,7 +214,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod);
/*!
* \brief Get all bound type variables from type t.
......@@ -227,7 +227,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const IRModule& mod);
/*!
* \brief Get all type variables in expression expr.
......@@ -237,7 +237,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*
* \return List of type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
/*!
* \brief Get all type variables in type t.
......@@ -247,7 +247,7 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*
* \return List of type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);
/*!
* \brief Collect the device mapping information of each expression.
......@@ -277,7 +277,7 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
* \return Returns a list of cases (as patterns) that are not handled by the match
* expression.
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
......
......@@ -106,9 +106,6 @@ class Id : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
struct Module;
} // namespace relay
} // namespace tvm
......
......@@ -24,13 +24,16 @@
#ifndef TVM_RELAY_ERROR_H_
#define TVM_RELAY_ERROR_H_
#include <tvm/ir/module.h>
#include <string>
#include <vector>
#include <sstream>
#include <unordered_map>
#include "./base.h"
#include "./expr.h"
#include "./module.h"
namespace tvm {
namespace relay {
......@@ -146,7 +149,7 @@ class ErrorReporter {
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
void RenderErrors(const Module& module, bool use_color = true);
void RenderErrors(const IRModule& module, bool use_color = true);
inline bool AnyErrors() {
return errors_.size() != 0;
......
......@@ -26,6 +26,8 @@
#include <tvm/node/container.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
#include <bitset>
namespace tvm {
......@@ -141,7 +143,6 @@ class FeatureSet {
*/
FeatureSet DetectFeature(const RelayExpr& expr);
struct Module;
/*!
* \brief Calculate the feature of the program.
*
......@@ -149,7 +150,7 @@ struct Module;
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Module& mod);
FeatureSet DetectFeature(const IRModule& mod);
/*!
* \brief Calculate the feature of the program.
......@@ -159,7 +160,7 @@ FeatureSet DetectFeature(const Module& mod);
*
* \return The FeatureSet.
*/
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) {
inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}
......
......@@ -35,7 +35,7 @@
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
......@@ -62,7 +62,7 @@ namespace relay {
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
CreateInterpreter(IRModule mod, DLContext context, Target target);
/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;
......
......@@ -61,7 +61,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
......@@ -236,7 +236,7 @@ class PassNode : public RelayNode {
*
* \return The transformed module.
*/
Module operator()(const Module& mod) const {
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
}
......@@ -248,8 +248,8 @@ class PassNode : public RelayNode {
*
* \return The transformed module.
*/
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {}
......@@ -266,7 +266,7 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
Module operator()(const Module& mod) const {
IRModule operator()(const IRModule& mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
......@@ -279,8 +279,8 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
Module operator()(const Module& mod,
const PassContext& pass_ctx) const {
IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
......@@ -329,7 +329,7 @@ class Sequential : public Pass {
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
......@@ -345,7 +345,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, Module, PassContext)>& pass_func,
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
......@@ -624,7 +624,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
* \note this function mutates mod and is not thread-safe.
*/
TVM_DLL Function InferType(const Function& f,
const Module& mod,
const IRModule& mod,
const GlobalVar& var);
/*!
......@@ -689,7 +689,7 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
*
* \return the converted Function.
*/
TVM_DLL Function ToCPS(const Function& f, const Module& mod);
TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
/*!
* \brief Remove the continuation argument of a CPS function.
......
......@@ -21,29 +21,33 @@
* \file module.cc
* \brief The global module in Relay.
*/
#include <tvm/relay/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
// NOTE on dependencies on relay analysis.
// We calls into relay's analysis module to verify correctness
// when a relay function is presented.
// These dependency does not happen at the interface-level.
// And is only used to enhance developer experiences when relay
// functions are presented.
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <sstream>
#include <fstream>
#include <unordered_set>
namespace tvm {
namespace relay {
using tvm::NodePrinter;
using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports) {
auto n = make_object<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<std::string> import_set) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
n->constructor_tag_map_ = {};
n->import_set_ = imports;
n->import_set_ = std::move(import_set);
for (const auto& kv : n->functions) {
// set global var map
......@@ -59,26 +63,25 @@ Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}
return Module(n);
data_ = std::move(n);
}
bool ModuleNode::ContainGlobalVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Module";
return (*it).second;
}
tvm::Array<GlobalVar> ModuleNode::GetGlobalVars() const {
tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
std::vector<GlobalVar> global_vars;
for (const auto& pair : global_var_map_) {
global_vars.push_back(pair.second);
......@@ -86,7 +89,7 @@ tvm::Array<GlobalVar> ModuleNode::GetGlobalVars() const {
return tvm::Array<GlobalVar>(global_vars);
}
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
......@@ -94,7 +97,7 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}
tvm::Array<GlobalTypeVar> ModuleNode::GetGlobalTypeVars() const {
tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
global_type_vars.push_back(pair.second);
......@@ -112,7 +115,7 @@ tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
}
// helper function to run type check
relay::Function RunTypeCheck(const Module& mod,
relay::Function RunTypeCheck(const IRModule& mod,
const GlobalVar& var,
relay::Function f) {
auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
......@@ -146,12 +149,12 @@ relay::Function RunTypeCheck(const Module& mod,
return checked_func;
}
void ModuleNode::Add(const GlobalVar& var,
const BaseFunc& f,
bool update) {
void IRModuleNode::Add(const GlobalVar& var,
const BaseFunc& f,
bool update) {
BaseFunc checked_func = f;
if (auto* ptr = f.as<relay::FunctionNode>()) {
checked_func = RunTypeCheck(GetRef<Module>(this),
checked_func = RunTypeCheck(GetRef<IRModule>(this),
var,
GetRef<relay::Function>(ptr));
}
......@@ -162,16 +165,16 @@ void ModuleNode::Add(const GlobalVar& var,
if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<FunctionNode>()->checked_type();
CHECK(AlphaEqual(type, old_type))
auto old_type = functions[var].as<relay::FunctionNode>()->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
AddUnchecked(var, checked_func);
}
void ModuleNode::AddUnchecked(const GlobalVar& var,
const BaseFunc& func) {
void IRModuleNode::AddUnchecked(const GlobalVar& var,
const BaseFunc& func) {
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
......@@ -185,7 +188,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var);
}
void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
......@@ -197,19 +200,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
}
}
void ModuleNode::AddTypeDef(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
void IRModuleNode::AddTypeDef(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(relay::KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
......@@ -220,55 +223,55 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
RegisterConstructors(var, type);
}
void ModuleNode::Update(const GlobalVar& var,
const BaseFunc& func) {
void IRModuleNode::Update(const GlobalVar& var,
const BaseFunc& func) {
this->Add(var, func, true);
}
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
const TypeData& type) {
void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
const TypeData& type) {
this->AddTypeDef(var, type, true);
}
void ModuleNode::Remove(const GlobalVar& var) {
void IRModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var);
auto gvar_node = global_var_map_.CopyOnWrite();
gvar_node->data.erase(var->name_hint);
}
BaseFunc ModuleNode::Lookup(const GlobalVar& var) const {
BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
BaseFunc ModuleNode::Lookup(const std::string& name) const {
BaseFunc IRModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
TypeData ModuleNode::LookupTypeDef(const std::string& name) const {
TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupTypeDef(id);
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
Constructor IRModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
<< "There is no constructor with the tag " << tag;
return (*it).second;
}
void ModuleNode::Update(const Module& mod) {
void IRModuleNode::Update(const IRModule& mod) {
// add functions and type defs. we add them unchecked first, so all definitions
// can reference each other, independent of the order in which they were defined.
for (auto pair : mod->functions) {
......@@ -285,11 +288,11 @@ void ModuleNode::Update(const Module& mod) {
}
}
Module ModuleNode::FromExpr(
IRModule IRModule::FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions);
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node);
......@@ -303,7 +306,7 @@ Module ModuleNode::FromExpr(
return mod;
}
void ModuleNode::Import(const std::string& path) {
void IRModuleNode::Import(const std::string& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
......@@ -311,102 +314,102 @@ void ModuleNode::Import(const std::string& path) {
std::string file_contents {
std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>() };
auto mod_to_import = FromText(file_contents, path);
auto mod_to_import = IRModule::FromText(file_contents, path);
Update(mod_to_import);
}
}
void ModuleNode::ImportFromStd(const std::string& path) {
void IRModuleNode::ImportFromStd(const std::string& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
}
std::unordered_set<std::string> ModuleNode::Imports() const {
std::unordered_set<std::string> IRModuleNode::Imports() const {
return this->import_set_;
}
Module FromText(const std::string& source, const std::string& source_name) {
IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
Module mod = (*f)(source, source_name);
IRModule mod = (*f)(text, source_path);
return mod;
}
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
return IRModule(funcs, types, {});
});
TVM_REGISTER_GLOBAL("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Module mod = args[0];
IRModule mod = args[0];
GlobalVar var = args[1];
ObjectRef val = args[2];
bool update = args[3];
CHECK(val->IsInstance<ExprNode>());
CHECK(val->IsInstance<RelayExprNode>());
if (val->IsInstance<relay::FunctionNode>()) {
mod->Add(var, Downcast<relay::Function>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->()));
auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
mod_copy = relay::transform::EtaExpand(
/* expand_constructor */ false,
/* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<relay::Function>(func), update);
} else {
auto func = FunctionNode::make({}, Downcast<relay::Expr>(val), Type(nullptr), {});
auto func = relay::FunctionNode::make({}, Downcast<RelayExpr>(val), Type(nullptr), {});
mod->Add(var, func, update);
}
*ret = mod;
});
TVM_REGISTER_GLOBAL("relay._module.Module_AddDef")
.set_body_method<Module>(&ModuleNode::AddTypeDef);
.set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar);
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars")
.set_body_method<Module>(&ModuleNode::GetGlobalVars);
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVars);
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar")
.set_body_method<Module>(&ModuleNode::ContainGlobalVar);
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup")
.set_body_typed([](Module mod, GlobalVar var) {
.set_body_typed([](IRModule mod, GlobalVar var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
.set_body_typed([](Module mod, std::string var) {
.set_body_typed([](IRModule mod, std::string var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed([](Module mod, GlobalTypeVar var) {
.set_body_typed([](IRModule mod, GlobalTypeVar var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed([](Module mod, std::string var) {
.set_body_typed([](IRModule mod, std::string var) {
return mod->LookupTypeDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
.set_body_typed([](Module mod, int32_t tag) {
.set_body_typed([](IRModule mod, int32_t tag) {
return mod->LookupTag(tag);
});
......@@ -414,29 +417,28 @@ TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
.set_body_typed([](RelayExpr e,
tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
return IRModule::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Update")
.set_body_typed([](Module mod, Module from) {
.set_body_typed([](IRModule mod, IRModule from) {
mod->Update(from);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Import")
.set_body_typed([](Module mod, std::string path) {
.set_body_typed([](IRModule mod, std::string path) {
mod->Import(path);
});
TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
.set_body_typed([](Module mod, std::string path) {
.set_body_typed([](IRModule mod, std::string path) {
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ModuleNode*>(ref.get());
p->stream << "ModuleNode( " << node->functions << ")";
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
} // namespace relay
} // namespace tvm
......@@ -294,7 +294,7 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return relay::Module The updated Relay module after optimization.
*/
relay::Module Optimize(
IRModule Optimize(
Function func,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
......@@ -303,7 +303,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}
// Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
IRModule relay_module = IRModule::FromExpr(func);
Array<Pass> pass_seqs;
......@@ -408,8 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return updated_module The updated module after device annotation.
*/
relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module,
int fallback_device) {
IRModule RunDeviceAnnotationPass(const IRModule& relay_module,
int fallback_device) {
UpdateHeterogeneousInputs(fallback_device);
auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
auto updated_module = rewrite(relay_module);
......@@ -461,7 +461,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Function func,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params);
IRModule relay_module = Optimize(func, targets_, params);
// Get the updated function.
func = Downcast<Function>(relay_module->Lookup("main"));
......
......@@ -613,7 +613,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::unordered_map<std::string, IRModule> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
......@@ -623,7 +623,7 @@ class CompileEngineImpl : public CompileEngineNode {
const tvm::ir::StringImmNode* code_gen = compiler.as<tvm::ir::StringImmNode>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
ext_mods[code_gen->value] = IRModule({}, {});
}
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
......
......@@ -186,8 +186,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
if (ref->IsInstance<FunctionNode>()) {
GenCFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
} else if (ref->IsInstance<IRModuleNode>()) {
IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) {
GenCFunc(Downcast<Function>(it.second));
}
......
......@@ -24,7 +24,7 @@
#include <dmlc/any.h>
#include <dmlc/json.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
......
......@@ -233,7 +233,7 @@ class Interpreter :
public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(Module mod, DLContext context, Target target)
Interpreter(IRModule mod, DLContext context, Target target)
: mod_(mod),
context_(context),
target_(target),
......@@ -761,7 +761,7 @@ class Interpreter :
private:
// Module
Module mod_;
IRModule mod_;
// For simplicity we only run the interpreter on a single context.
// Context to run the interpreter on.
DLContext context_;
......@@ -779,7 +779,7 @@ class Interpreter :
TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(
Module mod,
IRModule mod,
DLContext context,
Target target) {
if (mod.defined()) {
......
......@@ -752,7 +752,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Module mod = args[0];
IRModule mod = args[0];
this->Lower(mod, args[1], args[2]);
});
} else if (name == "codegen") {
......@@ -813,7 +813,7 @@ relay::Function VMCompiler::BindParamsByName(
return ret;
}
void VMCompiler::Lower(Module mod,
void VMCompiler::Lower(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
......@@ -884,7 +884,7 @@ void VMCompiler::Lower(Module mod,
}
}
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
......
......@@ -62,7 +62,7 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
struct VMCompilerContext {
// The module context for the compilation
Module module;
IRModule module;
// Error reporter
ErrorReporter err_reporter;
// Map from a unique integer to ADT constructor tag
......@@ -107,7 +107,7 @@ class VMCompiler : public runtime::ModuleNode {
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Lower(Module mod,
void Lower(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host);
......@@ -125,7 +125,7 @@ class VMCompiler : public runtime::ModuleNode {
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);
Module OptimizeModule(const Module& mod, const TargetsMap& targets);
IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
void PopulateGlobalMap();
......
......@@ -52,10 +52,10 @@ namespace vm {
* (fn(...) { ... })(...)
*/
struct PrimitiveInliner : ExprMutator {
Module module_;
IRModule module_;
std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_map;
explicit PrimitiveInliner(const Module& module) : module_(module) {}
explicit PrimitiveInliner(const IRModule& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)});
......@@ -106,7 +106,7 @@ struct PrimitiveInliner : ExprMutator {
}
}
Module Inline() {
IRModule Inline() {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
......@@ -137,8 +137,8 @@ struct PrimitiveInliner : ExprMutator {
namespace transform {
Pass InlinePrimitives() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::PrimitiveInliner(m).Inline();
};
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
......
......@@ -60,7 +60,7 @@ Function MarkClosure(const Function& func) {
*/
class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const Module& module) : module_(module) {}
explicit LambdaLifter(const IRModule& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false;
......@@ -184,7 +184,7 @@ class LambdaLifter : public ExprMutator {
}
}
Module Lift() {
IRModule Lift() {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
......@@ -204,7 +204,7 @@ class LambdaLifter : public ExprMutator {
private:
std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> lambda_map_;
std::vector<Var> letrec_;
Module module_;
IRModule module_;
};
} // namespace vm
......@@ -212,8 +212,8 @@ class LambdaLifter : public ExprMutator {
namespace transform {
Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift();
};
return CreateModulePass(pass_func, 1, "LambdaLift", {});
......
......@@ -40,7 +40,7 @@ namespace vm {
* \brief Detects all the functions that can be possibly called by entry function.
*/
struct CallTracer : ExprVisitor {
Module module_;
IRModule module_;
// Record the names of all encountered functions
std::unordered_set<std::string> called_funcs_;
......@@ -48,7 +48,7 @@ struct CallTracer : ExprVisitor {
// Record the expressions that are being visited
std::unordered_set<Expr, ObjectHash, ObjectEqual> visiting_;
explicit CallTracer(const Module& module)
explicit CallTracer(const IRModule& module)
: module_{module},
called_funcs_{},
visiting_{} {}
......@@ -99,7 +99,7 @@ struct CallTracer : ExprVisitor {
*
* \return The module with dead functions removed.
*/
Module RemoveUnusedFunctions(const Module& module,
IRModule RemoveUnusedFunctions(const IRModule& module,
Array<tvm::PrimExpr> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
......@@ -122,8 +122,8 @@ Module RemoveUnusedFunctions(const Module& module,
namespace transform {
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
};
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
......
......@@ -60,8 +60,8 @@ class AlphaEqualHandler:
if (!rhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<ModuleNode>()) {
auto rhsm = rhs.as<ModuleNode>();
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/error.h>
#include <string>
#include <vector>
......@@ -39,7 +39,7 @@ void RelayErrorStream::Raise() const {
template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>;
void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
......
......@@ -33,7 +33,7 @@
#include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
......@@ -242,8 +242,8 @@ class PrettyPrinter :
return PrintType(Downcast<Type>(node), meta);
} else if (node.as<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as<ModuleNode>()) {
return PrintMod(Downcast<Module>(node));
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else {
Doc doc;
return doc << node;
......@@ -525,7 +525,7 @@ class PrettyPrinter :
}
}
Doc PrintMod(const Module& mod) {
Doc PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;
// type definitions
......
......@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#include <tvm/relay/error.h>
#include <vector>
#include <algorithm>
#include <limits>
......
......@@ -118,8 +118,8 @@ Expr AlterOpLayout(const Expr& expr) {
namespace transform {
Pass AlterOpLayout() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
......
......@@ -129,8 +129,8 @@ Expr CanonicalizeCast(const Expr& e) {
namespace transform {
Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
......
......@@ -69,8 +69,8 @@ Expr CanonicalizeOps(const Expr& e) {
namespace transform {
Pass CanonicalizeOps() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
......
......@@ -216,8 +216,8 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
namespace transform {
Pass CombineParallelConv2D(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
......
......@@ -76,8 +76,8 @@ Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
namespace transform {
Pass CombineParallelDense(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
......
......@@ -186,8 +186,8 @@ namespace transform {
Pass CombineParallelOpBatch(const std::string& op_name,
const std::string& batch_op_name,
uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelOpBatch(f,
op_name,
batch_op_name,
......
......@@ -128,8 +128,8 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) {
namespace transform {
Pass ConvertLayout(const std::string& desired_layout) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
......
......@@ -140,8 +140,8 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
namespace transform {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
......
......@@ -572,8 +572,8 @@ TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps")
namespace transform {
Pass RewriteAnnotatedOps(int fallback_device) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
......
......@@ -87,8 +87,8 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
namespace transform {
Pass EliminateCommonSubexpr(PackedFunc fskip) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
......
......@@ -57,7 +57,7 @@ class TypeVarReplacer : public TypeMutator {
*/
class EtaExpander : public ExprMutator {
public:
explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var)
explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var)
: mod_(mod),
type_var_replacer_(TypeVarReplacer()),
expand_constructor_(expand_constructor),
......@@ -66,7 +66,7 @@ class EtaExpander : public ExprMutator {
<< "must expand at least one language feature";
}
Module Expand() {
IRModule Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
const BaseFunc base_func = mod_->Lookup(global_var);
if (auto* n = base_func.as<FunctionNode>()) {
......@@ -147,7 +147,7 @@ class EtaExpander : public ExprMutator {
private:
/*! \brief reference to module being expanded */
const Module mod_;
const IRModule mod_;
/*! \brief type variable replacer */
TypeVarReplacer type_var_replacer_;
/*! \brief whether to expand constructor nodes */
......@@ -161,8 +161,8 @@ class EtaExpander : public ExprMutator {
namespace transform {
Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module mod, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) {
return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
};
return CreateModulePass(pass_func, 1, "EtaExpand", {});
......
......@@ -25,7 +25,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include "pass_util.h"
namespace tvm {
......@@ -89,7 +89,7 @@ FeatureSet DetectFeature(const Expr& expr) {
return fd.fs;
}
FeatureSet DetectFeature(const Module& mod) {
FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
......@@ -99,7 +99,7 @@ FeatureSet DetectFeature(const Module& mod) {
return fs;
}
Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
return static_cast<Array<Integer>>(fs);
}
......
......@@ -79,7 +79,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.check_constant")
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor, Module module)
explicit ConstantFolder(FInterpreter executor, IRModule module)
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
......@@ -168,7 +168,7 @@ class ConstantFolder : public ExprMutator {
// Internal constant checker
ConstantChecker checker_;
// Module
Module module_;
IRModule module_;
// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
......@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator {
// TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
auto mod = ModuleNode::make(
auto mod = IRModule(
{},
module_->type_definitions,
module_->Imports());
......@@ -277,7 +277,7 @@ class ConstantFolder : public ExprMutator {
};
Expr FoldConstant(const Expr& expr, const Module& mod) {
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
......@@ -292,8 +292,8 @@ Expr FoldConstant(const Expr& expr, const Module& mod) {
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f, m));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
......
......@@ -949,8 +949,8 @@ Expr BackwardFoldScaleAxis(const Expr& data) {
namespace transform {
Pass ForwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
......@@ -962,8 +962,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
.set_body_typed(ForwardFoldScaleAxis);
Pass BackwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
......
......@@ -970,15 +970,15 @@ class FuseMutator : private ExprMutator {
}
};
Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level);
}
namespace transform {
Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
......
......@@ -67,7 +67,7 @@ Type WithGradientType(const Type&);
/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& e, const Module& mod);
Expr FirstOrderGradient(const Expr& e, const IRModule& mod);
Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
......@@ -80,7 +80,7 @@ Type WithGradientType(const Type& t) {
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
Expr DeGlobal(const IRModule& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) {
BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
......@@ -222,7 +222,7 @@ Type GradRetType(const Function& f) {
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}
Expr FirstOrderGradient(const Expr& re, const Module& mod) {
Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
// Currently we first remove any global functions for the first
// order case.
auto e = DeGlobal(mod, re);
......@@ -532,7 +532,7 @@ bool MissingGrad(const Expr& e) {
return false;
}
Expr Gradient(const Expr& re, const Module& mod) {
Expr Gradient(const Expr& re, const IRModule& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
......
......@@ -41,10 +41,10 @@ namespace relay {
using namespace tvm::runtime;
struct KindChecker : TypeFunctor<Kind(const Type&)> {
const Module& mod;
const IRModule& mod;
ErrorReporter err_reporter;
explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {}
explicit KindChecker(const IRModule& mod) : mod(mod), err_reporter() {}
void ReportFatalError(const Error& err) {
this->err_reporter.Report(err);
......@@ -177,7 +177,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
}
};
Kind KindCheck(const Type& t, const Module& mod) {
Kind KindCheck(const Type& t, const IRModule& mod) {
KindChecker kc(mod);
return kc.Check(t);
}
......@@ -185,7 +185,7 @@ Kind KindCheck(const Type& t, const Module& mod) {
TVM_REGISTER_GLOBAL("relay._analysis.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], ModuleNode::make({}, {}));
*ret = KindCheck(args[0], IRModule({}, {}));
} else {
*ret = KindCheck(args[0], args[1]);
}
......
......@@ -98,8 +98,8 @@ Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
namespace transform {
Pass Legalize(const std::string& legalize_map_attr_name) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")});
......
......@@ -155,17 +155,17 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand,
const Module& mod);
const IRModule& mod);
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod);
const IRModule& mod);
// Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const Pattern& cand,
const Module& mod) {
const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else {
......@@ -178,7 +178,7 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand,
const Module& mod) {
const IRModule& mod) {
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args.
......@@ -228,7 +228,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand,
const Module& mod) {
const IRModule& mod) {
// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
Array<Pattern> args;
......@@ -271,7 +271,7 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
* \return Returns a list of cases that are not handled by the match
* expression.
*/
Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
/* algorithm:
* candidates = { Wildcard }
* while candidates not empty {
......@@ -328,10 +328,10 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
// expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases")
.set_body_typed(
[](const Match& match, const Module& mod_ref) {
Module call_mod = mod_ref;
[](const Match& match, const IRModule& mod_ref) {
IRModule call_mod = mod_ref;
if (!call_mod.defined()) {
call_mod = ModuleNode::make({}, {});
call_mod = IRModule({}, {});
}
return UnmatchedCases(match, call_mod);
});
......
......@@ -569,7 +569,7 @@ FInterpreter CPUInterpreter() {
// in case we are already in a build context.
With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return CreateInterpreter(Module(nullptr), CPUContext(), target);
return CreateInterpreter(IRModule(nullptr), CPUContext(), target);
}
using FuncId = int;
......@@ -623,7 +623,7 @@ Function AsFunc(const Expr& e) {
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
PartialEvaluator(const Module& mod) : mod_(mod) { }
PartialEvaluator(const IRModule& mod) : mod_(mod) { }
PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
......@@ -954,7 +954,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
std::vector<transform::Pass> passes = {transform::FuseOps(0),
transform::InferType()};
auto mod = ModuleNode::FromExpr(expr);
auto mod = IRModule::FromExpr(expr);
auto seq = transform::Sequential(passes);
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
......@@ -1184,7 +1184,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
private:
Environment env_;
Module mod_;
IRModule mod_;
std::unordered_map<GlobalVar, PStatic, ObjectHash, ObjectEqual> gv_map_;
/*! Termination checking is done as follows:
* We have finitely many FunctionIds.
......@@ -1255,7 +1255,7 @@ Expr PostProcess(const Expr& e) {
} // namespace partial_eval
Module PartialEval(const Module& m) {
IRModule PartialEval(const IRModule& m) {
relay::partial_eval::PartialEvaluator pe(m);
std::vector<GlobalVar> gvs;
for (const auto& p : m->functions) {
......@@ -1270,9 +1270,9 @@ Module PartialEval(const Module& m) {
namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return PartialEval(m);
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::PartialEval(m);
};
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
}
......
......@@ -98,7 +98,7 @@ class ModulePassNode : public PassNode {
* implement the algorithm in the `pass_func` and let it run on a module. It
* will then remove the dead code including the unused functions in the module.
*/
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;
ModulePassNode() = default;
......@@ -114,7 +114,7 @@ class ModulePassNode : public PassNode {
*
* \return Return the updated module.
*/
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
......@@ -122,7 +122,7 @@ class ModulePassNode : public PassNode {
PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.ModulePass";
......@@ -155,7 +155,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;
FunctionPassNode() = default;
......@@ -171,7 +171,7 @@ class FunctionPassNode : public PassNode {
*
* \return Return the updated module.
*/
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
......@@ -179,7 +179,7 @@ class FunctionPassNode : public PassNode {
PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass";
......@@ -248,7 +248,7 @@ class SequentialNode : public PassNode {
* metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
* PassInfo, to store the relevant information including the parent passes.
*/
void ResolveDependency(const Module& mod);
void ResolveDependency(const IRModule& mod);
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
......@@ -261,7 +261,7 @@ class SequentialNode : public PassNode {
*
* \return Return the updated module.
*/
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
......@@ -278,7 +278,7 @@ PassInfo PassInfoNode::make(int opt_level,
}
ModulePass ModulePassNode::make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<ModulePassNode>();
n->pass_func = std::move(pass_func);
......@@ -287,7 +287,7 @@ ModulePass ModulePassNode::make(
}
// Module -> Module optimizations.
Module ModulePassNode::operator()(const Module& mod,
IRModule ModulePassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : "
......@@ -295,13 +295,13 @@ Module ModulePassNode::operator()(const Module& mod,
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
Module updated_mod = pass_func(mod, pass_ctx);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func);
......@@ -310,7 +310,7 @@ FunctionPass FunctionPassNode::make(
}
// Perform Module -> Module optimizations at the Function level.
Module FunctionPassNode::operator()(const Module& mod,
IRModule FunctionPassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
......@@ -320,7 +320,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->opt_level;
// Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::Function
......@@ -364,7 +364,7 @@ const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(get());
}
void SequentialNode::ResolveDependency(const Module& mod) {
void SequentialNode::ResolveDependency(const IRModule& mod) {
// TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes.
......@@ -410,9 +410,9 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
Module SequentialNode::operator()(const Module& module,
IRModule SequentialNode::operator()(const IRModule& module,
const PassContext& pass_ctx) const {
Module mod = module;
IRModule mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
......@@ -429,7 +429,7 @@ Module SequentialNode::operator()(const Module& module,
}
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
......@@ -438,7 +438,7 @@ Pass CreateModulePass(
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
......@@ -479,7 +479,7 @@ TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass")
TVM_REGISTER_GLOBAL("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
Module mod = args[1];
IRModule mod = args[1];
*ret = pass(mod);
});
......
......@@ -32,8 +32,8 @@ namespace relay {
namespace transform {
Pass PrintIR(bool show_meta_data) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
return m;
};
......
......@@ -92,8 +92,8 @@ Pass QuantizeAnnotate() {
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
......
......@@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr")
});
Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto ret = Downcast<Function>(
ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
return ret;
......
......@@ -190,7 +190,7 @@ Expr QuantizeRealize(const Call& ref_call,
}
Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
auto mod = IRModule::FromExpr(expr);
mod = transform::FoldConstant()(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
......@@ -522,8 +522,8 @@ RELAY_REGISTER_OP("annotation.cast_hint")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
......
......@@ -183,8 +183,8 @@ Expr SimplifyInference(const Expr& e) {
namespace transform {
Pass SimplifyInference() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
return CreateFunctionPass(pass_func, 0, "SimplifyInference",
......
......@@ -291,7 +291,7 @@ Expr ToANormalFormAux(const Expr& e) {
return Fill::ToANormalForm(e, dg, &node_scope);
}
Module ToANormalForm(const Module& m) {
IRModule ToANormalForm(const IRModule& m) {
DLOG(INFO) << "ToANF:" << std::endl << m;
tvm::Map<GlobalVar, Function> updates;
......@@ -321,9 +321,9 @@ Module ToANormalForm(const Module& m) {
namespace transform {
Pass ToANormalForm() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return ToANormalForm(m);
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::ToANormalForm(m);
};
return CreateModulePass(pass_func, 1, "ToANormalForm", {});
}
......
......@@ -111,21 +111,27 @@ using VarMap = std::unordered_map<Var, Var, ObjectHash, ObjectEqual>;
*/
using MCont = std::function<Expr(const Expr&)>;
Function ToCPS(const Function& f, const Module& m, CPSMap* cm);
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm);
Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) {
std::function<Var(Var)> remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); };
Function ToCPS(const Function& f,
const IRModule& m,
CPSMap* cm,
VarMap* vm,
const TypeVar& answer) {
std::function<Var(Var)> remap = [&](const Var& v) {
return vm->count(v) == 0 ? v : vm->at(v);
};
auto function_type = Downcast<FuncType>(f->checked_type());
// Each MCont can be used at most once.
struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator {
CPSFunctor(const std::function<Var(Var)>& remap,
const TypeVar& answer,
const Module& m,
const IRModule& m,
VarMap* vm,
CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { }
const std::function<Var(Var)>& remap;
TypeVar answer;
Module m;
IRModule m;
VarMap* vm;
CPSMap* cm;
......@@ -295,7 +301,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
f->attrs);
}
Function ToCPS(const Function& f, const Module& m, CPSMap* cm) {
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
TypeVar answer = TypeVarNode::make("answer", kType);
VarMap var;
struct Remapper : ExprVisitor, PatternVisitor {
......@@ -325,7 +331,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm) {
return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs);
}
Function ToCPS(const Function& f, const Module& m) {
Function ToCPS(const Function& f, const IRModule& m) {
CPSMap cps;
return ToCPS(f, m, &cps);
}
......@@ -368,7 +374,7 @@ Function UnCPS(const Function& f) {
}
TVM_REGISTER_GLOBAL("relay._transform.to_cps")
.set_body_typed(static_cast<Function (*)(const Function&, const Module&)>(ToCPS));
.set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS));
TVM_REGISTER_GLOBAL("relay._transform.un_cps")
.set_body_typed(UnCPS);
......@@ -376,8 +382,8 @@ TVM_REGISTER_GLOBAL("relay._transform.un_cps")
namespace transform {
Pass ToCPS() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Function(ToCPS(f, m));
};
return CreateFunctionPass(pass_func, 1, "ToCPS", {});
......@@ -388,8 +394,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ToCPS")
Pass UnCPS() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Function(UnCPS(f));
};
return CreateFunctionPass(pass_func, 1, "UnCPS", {});
......
......@@ -79,8 +79,8 @@ Expr ToGraphNormalForm(const Expr& e) {
namespace transform {
Pass ToGraphNormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ToGraphNormalForm(f));
};
return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {});
......
......@@ -105,7 +105,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
public:
// constructors
explicit TypeInferencer(Module mod, GlobalVar current_func)
explicit TypeInferencer(IRModule mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
......@@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// type resolver that maps back to type
class Resolver;
// internal environment
Module mod_;
IRModule mod_;
// The current function being type checked.
GlobalVar current_func_;
......@@ -798,7 +798,7 @@ void EnsureCheckedType(const Expr& e) {
AllCheckTypePopulated().VisitExpr(e);
}
Expr InferType(const Expr& expr, const Module& mod) {
Expr InferType(const Expr& expr, const IRModule& mod) {
auto main = mod->GetGlobalVar("main");
auto inferencer = TypeInferencer(mod, main);
auto e = inferencer.Infer(expr);
......@@ -811,7 +811,7 @@ Expr InferType(const Expr& expr, const Module& mod) {
}
Function InferType(const Function& func,
const Module& mod,
const IRModule& mod,
const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_object<FunctionNode>(*func.operator->()));
......@@ -832,8 +832,8 @@ Function InferType(const Function& func,
namespace transform {
Pass InferType() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(InferType(f, m));
};
return CreateFunctionPass(pass_func, 0, "InferType", {});
......
......@@ -60,7 +60,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
location = ref;
}
TVM_DLL Module GetModule() final {
TVM_DLL IRModule GetModule() final {
return this->solver_->module_;
}
......@@ -531,7 +531,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
// constructor
TypeSolver::TypeSolver(
const GlobalVar& current_func,
const Module& module,
const IRModule& module,
ErrorReporter* err_reporter)
: reporter_(make_object<Reporter>(this)),
current_func(current_func),
......@@ -661,7 +661,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter();
auto module = ModuleNode::make({}, {});
auto module = IRModule({}, {});
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
......
......@@ -62,7 +62,7 @@ using common::LinkedList;
*/
class TypeSolver {
public:
TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter);
TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
......@@ -179,7 +179,7 @@ class TypeSolver {
/*! \brief Error reporting. */
ErrorReporter* err_reporter_;
/*! \brief The module. */
Module module_;
IRModule module_;
/*!
* \brief GetTypeNode that is corresponds to t.
......
......@@ -72,7 +72,7 @@ class TypeVarTVisitor : public TypeVisitor {
class TypeVarEVisitor : private ExprVisitor {
public:
explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {}
explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
Array<TypeVar> CollectFree() {
Array<TypeVar> ret;
......@@ -156,7 +156,7 @@ class TypeVarEVisitor : private ExprVisitor {
private:
InsertionSet<TypeVar> type_vars_;
InsertionSet<TypeVar> bound_type_vars_;
const Module& mod_;
const IRModule& mod_;
};
class VarVisitor : protected ExprVisitor, protected PatternVisitor {
......@@ -234,27 +234,27 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
InsertionSet<Var> bound_vars_;
};
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod) {
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).Free(expr);
}
tvm::Array<TypeVar> FreeTypeVars(const Type& type, const Module& mod) {
tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).Free(type);
}
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod) {
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).Bound(expr);
}
tvm::Array<TypeVar> BoundTypeVars(const Type& type, const Module& mod) {
tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).Bound(type);
}
tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod) {
tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).All(expr);
}
tvm::Array<TypeVar> AllTypeVars(const Type& type, const Module& mod) {
tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).All(type);
}
......@@ -293,7 +293,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.all_vars")
TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
Module mod = args[1];
IRModule mod = args[1];
if (x.as<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x), mod);
} else {
......@@ -304,7 +304,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
Module mod = args[1];
IRModule mod = args[1];
if (x.as<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x), mod);
} else {
......@@ -315,7 +315,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
Module mod = args[1];
IRModule mod = args[1];
if (x.as<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x), mod);
} else {
......
......@@ -33,7 +33,7 @@ TEST(Relay, SelfReference) {
auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = relay::ModuleNode::FromExpr(fx);
auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
......
......@@ -22,7 +22,7 @@
#include <tvm/build_module.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
......@@ -73,7 +73,7 @@ TEST(Relay, Sequential) {
relay::transform::AlterOpLayout()
};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
auto mod = relay::ModuleNode::FromExpr(func);
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3;
pass_ctx->fallback_device = 1;
......@@ -100,7 +100,7 @@ TEST(Relay, Sequential) {
relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function.
auto mod1 = relay::ModuleNode::FromExpr(expected_func);
auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment