Unverified Commit c69092ae by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Unified IR IRModule structure. (#4699)

This PR brings relay::Module as the unified IRModule structure.
IRModule will be used as the basic unit for transformations
through out the stack.

- Rename relay::Module -> IRModule
- Move relay/module.h -> ir/module.h
- ModuleNode::FromExpr -> IRModule::FromExpr
- FromText -> IRModule::FromText
parent bd17baa2
...@@ -353,7 +353,7 @@ registration. ...@@ -353,7 +353,7 @@ registration.
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
// Create a module for optimization. // Create a module for optimization.
auto mod = relay::ModuleNode::FromExpr(fx); auto mod = IRModule::FromExpr(fx);
// Create a sequential pass. // Create a sequential pass.
tvm::Array<relay::transform::Pass> pass_seqs{ tvm::Array<relay::transform::Pass> pass_seqs{
......
...@@ -18,55 +18,41 @@ ...@@ -18,55 +18,41 @@
*/ */
/*! /*!
* \file tvm/relay/module.h * \file tvm/ir/module.h
* \brief The global environment: contains information needed to * \brief IRModule that holds the functions and type definitions.
* compile & optimize Relay programs.
*/ */
#ifndef TVM_RELAY_MODULE_H_ #ifndef TVM_IR_MODULE_H_
#define TVM_RELAY_MODULE_H_ #define TVM_IR_MODULE_H_
#include <tvm/relay/error.h> #include <tvm/ir/type.h>
#include <tvm/relay/expr.h> #include <tvm/ir/expr.h>
#include <tvm/relay/adt.h> #include <tvm/ir/adt.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
namespace tvm { namespace tvm {
namespace relay { class IRModule;
/*!
struct Module; * \brief IRModule that holds functions and type definitions.
/*! \brief The global environment of Relay programs.
*
* The global environment contains the global
* information needed to compile a Relay program.
*
* It contains all global functions, and configuration
* options.
* *
* Many operations require access to the global * IRModule is the basic unit for all IR transformations across the stack.
* Module. We pass the Module by value
* in a functional style as an explicit argument,
* but we mutate the Module while optimizing
* Relay programs.
* *
* The functional style allows users to construct custom * Many operations require access to the global IRModule.
* environments easily, for example each thread can store * We pass the IRModule by value in a functional style as an explicit argument,
* a Module while auto-tuning. * but we mutate the Module while optimizing programs.
* \sa IRModule
*/ */
class IRModuleNode : public Object {
class ModuleNode : public RelayNode {
public: public:
/*! \brief A map from ids to all global functions. */ /*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions; tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */ /*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions; tvm::Map<GlobalTypeVar, TypeData> type_definitions;
ModuleNode() {} IRModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("functions", &functions); v->Visit("functions", &functions);
...@@ -75,10 +61,6 @@ class ModuleNode : public RelayNode { ...@@ -75,10 +61,6 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_);
} }
TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
* \param var The var of the global function. * \param var The var of the global function.
...@@ -219,7 +201,7 @@ class ModuleNode : public RelayNode { ...@@ -219,7 +201,7 @@ class ModuleNode : public RelayNode {
* functions in another environment. * functions in another environment.
* \param other The other environment. * \param other The other environment.
*/ */
TVM_DLL void Update(const Module& other); TVM_DLL void Update(const IRModule& other);
/*! /*!
* \brief Import Relay code from the file at path. * \brief Import Relay code from the file at path.
...@@ -243,24 +225,8 @@ class ModuleNode : public RelayNode { ...@@ -243,24 +225,8 @@ class ModuleNode : public RelayNode {
*/ */
TVM_DLL std::unordered_set<std::string> Imports() const; TVM_DLL std::unordered_set<std::string> Imports() const;
/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module"; static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private: private:
/*! \brief Helper function for registering a typedef's constructors */ /*! \brief Helper function for registering a typedef's constructors */
...@@ -285,27 +251,62 @@ class ModuleNode : public RelayNode { ...@@ -285,27 +251,62 @@ class ModuleNode : public RelayNode {
importing is idempotent for each module. importing is idempotent for each module.
*/ */
std::unordered_set<std::string> import_set_; std::unordered_set<std::string> import_set_;
friend class IRModule;
}; };
struct Module : public ObjectRef { /*!
Module() {} * \brief Managed reference class to IRModuleNode.
explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {} * \sa IRModuleNode
*/
ModuleNode* operator->() const { class IRModule : public ObjectRef {
return static_cast<ModuleNode*>(get_mutable()); public:
/*!
* \brief constructor
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
* \brief constructor
* \param n The object pointer.
*/
explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
IRModuleNode* operator->() const {
auto* ptr = get_mutable();
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
} }
/*!
* \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
using ContainerType = ModuleNode; /*!
* \brief Parse text format source file into an IRModule.
* \param text A string of Relay source code.
* \param source_path The path to the source file.
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
}; };
/*! \brief Parse Relay source into a module.
* \param source A string of Relay source code.
* \param source_name The name of the source file.
* \return A Relay module.
*/
Module FromText(const std::string& source, const std::string& source_name);
} // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_IR_MODULE_H_
#endif // TVM_RELAY_MODULE_H_
...@@ -30,9 +30,8 @@ ...@@ -30,9 +30,8 @@
namespace tvm { namespace tvm {
// TODO(tqchen): remove after migrate Module to ir. // TODO(tqchen): remove after migrate Module to ir.
namespace relay { class IRModule;
struct Module;
}
/*! /*!
* \brief reporter that reports back to the * \brief reporter that reports back to the
...@@ -76,7 +75,7 @@ class TypeReporterNode : public Object { ...@@ -76,7 +75,7 @@ class TypeReporterNode : public Object {
* \brief Retrieve the current global module. * \brief Retrieve the current global module.
* \return The global module. * \return The global module.
*/ */
TVM_DLL virtual relay::Module GetModule() = 0; TVM_DLL virtual IRModule GetModule() = 0;
// solver is not serializable. // solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(tvm::AttrVisitor* v) {}
......
...@@ -71,7 +71,7 @@ Stmt CanonicalSimplify(Stmt stmt, ...@@ -71,7 +71,7 @@ Stmt CanonicalSimplify(Stmt stmt,
* \return Canonicalized expression. * \return Canonicalized expression.
*/ */
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>()); Map<Var, Range> vrange = Map<Var, Range>());
/*! /*!
* \brief Deep compare lhs and rhs * \brief Deep compare lhs and rhs
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/adt.h> #include <tvm/relay/adt.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <string> #include <string>
...@@ -49,7 +49,7 @@ namespace relay { ...@@ -49,7 +49,7 @@ namespace relay {
* *
* \return The kind of the passed type. * \return The kind of the passed type.
*/ */
TVM_DLL Kind KindCheck(const Type& t, const Module& mod); TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
/*! /*!
* \brief Check whether an expression is constant. * \brief Check whether an expression is constant.
...@@ -188,7 +188,7 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr); ...@@ -188,7 +188,7 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
* *
* \return List of free vars, in the PostDFS order visited by expr. * \return List of free vars, in the PostDFS order visited by expr.
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod);
/*! /*!
* \brief Get free TypeVars from type t. * \brief Get free TypeVars from type t.
...@@ -201,7 +201,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod); ...@@ -201,7 +201,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
* *
* \return List of free type vars, in the PostDFS order visited by type. * \return List of free type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const IRModule& mod);
/*! /*!
* \brief Get all bound type variables from expression expr. * \brief Get all bound type variables from expression expr.
...@@ -214,7 +214,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod); ...@@ -214,7 +214,7 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
* *
* \return List of bound type vars, in the PostDFS order in the expression. * \return List of bound type vars, in the PostDFS order in the expression.
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod);
/*! /*!
* \brief Get all bound type variables from type t. * \brief Get all bound type variables from type t.
...@@ -227,7 +227,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod); ...@@ -227,7 +227,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
* *
* \return List of bound type vars, in the PostDFS order visited by type. * \return List of bound type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const IRModule& mod);
/*! /*!
* \brief Get all type variables in expression expr. * \brief Get all type variables in expression expr.
...@@ -237,7 +237,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod); ...@@ -237,7 +237,7 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
* *
* \return List of type vars, in the PostDFS order in the expression. * \return List of type vars, in the PostDFS order in the expression.
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
/*! /*!
* \brief Get all type variables in type t. * \brief Get all type variables in type t.
...@@ -247,7 +247,7 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod); ...@@ -247,7 +247,7 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
* *
* \return List of type vars, in the PostDFS order visited by type. * \return List of type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);
/*! /*!
* \brief Collect the device mapping information of each expression. * \brief Collect the device mapping information of each expression.
...@@ -277,7 +277,7 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr); ...@@ -277,7 +277,7 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
* \return Returns a list of cases (as patterns) that are not handled by the match * \return Returns a list of cases (as patterns) that are not handled by the match
* expression. * expression.
*/ */
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod); TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
......
...@@ -106,9 +106,6 @@ class Id : public ObjectRef { ...@@ -106,9 +106,6 @@ class Id : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
}; };
struct Module;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -24,13 +24,16 @@ ...@@ -24,13 +24,16 @@
#ifndef TVM_RELAY_ERROR_H_ #ifndef TVM_RELAY_ERROR_H_
#define TVM_RELAY_ERROR_H_ #define TVM_RELAY_ERROR_H_
#include <tvm/ir/module.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./module.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -146,7 +149,7 @@ class ErrorReporter { ...@@ -146,7 +149,7 @@ class ErrorReporter {
* \param module The module to report errors on. * \param module The module to report errors on.
* \param use_color Controls whether to colorize the output. * \param use_color Controls whether to colorize the output.
*/ */
void RenderErrors(const Module& module, bool use_color = true); void RenderErrors(const IRModule& module, bool use_color = true);
inline bool AnyErrors() { inline bool AnyErrors() {
return errors_.size() != 0; return errors_.size() != 0;
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
#include <bitset> #include <bitset>
namespace tvm { namespace tvm {
...@@ -141,7 +143,6 @@ class FeatureSet { ...@@ -141,7 +143,6 @@ class FeatureSet {
*/ */
FeatureSet DetectFeature(const RelayExpr& expr); FeatureSet DetectFeature(const RelayExpr& expr);
struct Module;
/*! /*!
* \brief Calculate the feature of the program. * \brief Calculate the feature of the program.
* *
...@@ -149,7 +150,7 @@ struct Module; ...@@ -149,7 +150,7 @@ struct Module;
* *
* \return The FeatureSet. * \return The FeatureSet.
*/ */
FeatureSet DetectFeature(const Module& mod); FeatureSet DetectFeature(const IRModule& mod);
/*! /*!
* \brief Calculate the feature of the program. * \brief Calculate the feature of the program.
...@@ -159,7 +160,7 @@ FeatureSet DetectFeature(const Module& mod); ...@@ -159,7 +160,7 @@ FeatureSet DetectFeature(const Module& mod);
* *
* \return The FeatureSet. * \return The FeatureSet.
*/ */
inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) { inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod); return DetectFeature(expr) + DetectFeature(mod);
} }
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#define TVM_RELAY_INTERPRETER_H_ #define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
...@@ -62,7 +62,7 @@ namespace relay { ...@@ -62,7 +62,7 @@ namespace relay {
* \return A function that takes in an expression and returns a value. * \return A function that takes in an expression and returns a value.
*/ */
runtime::TypedPackedFunc<ObjectRef(Expr)> runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target); CreateInterpreter(IRModule mod, DLContext context, Target target);
/*! \brief A Relay closure, i.e a scope and a function. */ /*! \brief A Relay closure, i.e a scope and a function. */
class Closure; class Closure;
......
...@@ -61,7 +61,7 @@ ...@@ -61,7 +61,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <string> #include <string>
...@@ -236,7 +236,7 @@ class PassNode : public RelayNode { ...@@ -236,7 +236,7 @@ class PassNode : public RelayNode {
* *
* \return The transformed module. * \return The transformed module.
*/ */
Module operator()(const Module& mod) const { IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current()); return this->operator()(mod, PassContext::Current());
} }
...@@ -248,8 +248,8 @@ class PassNode : public RelayNode { ...@@ -248,8 +248,8 @@ class PassNode : public RelayNode {
* *
* \return The transformed module. * \return The transformed module.
*/ */
virtual Module operator()(const Module& mod, virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0; const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(tvm::AttrVisitor* v) {}
...@@ -266,7 +266,7 @@ class Pass : public ObjectRef { ...@@ -266,7 +266,7 @@ class Pass : public ObjectRef {
* *
* \return The transformed module. * \return The transformed module.
*/ */
Module operator()(const Module& mod) const { IRModule operator()(const IRModule& mod) const {
const PassNode* node = operator->(); const PassNode* node = operator->();
CHECK(node != nullptr); CHECK(node != nullptr);
return node->operator()(mod); return node->operator()(mod);
...@@ -279,8 +279,8 @@ class Pass : public ObjectRef { ...@@ -279,8 +279,8 @@ class Pass : public ObjectRef {
* *
* \return The transformed module. * \return The transformed module.
*/ */
Module operator()(const Module& mod, IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassNode* node = operator->(); const PassNode* node = operator->();
CHECK(node != nullptr); CHECK(node != nullptr);
return node->operator()(mod, pass_ctx); return node->operator()(mod, pass_ctx);
...@@ -329,7 +329,7 @@ class Sequential : public Pass { ...@@ -329,7 +329,7 @@ class Sequential : public Pass {
* \return The created module pass. * \return The created module pass.
*/ */
Pass CreateModulePass( Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func, const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::PrimExpr>& required); const tvm::Array<tvm::PrimExpr>& required);
...@@ -345,7 +345,7 @@ Pass CreateModulePass( ...@@ -345,7 +345,7 @@ Pass CreateModulePass(
* \return The created function pass. * \return The created function pass.
*/ */
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, Module, PassContext)>& pass_func, Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::PrimExpr>& required); const tvm::Array<tvm::PrimExpr>& required);
...@@ -624,7 +624,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds); ...@@ -624,7 +624,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
* \note this function mutates mod and is not thread-safe. * \note this function mutates mod and is not thread-safe.
*/ */
TVM_DLL Function InferType(const Function& f, TVM_DLL Function InferType(const Function& f,
const Module& mod, const IRModule& mod,
const GlobalVar& var); const GlobalVar& var);
/*! /*!
...@@ -689,7 +689,7 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); ...@@ -689,7 +689,7 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
* *
* \return the converted Function. * \return the converted Function.
*/ */
TVM_DLL Function ToCPS(const Function& f, const Module& mod); TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
/*! /*!
* \brief Remove the continuation argument of a CPS function. * \brief Remove the continuation argument of a CPS function.
......
...@@ -294,7 +294,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -294,7 +294,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* *
* \return relay::Module The updated Relay module after optimization. * \return relay::Module The updated Relay module after optimization.
*/ */
relay::Module Optimize( IRModule Optimize(
Function func, Function func,
const TargetsMap& targets, const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) { const std::unordered_map<std::string, runtime::NDArray>& params) {
...@@ -303,7 +303,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -303,7 +303,7 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
// Perform Module->Module optimizations. // Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func); IRModule relay_module = IRModule::FromExpr(func);
Array<Pass> pass_seqs; Array<Pass> pass_seqs;
...@@ -408,8 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -408,8 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* *
* \return updated_module The updated module after device annotation. * \return updated_module The updated module after device annotation.
*/ */
relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, IRModule RunDeviceAnnotationPass(const IRModule& relay_module,
int fallback_device) { int fallback_device) {
UpdateHeterogeneousInputs(fallback_device); UpdateHeterogeneousInputs(fallback_device);
auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
auto updated_module = rewrite(relay_module); auto updated_module = rewrite(relay_module);
...@@ -461,7 +461,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -461,7 +461,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Function func, Function func,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) { const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Optimize input Relay Function and returns Relay Module // Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params); IRModule relay_module = Optimize(func, targets_, params);
// Get the updated function. // Get the updated function.
func = Downcast<Function>(relay_module->Lookup("main")); func = Downcast<Function>(relay_module->Lookup("main"));
......
...@@ -613,7 +613,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -613,7 +613,7 @@ class CompileEngineImpl : public CompileEngineNode {
} }
Array<tvm::runtime::Module> LowerExternalFunctions() { Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods; std::unordered_map<std::string, IRModule> ext_mods;
std::vector<CCacheKey> cached_ext_funcs; std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) { for (const auto& it : cache_) {
auto src_func = it.first->source_func; auto src_func = it.first->source_func;
...@@ -623,7 +623,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -623,7 +623,7 @@ class CompileEngineImpl : public CompileEngineNode {
const tvm::ir::StringImmNode* code_gen = compiler.as<tvm::ir::StringImmNode>(); const tvm::ir::StringImmNode* code_gen = compiler.as<tvm::ir::StringImmNode>();
CHECK(code_gen) << "No external codegen is set"; CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) { if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {}); ext_mods[code_gen->value] = IRModule({}, {});
} }
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>(); const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
......
...@@ -186,8 +186,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { ...@@ -186,8 +186,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
if (ref->IsInstance<FunctionNode>()) { if (ref->IsInstance<FunctionNode>()) {
GenCFunc(Downcast<Function>(ref)); GenCFunc(Downcast<Function>(ref));
} else if (ref->IsInstance<relay::ModuleNode>()) { } else if (ref->IsInstance<IRModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref); IRModule mod = Downcast<IRModule>(ref);
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
GenCFunc(Downcast<Function>(it.second)); GenCFunc(Downcast<Function>(it.second));
} }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <dmlc/any.h> #include <dmlc/any.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
......
...@@ -233,7 +233,7 @@ class Interpreter : ...@@ -233,7 +233,7 @@ class Interpreter :
public ExprFunctor<ObjectRef(const Expr& n)>, public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> { PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public: public:
Interpreter(Module mod, DLContext context, Target target) Interpreter(IRModule mod, DLContext context, Target target)
: mod_(mod), : mod_(mod),
context_(context), context_(context),
target_(target), target_(target),
...@@ -761,7 +761,7 @@ class Interpreter : ...@@ -761,7 +761,7 @@ class Interpreter :
private: private:
// Module // Module
Module mod_; IRModule mod_;
// For simplicity we only run the interpreter on a single context. // For simplicity we only run the interpreter on a single context.
// Context to run the interpreter on. // Context to run the interpreter on.
DLContext context_; DLContext context_;
...@@ -779,7 +779,7 @@ class Interpreter : ...@@ -779,7 +779,7 @@ class Interpreter :
TypedPackedFunc<ObjectRef(Expr)> TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter( CreateInterpreter(
Module mod, IRModule mod,
DLContext context, DLContext context,
Target target) { Target target) {
if (mod.defined()) { if (mod.defined()) {
......
...@@ -752,7 +752,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, ...@@ -752,7 +752,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "lower") { if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3); CHECK_EQ(args.num_args, 3);
Module mod = args[0]; IRModule mod = args[0];
this->Lower(mod, args[1], args[2]); this->Lower(mod, args[1], args[2]);
}); });
} else if (name == "codegen") { } else if (name == "codegen") {
...@@ -813,7 +813,7 @@ relay::Function VMCompiler::BindParamsByName( ...@@ -813,7 +813,7 @@ relay::Function VMCompiler::BindParamsByName(
return ret; return ret;
} }
void VMCompiler::Lower(Module mod, void VMCompiler::Lower(IRModule mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host) { const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1) CHECK_EQ(targets.size(), 1)
...@@ -884,7 +884,7 @@ void VMCompiler::Lower(Module mod, ...@@ -884,7 +884,7 @@ void VMCompiler::Lower(Module mod,
} }
} }
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs; Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}}; Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
......
...@@ -62,7 +62,7 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>; ...@@ -62,7 +62,7 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
struct VMCompilerContext { struct VMCompilerContext {
// The module context for the compilation // The module context for the compilation
Module module; IRModule module;
// Error reporter // Error reporter
ErrorReporter err_reporter; ErrorReporter err_reporter;
// Map from a unique integer to ADT constructor tag // Map from a unique integer to ADT constructor tag
...@@ -107,7 +107,7 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -107,7 +107,7 @@ class VMCompiler : public runtime::ModuleNode {
to target mapping. For homogeneous compilation, it is a build target. to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device. * \param target_host Host compilation target, if target is device.
*/ */
void Lower(Module mod, void Lower(IRModule mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host); const tvm::Target& target_host);
...@@ -125,7 +125,7 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -125,7 +125,7 @@ class VMCompiler : public runtime::ModuleNode {
relay::Function func, relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params); const std::unordered_map<std::string, runtime::NDArray>& params);
Module OptimizeModule(const Module& mod, const TargetsMap& targets); IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
void PopulateGlobalMap(); void PopulateGlobalMap();
......
...@@ -52,10 +52,10 @@ namespace vm { ...@@ -52,10 +52,10 @@ namespace vm {
* (fn(...) { ... })(...) * (fn(...) { ... })(...)
*/ */
struct PrimitiveInliner : ExprMutator { struct PrimitiveInliner : ExprMutator {
Module module_; IRModule module_;
std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_map; std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_map;
explicit PrimitiveInliner(const Module& module) : module_(module) {} explicit PrimitiveInliner(const IRModule& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) { Expr VisitExpr_(const LetNode* let_node) {
var_map.insert({let_node->var, VisitExpr(let_node->value)}); var_map.insert({let_node->var, VisitExpr(let_node->value)});
...@@ -106,7 +106,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -106,7 +106,7 @@ struct PrimitiveInliner : ExprMutator {
} }
} }
Module Inline() { IRModule Inline() {
auto gvar_funcs = module_->functions; auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) { for (auto pair : gvar_funcs) {
auto global = pair.first; auto global = pair.first;
...@@ -137,8 +137,8 @@ struct PrimitiveInliner : ExprMutator { ...@@ -137,8 +137,8 @@ struct PrimitiveInliner : ExprMutator {
namespace transform { namespace transform {
Pass InlinePrimitives() { Pass InlinePrimitives() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
return relay::vm::PrimitiveInliner(m).Inline(); return relay::vm::PrimitiveInliner(m).Inline();
}; };
auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {});
......
...@@ -60,7 +60,7 @@ Function MarkClosure(const Function& func) { ...@@ -60,7 +60,7 @@ Function MarkClosure(const Function& func) {
*/ */
class LambdaLifter : public ExprMutator { class LambdaLifter : public ExprMutator {
public: public:
explicit LambdaLifter(const Module& module) : module_(module) {} explicit LambdaLifter(const IRModule& module) : module_(module) {}
Expr VisitExpr_(const LetNode* let_node) final { Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false; bool is_lambda = false;
...@@ -184,7 +184,7 @@ class LambdaLifter : public ExprMutator { ...@@ -184,7 +184,7 @@ class LambdaLifter : public ExprMutator {
} }
} }
Module Lift() { IRModule Lift() {
// There is an ordering bug here. // There is an ordering bug here.
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) { for (auto pair : glob_funcs) {
...@@ -204,7 +204,7 @@ class LambdaLifter : public ExprMutator { ...@@ -204,7 +204,7 @@ class LambdaLifter : public ExprMutator {
private: private:
std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> lambda_map_; std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> lambda_map_;
std::vector<Var> letrec_; std::vector<Var> letrec_;
Module module_; IRModule module_;
}; };
} // namespace vm } // namespace vm
...@@ -212,8 +212,8 @@ class LambdaLifter : public ExprMutator { ...@@ -212,8 +212,8 @@ class LambdaLifter : public ExprMutator {
namespace transform { namespace transform {
Pass LambdaLift() { Pass LambdaLift() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
return relay::vm::LambdaLifter(m).Lift(); return relay::vm::LambdaLifter(m).Lift();
}; };
return CreateModulePass(pass_func, 1, "LambdaLift", {}); return CreateModulePass(pass_func, 1, "LambdaLift", {});
......
...@@ -40,7 +40,7 @@ namespace vm { ...@@ -40,7 +40,7 @@ namespace vm {
* \brief Detects all the functions that can be possibly called by entry function. * \brief Detects all the functions that can be possibly called by entry function.
*/ */
struct CallTracer : ExprVisitor { struct CallTracer : ExprVisitor {
Module module_; IRModule module_;
// Record the names of all encountered functions // Record the names of all encountered functions
std::unordered_set<std::string> called_funcs_; std::unordered_set<std::string> called_funcs_;
...@@ -48,7 +48,7 @@ struct CallTracer : ExprVisitor { ...@@ -48,7 +48,7 @@ struct CallTracer : ExprVisitor {
// Record the expressions that are being visited // Record the expressions that are being visited
std::unordered_set<Expr, ObjectHash, ObjectEqual> visiting_; std::unordered_set<Expr, ObjectHash, ObjectEqual> visiting_;
explicit CallTracer(const Module& module) explicit CallTracer(const IRModule& module)
: module_{module}, : module_{module},
called_funcs_{}, called_funcs_{},
visiting_{} {} visiting_{} {}
...@@ -99,7 +99,7 @@ struct CallTracer : ExprVisitor { ...@@ -99,7 +99,7 @@ struct CallTracer : ExprVisitor {
* *
* \return The module with dead functions removed. * \return The module with dead functions removed.
*/ */
Module RemoveUnusedFunctions(const Module& module, IRModule RemoveUnusedFunctions(const IRModule& module,
Array<tvm::PrimExpr> entry_funcs) { Array<tvm::PrimExpr> entry_funcs) {
std::unordered_set<std::string> called_funcs{}; std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) { for (auto entry : entry_funcs) {
...@@ -122,8 +122,8 @@ Module RemoveUnusedFunctions(const Module& module, ...@@ -122,8 +122,8 @@ Module RemoveUnusedFunctions(const Module& module,
namespace transform { namespace transform {
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) { Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions); return relay::vm::RemoveUnusedFunctions(m, entry_functions);
}; };
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
......
...@@ -60,8 +60,8 @@ class AlphaEqualHandler: ...@@ -60,8 +60,8 @@ class AlphaEqualHandler:
if (!rhs->IsInstance<ExprNode>()) return false; if (!rhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs)); return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
} }
if (const auto lhsm = lhs.as<ModuleNode>()) { if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<ModuleNode>(); auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false; if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false; if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) { for (const auto& p : lhsm->functions) {
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -39,7 +39,7 @@ void RelayErrorStream::Raise() const { ...@@ -39,7 +39,7 @@ void RelayErrorStream::Raise() const {
template<typename T, typename U> template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>; using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>;
void ErrorReporter::RenderErrors(const Module& module, bool use_color) { void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
// First we pick an error reporting strategy for each error. // First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported. // TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) { for (auto err : this->errors_) {
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include <tvm/node/serialization.h> #include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "doc.h" #include "doc.h"
#include "type_functor.h" #include "type_functor.h"
...@@ -242,8 +242,8 @@ class PrettyPrinter : ...@@ -242,8 +242,8 @@ class PrettyPrinter :
return PrintType(Downcast<Type>(node), meta); return PrintType(Downcast<Type>(node), meta);
} else if (node.as<PatternNode>()) { } else if (node.as<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta); return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as<ModuleNode>()) { } else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<Module>(node)); return PrintMod(Downcast<IRModule>(node));
} else { } else {
Doc doc; Doc doc;
return doc << node; return doc << node;
...@@ -525,7 +525,7 @@ class PrettyPrinter : ...@@ -525,7 +525,7 @@ class PrettyPrinter :
} }
} }
Doc PrintMod(const Module& mod) { Doc PrintMod(const IRModule& mod) {
Doc doc; Doc doc;
int counter = 0; int counter = 0;
// type definitions // type definitions
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#include <tvm/relay/error.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
......
...@@ -118,8 +118,8 @@ Expr AlterOpLayout(const Expr& expr) { ...@@ -118,8 +118,8 @@ Expr AlterOpLayout(const Expr& expr) {
namespace transform { namespace transform {
Pass AlterOpLayout() { Pass AlterOpLayout() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f)); return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
}; };
return CreateFunctionPass(pass_func, 3, "AlterOpLayout", return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
......
...@@ -129,8 +129,8 @@ Expr CanonicalizeCast(const Expr& e) { ...@@ -129,8 +129,8 @@ Expr CanonicalizeCast(const Expr& e) {
namespace transform { namespace transform {
Pass CanonicalizeCast() { Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f)); return Downcast<Function>(CanonicalizeCast(f));
}; };
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
......
...@@ -69,8 +69,8 @@ Expr CanonicalizeOps(const Expr& e) { ...@@ -69,8 +69,8 @@ Expr CanonicalizeOps(const Expr& e) {
namespace transform { namespace transform {
Pass CanonicalizeOps() { Pass CanonicalizeOps() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f)); return Downcast<Function>(CanonicalizeOps(f));
}; };
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
......
...@@ -216,8 +216,8 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { ...@@ -216,8 +216,8 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
namespace transform { namespace transform {
Pass CombineParallelConv2D(uint64_t min_num_branches) { Pass CombineParallelConv2D(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches)); return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
}; };
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
......
...@@ -76,8 +76,8 @@ Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { ...@@ -76,8 +76,8 @@ Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
namespace transform { namespace transform {
Pass CombineParallelDense(uint64_t min_num_branches) { Pass CombineParallelDense(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches)); return Downcast<Function>(CombineParallelDense(f, min_num_branches));
}; };
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
......
...@@ -186,8 +186,8 @@ namespace transform { ...@@ -186,8 +186,8 @@ namespace transform {
Pass CombineParallelOpBatch(const std::string& op_name, Pass CombineParallelOpBatch(const std::string& op_name,
const std::string& batch_op_name, const std::string& batch_op_name,
uint64_t min_num_branches) { uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelOpBatch(f, return Downcast<Function>(CombineParallelOpBatch(f,
op_name, op_name,
batch_op_name, batch_op_name,
......
...@@ -128,8 +128,8 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { ...@@ -128,8 +128,8 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) {
namespace transform { namespace transform {
Pass ConvertLayout(const std::string& desired_layout) { Pass ConvertLayout(const std::string& desired_layout) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout)); return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
}; };
return CreateFunctionPass( return CreateFunctionPass(
......
...@@ -140,8 +140,8 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { ...@@ -140,8 +140,8 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
namespace transform { namespace transform {
Pass DeadCodeElimination(bool inline_once) { Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f, inline_once)); return Downcast<Function>(DeadCodeElimination(f, inline_once));
}; };
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
......
...@@ -572,8 +572,8 @@ TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") ...@@ -572,8 +572,8 @@ TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps")
namespace transform { namespace transform {
Pass RewriteAnnotatedOps(int fallback_device) { Pass RewriteAnnotatedOps(int fallback_device) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device)); return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
}; };
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
......
...@@ -87,8 +87,8 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { ...@@ -87,8 +87,8 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
namespace transform { namespace transform {
Pass EliminateCommonSubexpr(PackedFunc fskip) { Pass EliminateCommonSubexpr(PackedFunc fskip) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip)); return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
}; };
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
......
...@@ -57,7 +57,7 @@ class TypeVarReplacer : public TypeMutator { ...@@ -57,7 +57,7 @@ class TypeVarReplacer : public TypeMutator {
*/ */
class EtaExpander : public ExprMutator { class EtaExpander : public ExprMutator {
public: public:
explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var) explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var)
: mod_(mod), : mod_(mod),
type_var_replacer_(TypeVarReplacer()), type_var_replacer_(TypeVarReplacer()),
expand_constructor_(expand_constructor), expand_constructor_(expand_constructor),
...@@ -66,7 +66,7 @@ class EtaExpander : public ExprMutator { ...@@ -66,7 +66,7 @@ class EtaExpander : public ExprMutator {
<< "must expand at least one language feature"; << "must expand at least one language feature";
} }
Module Expand() { IRModule Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) { for (GlobalVar global_var : mod_->GetGlobalVars()) {
const BaseFunc base_func = mod_->Lookup(global_var); const BaseFunc base_func = mod_->Lookup(global_var);
if (auto* n = base_func.as<FunctionNode>()) { if (auto* n = base_func.as<FunctionNode>()) {
...@@ -147,7 +147,7 @@ class EtaExpander : public ExprMutator { ...@@ -147,7 +147,7 @@ class EtaExpander : public ExprMutator {
private: private:
/*! \brief reference to module being expanded */ /*! \brief reference to module being expanded */
const Module mod_; const IRModule mod_;
/*! \brief type variable replacer */ /*! \brief type variable replacer */
TypeVarReplacer type_var_replacer_; TypeVarReplacer type_var_replacer_;
/*! \brief whether to expand constructor nodes */ /*! \brief whether to expand constructor nodes */
...@@ -161,8 +161,8 @@ class EtaExpander : public ExprMutator { ...@@ -161,8 +161,8 @@ class EtaExpander : public ExprMutator {
namespace transform { namespace transform {
Pass EtaExpand(bool expand_constructor, bool expand_global_var) { Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module mod, PassContext pc) { [=](IRModule mod, PassContext pc) {
return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
}; };
return CreateModulePass(pass_func, 1, "EtaExpand", {}); return CreateModulePass(pass_func, 1, "EtaExpand", {});
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include "pass_util.h" #include "pass_util.h"
namespace tvm { namespace tvm {
...@@ -89,7 +89,7 @@ FeatureSet DetectFeature(const Expr& expr) { ...@@ -89,7 +89,7 @@ FeatureSet DetectFeature(const Expr& expr) {
return fd.fs; return fd.fs;
} }
FeatureSet DetectFeature(const Module& mod) { FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No(); FeatureSet fs = FeatureSet::No();
if (mod.defined()) { if (mod.defined()) {
for (const auto& f : mod->functions) { for (const auto& f : mod->functions) {
...@@ -99,7 +99,7 @@ FeatureSet DetectFeature(const Module& mod) { ...@@ -99,7 +99,7 @@ FeatureSet DetectFeature(const Module& mod) {
return fs; return fs;
} }
Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) { Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
FeatureSet fs = DetectFeature(expr) + DetectFeature(mod); FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
return static_cast<Array<Integer>>(fs); return static_cast<Array<Integer>>(fs);
} }
......
...@@ -79,7 +79,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.check_constant") ...@@ -79,7 +79,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.check_constant")
// or make a more powerful partial evaluator. // or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator { class ConstantFolder : public ExprMutator {
public: public:
explicit ConstantFolder(FInterpreter executor, Module module) explicit ConstantFolder(FInterpreter executor, IRModule module)
: executor_(executor), : executor_(executor),
module_(module), module_(module),
shape_of_op_(Op::Get("shape_of")), shape_of_op_(Op::Get("shape_of")),
...@@ -168,7 +168,7 @@ class ConstantFolder : public ExprMutator { ...@@ -168,7 +168,7 @@ class ConstantFolder : public ExprMutator {
// Internal constant checker // Internal constant checker
ConstantChecker checker_; ConstantChecker checker_;
// Module // Module
Module module_; IRModule module_;
// Cache the following ops for equivalence checking in this pass. // Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_; const Op& shape_of_op_;
...@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator { ...@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator {
// TODO(@jroesch): fix this // TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
} }
auto mod = ModuleNode::make( auto mod = IRModule(
{}, {},
module_->type_definitions, module_->type_definitions,
module_->Imports()); module_->Imports());
...@@ -277,7 +277,7 @@ class ConstantFolder : public ExprMutator { ...@@ -277,7 +277,7 @@ class ConstantFolder : public ExprMutator {
}; };
Expr FoldConstant(const Expr& expr, const Module& mod) { Expr FoldConstant(const Expr& expr, const IRModule& mod) {
DLContext ctx; DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
...@@ -292,8 +292,8 @@ Expr FoldConstant(const Expr& expr, const Module& mod) { ...@@ -292,8 +292,8 @@ Expr FoldConstant(const Expr& expr, const Module& mod) {
namespace transform { namespace transform {
Pass FoldConstant() { Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f, m)); return Downcast<Function>(FoldConstant(f, m));
}; };
return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
......
...@@ -949,8 +949,8 @@ Expr BackwardFoldScaleAxis(const Expr& data) { ...@@ -949,8 +949,8 @@ Expr BackwardFoldScaleAxis(const Expr& data) {
namespace transform { namespace transform {
Pass ForwardFoldScaleAxis() { Pass ForwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>( return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f)); relay::fold_scale_axis::ForwardFoldScaleAxis(f));
}; };
...@@ -962,8 +962,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") ...@@ -962,8 +962,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
.set_body_typed(ForwardFoldScaleAxis); .set_body_typed(ForwardFoldScaleAxis);
Pass BackwardFoldScaleAxis() { Pass BackwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>( return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f)); relay::fold_scale_axis::BackwardFoldScaleAxis(f));
}; };
......
...@@ -970,15 +970,15 @@ class FuseMutator : private ExprMutator { ...@@ -970,15 +970,15 @@ class FuseMutator : private ExprMutator {
} }
}; };
Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) {
return FuseMutator().Transform(expr, fuse_opt_level); return FuseMutator().Transform(expr, fuse_opt_level);
} }
namespace transform { namespace transform {
Pass FuseOps(int fuse_opt_level) { Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m)); return Downcast<Function>(FuseOps(f, opt_level, m));
}; };
......
...@@ -67,7 +67,7 @@ Type WithGradientType(const Type&); ...@@ -67,7 +67,7 @@ Type WithGradientType(const Type&);
/*! return an expression that represent differentiation of e (according to WithGradientType). /*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow. * This version only work on first order code without control flow.
*/ */
Expr FirstOrderGradient(const Expr& e, const Module& mod); Expr FirstOrderGradient(const Expr& e, const IRModule& mod);
Type WithGradientType(const Type& t) { Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking // TODO(M.K.): stricter checking
...@@ -80,7 +80,7 @@ Type WithGradientType(const Type& t) { ...@@ -80,7 +80,7 @@ Type WithGradientType(const Type& t) {
} }
//! \brief if the expression is a GlobalVar, transform to it's expression. //! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) { Expr DeGlobal(const IRModule& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) { if (const auto* x = e.as<GlobalVarNode>()) {
BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x)); BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) { if (auto* n = base_func.as<FunctionNode>()) {
...@@ -222,7 +222,7 @@ Type GradRetType(const Function& f) { ...@@ -222,7 +222,7 @@ Type GradRetType(const Function& f) {
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
} }
Expr FirstOrderGradient(const Expr& re, const Module& mod) { Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
// Currently we first remove any global functions for the first // Currently we first remove any global functions for the first
// order case. // order case.
auto e = DeGlobal(mod, re); auto e = DeGlobal(mod, re);
...@@ -532,7 +532,7 @@ bool MissingGrad(const Expr& e) { ...@@ -532,7 +532,7 @@ bool MissingGrad(const Expr& e) {
return false; return false;
} }
Expr Gradient(const Expr& re, const Module& mod) { Expr Gradient(const Expr& re, const IRModule& mod) {
auto e = DeGlobal(mod, re); auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>(); auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function"; CHECK(f) << "input need to be a function";
......
...@@ -41,10 +41,10 @@ namespace relay { ...@@ -41,10 +41,10 @@ namespace relay {
using namespace tvm::runtime; using namespace tvm::runtime;
struct KindChecker : TypeFunctor<Kind(const Type&)> { struct KindChecker : TypeFunctor<Kind(const Type&)> {
const Module& mod; const IRModule& mod;
ErrorReporter err_reporter; ErrorReporter err_reporter;
explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {} explicit KindChecker(const IRModule& mod) : mod(mod), err_reporter() {}
void ReportFatalError(const Error& err) { void ReportFatalError(const Error& err) {
this->err_reporter.Report(err); this->err_reporter.Report(err);
...@@ -177,7 +177,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -177,7 +177,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
} }
}; };
Kind KindCheck(const Type& t, const Module& mod) { Kind KindCheck(const Type& t, const IRModule& mod) {
KindChecker kc(mod); KindChecker kc(mod);
return kc.Check(t); return kc.Check(t);
} }
...@@ -185,7 +185,7 @@ Kind KindCheck(const Type& t, const Module& mod) { ...@@ -185,7 +185,7 @@ Kind KindCheck(const Type& t, const Module& mod) {
TVM_REGISTER_GLOBAL("relay._analysis.check_kind") TVM_REGISTER_GLOBAL("relay._analysis.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], ModuleNode::make({}, {})); *ret = KindCheck(args[0], IRModule({}, {}));
} else { } else {
*ret = KindCheck(args[0], args[1]); *ret = KindCheck(args[0], args[1]);
} }
......
...@@ -98,8 +98,8 @@ Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { ...@@ -98,8 +98,8 @@ Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
namespace transform { namespace transform {
Pass Legalize(const std::string& legalize_map_attr_name) { Pass Legalize(const std::string& legalize_map_attr_name) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name)); return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
}; };
return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")}); return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")});
......
...@@ -155,17 +155,17 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) { ...@@ -155,17 +155,17 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand, const Pattern& cand,
const Module& mod); const IRModule& mod);
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand, const Pattern& cand,
const Module& mod); const IRModule& mod);
// Expands all wildcards in the candidate pattern once // Expands all wildcards in the candidate pattern once
// Returns a list of all possible expansions. // Returns a list of all possible expansions.
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const Pattern& cand, const Pattern& cand,
const Module& mod) { const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) { if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod); return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else { } else {
...@@ -178,7 +178,7 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, ...@@ -178,7 +178,7 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
// Returns a list of all possible expansions. // Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
const Pattern& cand, const Pattern& cand,
const Module& mod) { const IRModule& mod) {
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to); auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
// for a wildcard node, create constructor nodes with wildcards for all args. // for a wildcard node, create constructor nodes with wildcards for all args.
...@@ -228,7 +228,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, ...@@ -228,7 +228,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// Returns a list of all possible expansions. // Returns a list of all possible expansions.
Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
const Pattern& cand, const Pattern& cand,
const Module& mod) { const IRModule& mod) {
// for a wildcard node, create constructor nodes with wildcards for all args. // for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) { if (cand.as<PatternWildcardNode>()) {
Array<Pattern> args; Array<Pattern> args;
...@@ -271,7 +271,7 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple, ...@@ -271,7 +271,7 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
* \return Returns a list of cases that are not handled by the match * \return Returns a list of cases that are not handled by the match
* expression. * expression.
*/ */
Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) { Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
/* algorithm: /* algorithm:
* candidates = { Wildcard } * candidates = { Wildcard }
* while candidates not empty { * while candidates not empty {
...@@ -328,10 +328,10 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) { ...@@ -328,10 +328,10 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
// expose for testing only // expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases")
.set_body_typed( .set_body_typed(
[](const Match& match, const Module& mod_ref) { [](const Match& match, const IRModule& mod_ref) {
Module call_mod = mod_ref; IRModule call_mod = mod_ref;
if (!call_mod.defined()) { if (!call_mod.defined()) {
call_mod = ModuleNode::make({}, {}); call_mod = IRModule({}, {});
} }
return UnmatchedCases(match, call_mod); return UnmatchedCases(match, call_mod);
}); });
......
...@@ -569,7 +569,7 @@ FInterpreter CPUInterpreter() { ...@@ -569,7 +569,7 @@ FInterpreter CPUInterpreter() {
// in case we are already in a build context. // in case we are already in a build context.
With<BuildConfig> fresh_build_ctx(BuildConfig::Create()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return CreateInterpreter(Module(nullptr), CPUContext(), target); return CreateInterpreter(IRModule(nullptr), CPUContext(), target);
} }
using FuncId = int; using FuncId = int;
...@@ -623,7 +623,7 @@ Function AsFunc(const Expr& e) { ...@@ -623,7 +623,7 @@ Function AsFunc(const Expr& e) {
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>, class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> { public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public: public:
PartialEvaluator(const Module& mod) : mod_(mod) { } PartialEvaluator(const IRModule& mod) : mod_(mod) { }
PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll); PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
...@@ -954,7 +954,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -954,7 +954,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic ConstEvaluate(const Expr& expr, LetList* ll) { PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), std::vector<transform::Pass> passes = {transform::FuseOps(0),
transform::InferType()}; transform::InferType()};
auto mod = ModuleNode::FromExpr(expr); auto mod = IRModule::FromExpr(expr);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main")); auto entry_func = Downcast<Function>(mod->Lookup("main"));
...@@ -1184,7 +1184,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -1184,7 +1184,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
private: private:
Environment env_; Environment env_;
Module mod_; IRModule mod_;
std::unordered_map<GlobalVar, PStatic, ObjectHash, ObjectEqual> gv_map_; std::unordered_map<GlobalVar, PStatic, ObjectHash, ObjectEqual> gv_map_;
/*! Termination checking is done as follows: /*! Termination checking is done as follows:
* We have finitely many FunctionIds. * We have finitely many FunctionIds.
...@@ -1255,7 +1255,7 @@ Expr PostProcess(const Expr& e) { ...@@ -1255,7 +1255,7 @@ Expr PostProcess(const Expr& e) {
} // namespace partial_eval } // namespace partial_eval
Module PartialEval(const Module& m) { IRModule PartialEval(const IRModule& m) {
relay::partial_eval::PartialEvaluator pe(m); relay::partial_eval::PartialEvaluator pe(m);
std::vector<GlobalVar> gvs; std::vector<GlobalVar> gvs;
for (const auto& p : m->functions) { for (const auto& p : m->functions) {
...@@ -1270,9 +1270,9 @@ Module PartialEval(const Module& m) { ...@@ -1270,9 +1270,9 @@ Module PartialEval(const Module& m) {
namespace transform { namespace transform {
Pass PartialEval() { Pass PartialEval() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
return PartialEval(m); return relay::PartialEval(m);
}; };
return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
} }
......
...@@ -98,7 +98,7 @@ class ModulePassNode : public PassNode { ...@@ -98,7 +98,7 @@ class ModulePassNode : public PassNode {
* implement the algorithm in the `pass_func` and let it run on a module. It * implement the algorithm in the `pass_func` and let it run on a module. It
* will then remove the dead code including the unused functions in the module. * will then remove the dead code including the unused functions in the module.
*/ */
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func; runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func;
ModulePassNode() = default; ModulePassNode() = default;
...@@ -114,7 +114,7 @@ class ModulePassNode : public PassNode { ...@@ -114,7 +114,7 @@ class ModulePassNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
Module operator()(const Module& mod, const PassContext& pass_ctx) const final; IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
...@@ -122,7 +122,7 @@ class ModulePassNode : public PassNode { ...@@ -122,7 +122,7 @@ class ModulePassNode : public PassNode {
PassInfo Info() const override { return pass_info; } PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make( TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func, runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info); PassInfo pass_info);
static constexpr const char* _type_key = "relay.ModulePass"; static constexpr const char* _type_key = "relay.ModulePass";
...@@ -155,7 +155,7 @@ class FunctionPassNode : public PassNode { ...@@ -155,7 +155,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will * `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module. * then be applied on each function in the module.
*/ */
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func; runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;
FunctionPassNode() = default; FunctionPassNode() = default;
...@@ -171,7 +171,7 @@ class FunctionPassNode : public PassNode { ...@@ -171,7 +171,7 @@ class FunctionPassNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
Module operator()(const Module& mod, const PassContext& pass_ctx) const final; IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
...@@ -179,7 +179,7 @@ class FunctionPassNode : public PassNode { ...@@ -179,7 +179,7 @@ class FunctionPassNode : public PassNode {
PassInfo Info() const override { return pass_info; } PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make( TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func, runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info); PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass"; static constexpr const char* _type_key = "relay.FunctionPass";
...@@ -248,7 +248,7 @@ class SequentialNode : public PassNode { ...@@ -248,7 +248,7 @@ class SequentialNode : public PassNode {
* metadata, i.e. required_passes. Likely, we can have a data structure, i.e. * metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
* PassInfo, to store the relevant information including the parent passes. * PassInfo, to store the relevant information including the parent passes.
*/ */
void ResolveDependency(const Module& mod); void ResolveDependency(const IRModule& mod);
/*! /*!
* \brief Perform optimizations on a series of passes. The aforementioned * \brief Perform optimizations on a series of passes. The aforementioned
...@@ -261,7 +261,7 @@ class SequentialNode : public PassNode { ...@@ -261,7 +261,7 @@ class SequentialNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
Module operator()(const Module& mod, const PassContext& pass_ctx) const final; IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential"; static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
...@@ -278,7 +278,7 @@ PassInfo PassInfoNode::make(int opt_level, ...@@ -278,7 +278,7 @@ PassInfo PassInfoNode::make(int opt_level,
} }
ModulePass ModulePassNode::make( ModulePass ModulePassNode::make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func, runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) { PassInfo pass_info) {
auto n = make_object<ModulePassNode>(); auto n = make_object<ModulePassNode>();
n->pass_func = std::move(pass_func); n->pass_func = std::move(pass_func);
...@@ -287,7 +287,7 @@ ModulePass ModulePassNode::make( ...@@ -287,7 +287,7 @@ ModulePass ModulePassNode::make(
} }
// Module -> Module optimizations. // Module -> Module optimizations.
Module ModulePassNode::operator()(const Module& mod, IRModule ModulePassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : " DLOG(INFO) << "Executing module pass : "
...@@ -295,13 +295,13 @@ Module ModulePassNode::operator()(const Module& mod, ...@@ -295,13 +295,13 @@ Module ModulePassNode::operator()(const Module& mod,
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
CHECK(mod.defined()); CHECK(mod.defined());
Module updated_mod = pass_func(mod, pass_ctx); IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined()); CHECK(updated_mod.defined());
return updated_mod; return updated_mod;
} }
FunctionPass FunctionPassNode::make( FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func, runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) { PassInfo pass_info) {
auto n = make_object<FunctionPassNode>(); auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func); n->pass_func = std::move(pass_func);
...@@ -310,7 +310,7 @@ FunctionPass FunctionPassNode::make( ...@@ -310,7 +310,7 @@ FunctionPass FunctionPassNode::make(
} }
// Perform Module -> Module optimizations at the Function level. // Perform Module -> Module optimizations at the Function level.
Module FunctionPassNode::operator()(const Module& mod, IRModule FunctionPassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
CHECK(mod.defined()); CHECK(mod.defined());
...@@ -320,7 +320,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -320,7 +320,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->opt_level; << pass_info->opt_level;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports()); IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) { for (const auto& it : updated_mod->functions) {
// only picks up relay::Function // only picks up relay::Function
...@@ -364,7 +364,7 @@ const SequentialNode* Sequential::operator->() const { ...@@ -364,7 +364,7 @@ const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(get()); return static_cast<const SequentialNode*>(get());
} }
void SequentialNode::ResolveDependency(const Module& mod) { void SequentialNode::ResolveDependency(const IRModule& mod) {
// TODO(zhiics) Implement it. // TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass. // 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes. // 2. Only resolve the enabled passes.
...@@ -410,9 +410,9 @@ Pass GetPass(const std::string& pass_name) { ...@@ -410,9 +410,9 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in // TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future. // ordering problem needs to be handled in the future.
Module SequentialNode::operator()(const Module& module, IRModule SequentialNode::operator()(const IRModule& module,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
Module mod = module; IRModule mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info(); const PassInfo& pass_info = pass->Info();
...@@ -429,7 +429,7 @@ Module SequentialNode::operator()(const Module& module, ...@@ -429,7 +429,7 @@ Module SequentialNode::operator()(const Module& module,
} }
Pass CreateModulePass( Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func, const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) { const tvm::Array<tvm::PrimExpr>& required) {
...@@ -438,7 +438,7 @@ Pass CreateModulePass( ...@@ -438,7 +438,7 @@ Pass CreateModulePass(
} }
Pass CreateFunctionPass( Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func, const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) { const tvm::Array<tvm::PrimExpr>& required) {
...@@ -479,7 +479,7 @@ TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") ...@@ -479,7 +479,7 @@ TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass")
TVM_REGISTER_GLOBAL("relay._transform.RunPass") TVM_REGISTER_GLOBAL("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
Module mod = args[1]; IRModule mod = args[1];
*ret = pass(mod); *ret = pass(mod);
}); });
......
...@@ -32,8 +32,8 @@ namespace relay { ...@@ -32,8 +32,8 @@ namespace relay {
namespace transform { namespace transform {
Pass PrintIR(bool show_meta_data) { Pass PrintIR(bool show_meta_data) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data); LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
return m; return m;
}; };
......
...@@ -92,8 +92,8 @@ Pass QuantizeAnnotate() { ...@@ -92,8 +92,8 @@ Pass QuantizeAnnotate() {
return e; return e;
}; };
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params; auto new_params = func->params;
for (const auto& x : FreeVars(func)) { for (const auto& x : FreeVars(func)) {
......
...@@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") ...@@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr")
}); });
Pass QuantizePartition() { Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
auto ret = Downcast<Function>( auto ret = Downcast<Function>(
ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
return ret; return ret;
......
...@@ -190,7 +190,7 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -190,7 +190,7 @@ Expr QuantizeRealize(const Call& ref_call,
} }
Expr FoldConstantOpt(const Expr& expr) { Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr); auto mod = IRModule::FromExpr(expr);
mod = transform::FoldConstant()(mod); mod = transform::FoldConstant()(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main")); auto entry_func = Downcast<Function>(mod->Lookup("main"));
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
...@@ -522,8 +522,8 @@ RELAY_REGISTER_OP("annotation.cast_hint") ...@@ -522,8 +522,8 @@ RELAY_REGISTER_OP("annotation.cast_hint")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize); .set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
Pass QuantizeRealizePass() { Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>( return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
}; };
......
...@@ -183,8 +183,8 @@ Expr SimplifyInference(const Expr& e) { ...@@ -183,8 +183,8 @@ Expr SimplifyInference(const Expr& e) {
namespace transform { namespace transform {
Pass SimplifyInference() { Pass SimplifyInference() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f)); return Downcast<Function>(SimplifyInference(f));
}; };
return CreateFunctionPass(pass_func, 0, "SimplifyInference", return CreateFunctionPass(pass_func, 0, "SimplifyInference",
......
...@@ -291,7 +291,7 @@ Expr ToANormalFormAux(const Expr& e) { ...@@ -291,7 +291,7 @@ Expr ToANormalFormAux(const Expr& e) {
return Fill::ToANormalForm(e, dg, &node_scope); return Fill::ToANormalForm(e, dg, &node_scope);
} }
Module ToANormalForm(const Module& m) { IRModule ToANormalForm(const IRModule& m) {
DLOG(INFO) << "ToANF:" << std::endl << m; DLOG(INFO) << "ToANF:" << std::endl << m;
tvm::Map<GlobalVar, Function> updates; tvm::Map<GlobalVar, Function> updates;
...@@ -321,9 +321,9 @@ Module ToANormalForm(const Module& m) { ...@@ -321,9 +321,9 @@ Module ToANormalForm(const Module& m) {
namespace transform { namespace transform {
Pass ToANormalForm() { Pass ToANormalForm() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](Module m, PassContext pc) { [=](IRModule m, PassContext pc) {
return ToANormalForm(m); return relay::ToANormalForm(m);
}; };
return CreateModulePass(pass_func, 1, "ToANormalForm", {}); return CreateModulePass(pass_func, 1, "ToANormalForm", {});
} }
......
...@@ -111,21 +111,27 @@ using VarMap = std::unordered_map<Var, Var, ObjectHash, ObjectEqual>; ...@@ -111,21 +111,27 @@ using VarMap = std::unordered_map<Var, Var, ObjectHash, ObjectEqual>;
*/ */
using MCont = std::function<Expr(const Expr&)>; using MCont = std::function<Expr(const Expr&)>;
Function ToCPS(const Function& f, const Module& m, CPSMap* cm); Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm);
Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { Function ToCPS(const Function& f,
std::function<Var(Var)> remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; const IRModule& m,
CPSMap* cm,
VarMap* vm,
const TypeVar& answer) {
std::function<Var(Var)> remap = [&](const Var& v) {
return vm->count(v) == 0 ? v : vm->at(v);
};
auto function_type = Downcast<FuncType>(f->checked_type()); auto function_type = Downcast<FuncType>(f->checked_type());
// Each MCont can be used at most once. // Each MCont can be used at most once.
struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator { struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator {
CPSFunctor(const std::function<Var(Var)>& remap, CPSFunctor(const std::function<Var(Var)>& remap,
const TypeVar& answer, const TypeVar& answer,
const Module& m, const IRModule& m,
VarMap* vm, VarMap* vm,
CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { }
const std::function<Var(Var)>& remap; const std::function<Var(Var)>& remap;
TypeVar answer; TypeVar answer;
Module m; IRModule m;
VarMap* vm; VarMap* vm;
CPSMap* cm; CPSMap* cm;
...@@ -295,7 +301,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const ...@@ -295,7 +301,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
f->attrs); f->attrs);
} }
Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
TypeVar answer = TypeVarNode::make("answer", kType); TypeVar answer = TypeVarNode::make("answer", kType);
VarMap var; VarMap var;
struct Remapper : ExprVisitor, PatternVisitor { struct Remapper : ExprVisitor, PatternVisitor {
...@@ -325,7 +331,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { ...@@ -325,7 +331,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm) {
return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs);
} }
Function ToCPS(const Function& f, const Module& m) { Function ToCPS(const Function& f, const IRModule& m) {
CPSMap cps; CPSMap cps;
return ToCPS(f, m, &cps); return ToCPS(f, m, &cps);
} }
...@@ -368,7 +374,7 @@ Function UnCPS(const Function& f) { ...@@ -368,7 +374,7 @@ Function UnCPS(const Function& f) {
} }
TVM_REGISTER_GLOBAL("relay._transform.to_cps") TVM_REGISTER_GLOBAL("relay._transform.to_cps")
.set_body_typed(static_cast<Function (*)(const Function&, const Module&)>(ToCPS)); .set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS));
TVM_REGISTER_GLOBAL("relay._transform.un_cps") TVM_REGISTER_GLOBAL("relay._transform.un_cps")
.set_body_typed(UnCPS); .set_body_typed(UnCPS);
...@@ -376,8 +382,8 @@ TVM_REGISTER_GLOBAL("relay._transform.un_cps") ...@@ -376,8 +382,8 @@ TVM_REGISTER_GLOBAL("relay._transform.un_cps")
namespace transform { namespace transform {
Pass ToCPS() { Pass ToCPS() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Function(ToCPS(f, m)); return Function(ToCPS(f, m));
}; };
return CreateFunctionPass(pass_func, 1, "ToCPS", {}); return CreateFunctionPass(pass_func, 1, "ToCPS", {});
...@@ -388,8 +394,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ToCPS") ...@@ -388,8 +394,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ToCPS")
Pass UnCPS() { Pass UnCPS() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Function(UnCPS(f)); return Function(UnCPS(f));
}; };
return CreateFunctionPass(pass_func, 1, "UnCPS", {}); return CreateFunctionPass(pass_func, 1, "UnCPS", {});
......
...@@ -79,8 +79,8 @@ Expr ToGraphNormalForm(const Expr& e) { ...@@ -79,8 +79,8 @@ Expr ToGraphNormalForm(const Expr& e) {
namespace transform { namespace transform {
Pass ToGraphNormalForm() { Pass ToGraphNormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ToGraphNormalForm(f)); return Downcast<Function>(ToGraphNormalForm(f));
}; };
return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {});
......
...@@ -105,7 +105,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -105,7 +105,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
public: public:
// constructors // constructors
explicit TypeInferencer(Module mod, GlobalVar current_func) explicit TypeInferencer(IRModule mod, GlobalVar current_func)
: mod_(mod), current_func_(current_func), : mod_(mod), current_func_(current_func),
err_reporter(), solver_(current_func, mod, &this->err_reporter) { err_reporter(), solver_(current_func, mod, &this->err_reporter) {
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
...@@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// type resolver that maps back to type // type resolver that maps back to type
class Resolver; class Resolver;
// internal environment // internal environment
Module mod_; IRModule mod_;
// The current function being type checked. // The current function being type checked.
GlobalVar current_func_; GlobalVar current_func_;
...@@ -798,7 +798,7 @@ void EnsureCheckedType(const Expr& e) { ...@@ -798,7 +798,7 @@ void EnsureCheckedType(const Expr& e) {
AllCheckTypePopulated().VisitExpr(e); AllCheckTypePopulated().VisitExpr(e);
} }
Expr InferType(const Expr& expr, const Module& mod) { Expr InferType(const Expr& expr, const IRModule& mod) {
auto main = mod->GetGlobalVar("main"); auto main = mod->GetGlobalVar("main");
auto inferencer = TypeInferencer(mod, main); auto inferencer = TypeInferencer(mod, main);
auto e = inferencer.Infer(expr); auto e = inferencer.Infer(expr);
...@@ -811,7 +811,7 @@ Expr InferType(const Expr& expr, const Module& mod) { ...@@ -811,7 +811,7 @@ Expr InferType(const Expr& expr, const Module& mod) {
} }
Function InferType(const Function& func, Function InferType(const Function& func,
const Module& mod, const IRModule& mod,
const GlobalVar& var) { const GlobalVar& var) {
CHECK(mod.defined()) << "internal error: module must be set for type inference"; CHECK(mod.defined()) << "internal error: module must be set for type inference";
Function func_copy = Function(make_object<FunctionNode>(*func.operator->())); Function func_copy = Function(make_object<FunctionNode>(*func.operator->()));
...@@ -832,8 +832,8 @@ Function InferType(const Function& func, ...@@ -832,8 +832,8 @@ Function InferType(const Function& func,
namespace transform { namespace transform {
Pass InferType() { Pass InferType() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(InferType(f, m)); return Downcast<Function>(InferType(f, m));
}; };
return CreateFunctionPass(pass_func, 0, "InferType", {}); return CreateFunctionPass(pass_func, 0, "InferType", {});
......
...@@ -60,7 +60,7 @@ class TypeSolver::Reporter : public TypeReporterNode { ...@@ -60,7 +60,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
location = ref; location = ref;
} }
TVM_DLL Module GetModule() final { TVM_DLL IRModule GetModule() final {
return this->solver_->module_; return this->solver_->module_;
} }
...@@ -531,7 +531,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { ...@@ -531,7 +531,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
// constructor // constructor
TypeSolver::TypeSolver( TypeSolver::TypeSolver(
const GlobalVar& current_func, const GlobalVar& current_func,
const Module& module, const IRModule& module,
ErrorReporter* err_reporter) ErrorReporter* err_reporter)
: reporter_(make_object<Reporter>(this)), : reporter_(make_object<Reporter>(this)),
current_func(current_func), current_func(current_func),
...@@ -661,7 +661,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") ...@@ -661,7 +661,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter(); ErrorReporter *err_reporter = new ErrorReporter();
auto module = ModuleNode::make({}, {}); auto module = IRModule({}, {});
auto dummy_fn_name = GlobalVar("test"); auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter); auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
......
...@@ -62,7 +62,7 @@ using common::LinkedList; ...@@ -62,7 +62,7 @@ using common::LinkedList;
*/ */
class TypeSolver { class TypeSolver {
public: public:
TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter); TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter);
~TypeSolver(); ~TypeSolver();
/*! /*!
* \brief Add a type constraint to the solver. * \brief Add a type constraint to the solver.
...@@ -179,7 +179,7 @@ class TypeSolver { ...@@ -179,7 +179,7 @@ class TypeSolver {
/*! \brief Error reporting. */ /*! \brief Error reporting. */
ErrorReporter* err_reporter_; ErrorReporter* err_reporter_;
/*! \brief The module. */ /*! \brief The module. */
Module module_; IRModule module_;
/*! /*!
* \brief GetTypeNode that is corresponds to t. * \brief GetTypeNode that is corresponds to t.
......
...@@ -72,7 +72,7 @@ class TypeVarTVisitor : public TypeVisitor { ...@@ -72,7 +72,7 @@ class TypeVarTVisitor : public TypeVisitor {
class TypeVarEVisitor : private ExprVisitor { class TypeVarEVisitor : private ExprVisitor {
public: public:
explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {} explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
Array<TypeVar> CollectFree() { Array<TypeVar> CollectFree() {
Array<TypeVar> ret; Array<TypeVar> ret;
...@@ -156,7 +156,7 @@ class TypeVarEVisitor : private ExprVisitor { ...@@ -156,7 +156,7 @@ class TypeVarEVisitor : private ExprVisitor {
private: private:
InsertionSet<TypeVar> type_vars_; InsertionSet<TypeVar> type_vars_;
InsertionSet<TypeVar> bound_type_vars_; InsertionSet<TypeVar> bound_type_vars_;
const Module& mod_; const IRModule& mod_;
}; };
class VarVisitor : protected ExprVisitor, protected PatternVisitor { class VarVisitor : protected ExprVisitor, protected PatternVisitor {
...@@ -234,27 +234,27 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { ...@@ -234,27 +234,27 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
InsertionSet<Var> bound_vars_; InsertionSet<Var> bound_vars_;
}; };
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod) { tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).Free(expr); return TypeVarEVisitor(mod).Free(expr);
} }
tvm::Array<TypeVar> FreeTypeVars(const Type& type, const Module& mod) { tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).Free(type); return TypeVarEVisitor(mod).Free(type);
} }
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod) { tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).Bound(expr); return TypeVarEVisitor(mod).Bound(expr);
} }
tvm::Array<TypeVar> BoundTypeVars(const Type& type, const Module& mod) { tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).Bound(type); return TypeVarEVisitor(mod).Bound(type);
} }
tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod) { tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
return TypeVarEVisitor(mod).All(expr); return TypeVarEVisitor(mod).All(expr);
} }
tvm::Array<TypeVar> AllTypeVars(const Type& type, const Module& mod) { tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
return TypeVarEVisitor(mod).All(type); return TypeVarEVisitor(mod).All(type);
} }
...@@ -293,7 +293,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.all_vars") ...@@ -293,7 +293,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.all_vars")
TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
Module mod = args[1]; IRModule mod = args[1];
if (x.as<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x), mod); *ret = FreeTypeVars(Downcast<Type>(x), mod);
} else { } else {
...@@ -304,7 +304,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") ...@@ -304,7 +304,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
Module mod = args[1]; IRModule mod = args[1];
if (x.as<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x), mod); *ret = BoundTypeVars(Downcast<Type>(x), mod);
} else { } else {
...@@ -315,7 +315,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") ...@@ -315,7 +315,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
Module mod = args[1]; IRModule mod = args[1];
if (x.as<TypeNode>()) { if (x.as<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x), mod); *ret = AllTypeVars(Downcast<Type>(x), mod);
} else { } else {
......
...@@ -33,7 +33,7 @@ TEST(Relay, SelfReference) { ...@@ -33,7 +33,7 @@ TEST(Relay, SelfReference) {
auto y = relay::VarNode::make("y", tensor_type); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y }); auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = relay::ModuleNode::FromExpr(fx); auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod); mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main"); auto type_fx = mod->Lookup("main");
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
...@@ -73,7 +73,7 @@ TEST(Relay, Sequential) { ...@@ -73,7 +73,7 @@ TEST(Relay, Sequential) {
relay::transform::AlterOpLayout() relay::transform::AlterOpLayout()
}; };
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
auto mod = relay::ModuleNode::FromExpr(func); auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create(); auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3; pass_ctx->opt_level = 3;
pass_ctx->fallback_device = 1; pass_ctx->fallback_device = 1;
...@@ -100,7 +100,7 @@ TEST(Relay, Sequential) { ...@@ -100,7 +100,7 @@ TEST(Relay, Sequential) {
relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function. // Infer type for the expected function.
auto mod1 = relay::ModuleNode::FromExpr(expected_func); auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1); mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main"); auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected)); CHECK(relay::AlphaEqual(f, expected));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment