Commit 89a88c57 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] Start porting pass to the pass manager (#3191)

parent 7e648417
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/adt.h> #include <tvm/relay/adt.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, ...@@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
*/ */
TVM_DLL Kind KindCheck(const Type& t, const Module& mod); TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*! \brief Compare two expressions for structural equivalence. /*!
* \brief Compare two expressions for structural equivalence.
* *
* This comparison operator respects scoping and compares * This comparison operator respects scoping and compares
* expressions without regard to variable choice. * expressions without regard to variable choice.
...@@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); ...@@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
*/ */
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
/*! \brief Compare two types for structural equivalence. /*!
* \brief Compare two types for structural equivalence.
* *
* This comparison operator respects scoping and compares * This comparison operator respects scoping and compares
* expressions without regard to variable choice. * expressions without regard to variable choice.
...@@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
*/ */
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
/*! \brief Add abstraction over a function /*!
* \brief Add abstraction over a function
* *
* For example: `square` is transformed to * For example: `square` is transformed to
* `fun x -> square x`. * `fun x -> square x`.
...@@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
*/ */
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
/*! \brief Check that each Var is only bound once. /*!
* \brief Check that each Var is only bound once.
* *
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
* *
...@@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); ...@@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
*/ */
TVM_DLL bool WellFormed(const Expr& expr); TVM_DLL bool WellFormed(const Expr& expr);
/*! \brief Get all bound variables from expression expr. /*!
* \brief Get all bound variables from expression expr.
* *
* Bound variables are all variables that are declared in the expr. * Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it. * They only have meaning inside that expr, and can only be used in it.
...@@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr); ...@@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/ */
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr); TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get all bound variables from pattern pat. /*!
* \brief Get all bound variables from pattern pat.
* *
* Bound variables are all variables that got bound by the pat. * Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it. * They only have meaning inside that expr, and can only be used in it.
...@@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr); ...@@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
*/ */
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat); TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
/*! \brief Get free type parameters from expression expr. /*!
* \brief Get free type parameters from expression expr.
* *
* Free variables are variables that are not bound by a * Free variables are variables that are not bound by a
* let or a function parameter in the context. * let or a function parameter in the context.
...@@ -181,7 +189,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat); ...@@ -181,7 +189,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
*/ */
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr); TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get all variables from expression expr. /*!
* \brief Get all variables from expression expr.
* *
* \param expr the expression. * \param expr the expression.
* *
...@@ -189,7 +198,8 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr); ...@@ -189,7 +198,8 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
*/ */
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr); TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
/*! \brief Get free TypeVars from expression expr. /*!
* \brief Get free TypeVars from expression expr.
* *
* Free type parameters are type parameters that are not bound by a function * Free type parameters are type parameters that are not bound by a function
* type in the context. * type in the context.
...@@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr); ...@@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get free TypeVars from type t. /*!
* \brief Get free TypeVars from type t.
* *
* Free type parameters are type parameters that are not bound by a function * Free type parameters are type parameters that are not bound by a function
* type in the context. * type in the context.
...@@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod); ...@@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
/*! \brief Get all bound type variables from expression expr. /*!
* \brief Get all bound type variables from expression expr.
* *
* Bound variables are all type variables that are declared in the expr. * Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it. * They only have meaning inside that expr, and can only be used in it.
...@@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod); ...@@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all bound type variables from type t. /*!
* \brief Get all bound type variables from type t.
* *
* Bound variables are all type variables that are declared in the type. * Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it. * They only have meaning inside that type, and can only be used in it.
...@@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod); ...@@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
/*! \brief Get all type variables in expression expr. /*!
* \brief Get all type variables in expression expr.
* *
* \param expr the expression. * \param expr the expression.
* \param mod the module. * \param mod the module.
...@@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod); ...@@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all type variables in type t. /*!
* \brief Get all type variables in type t.
* *
* \param t the type. * \param t the type.
* \param mod the module. * \param mod the module.
...@@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e); ...@@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
*
* \param expr the expression to be optimized. * \param expr the expression to be optimized.
*
* \return The optimized expression. * \return The optimized expression.
*/ */
TVM_DLL Expr FoldConstant(const Expr& expr); TVM_DLL Expr FoldConstant(const Expr& expr);
/*! /*!
* \brief Fuse operations into expr into seperate functions. * \brief Fuse operations into expr into seperate functions.
*
* \param expr The expression. * \param expr The expression.
* \param fuse_opt_level Optimization level. * \param fuse_opt_level Optimization level.
* \param mod the module. * \param mod the module.
*
* \return The optimized expression. * \return The optimized expression.
*/ */
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
/*! /*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. * \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression. * \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function. * rule function.
...@@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); ...@@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
* \return The rewritten expression. * \return The rewritten expression.
*/ */
TVM_DLL Expr ForwardRewrite(const Expr& expr, TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name, const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! /*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. * \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression. * \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators. * \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node. * \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when * \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers. * an Expr consumed by multiple callers.
*
* \return The rewritten expression. * \return The rewritten expression.
*/ */
TVM_DLL Expr ForwardRewrite(const Expr& expr, TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func, const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr, std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*! /*!
* \brief Rewrite the annotated program. * \brief Rewrite the annotated program.
*
* \param expr The expression. * \param expr The expression.
* \param fallback_device The fallback device which is the default device for * \param fallback_device The fallback device which is the default device for
* operators without annotation. * operators without annotation.
*
* \return The updated program. * \return The updated program.
*/ */
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*! /*!
* \brief Collect the device mapping information of each expression. * \brief Collect the device mapping information of each expression.
*
* \param expr The expression. * \param expr The expression.
*
* \return The device mapping. * \return The device mapping.
*/ */
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr); TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
/*! \brief A hashing structure in the style of std::hash. */ /*!
struct StructuralHash { * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
* *
* It will turn an expression that is in a graph form (with sharing implicit), * It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form). * to an expression with explicit sharing (A-Normal Form).
* *
* The scope of the root expression is the global scope. * The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope. * The scope of any non root expression is the least common ancestor of all it's scope.
* *
* Values are ordered by post-DFS order in each scope. * Values are ordered by post-DFS order in each scope.
* *
* \param e the expression to observably share * \param e the expression to observably share.
*
* \param mod The module used for referencing global functions, can be * \param mod The module used for referencing global functions, can be
* None. * None.
* *
* \return expression in A-Normal Form * \return expression in A-Normal Form.
*/ */
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
/*! \brief Remove let binding and directly share via pointer instead. /*!
* \brief Remove let binding and directly share via pointer instead.
* *
* It will remove all let binding, * It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference. * and turn all of the variable bound by let into direct pointer reference.
...@@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); ...@@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/ */
TVM_DLL Expr ToGraphNormalForm(const Expr& e); TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*! \brief Aggressive constant propagation/constant folding/inlining. /*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible. * It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode. * As a side effect, code size will explode.
*
* \param e the expression,
*
* \return the optimized expression.
*/ */
Expr PartialEval(const Expr& e); TVM_DLL Expr PartialEval(const Expr& e);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};
namespace vm { namespace vm {
/*! \brief Compile a module, and construct the virtual machine. /*!
* \brief Compile a module, and construct the virtual machine.
* *
* \param mod The module to compile. * \param mod The module to compile.
*
* \return The constructed virtual machine. * \return The constructed virtual machine.
*/ */
runtime::vm::VirtualMachine CompileModule(const Module& mod); runtime::vm::VirtualMachine CompileModule(const Module& mod);
......
...@@ -61,6 +61,7 @@ ...@@ -61,6 +61,7 @@
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -198,7 +199,7 @@ class Pass; ...@@ -198,7 +199,7 @@ class Pass;
*/ */
class PassNode : public RelayNode { class PassNode : public RelayNode {
public: public:
/* /*!
* \brief Get the pass information/meta data. */ * \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0; virtual PassInfo Info() const = 0;
...@@ -300,11 +301,118 @@ Pass CreateModulePass( ...@@ -300,11 +301,118 @@ Pass CreateModulePass(
* *
* \return The created function pass. * \return The created function pass.
*/ */
Pass CreateFunctionPass( TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func, Function(Function, Module, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::Expr>& required); const tvm::Array<tvm::Expr>& required);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
/*!
* \brief Fold constant expressions.
*
* \return The pass.
*/
TVM_DLL Pass FoldConstant();
/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
*
* \return The pass.
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)>
fmulti_ref_trigger = nullptr);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
*
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The pass.
*/
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \return The pass.
*/
TVM_DLL Pass ToANormalForm();
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \return the expression in graph normal form.
*/
TVM_DLL Pass ToGraphNormalForm();
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \return the optimized expression.
*/
TVM_DLL Pass PartialEval();
} // namespace transform } // namespace transform
} // namespace relay } // namespace relay
......
...@@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) { ...@@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) {
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination); .set_body_typed(DeadCodeElimination);
namespace transform {
Pass DeadCodeElimination() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
};
return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") ...@@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
.set_body_typed(CollectDeviceAnnotationOps); .set_body_typed(CollectDeviceAnnotationOps);
namespace transform {
Pass RewriteAnnotatedOps(int fallback_device) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) { ...@@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) {
TVM_REGISTER_API("relay._ir_pass.FoldConstant") TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body_typed(FoldConstant); .set_body_typed(FoldConstant);
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 1, "fold_constant", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr, ...@@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr,
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
} }
namespace transform {
using std::function;
Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
function<NodeRef(const Call&)> fcontext,
function<Expr(const Expr&)> fmulti_ref_trigger) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ForwardRewrite(f,
rewrite_map_attr_name,
fcontext,
fmulti_ref_trigger));
};
return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
}
Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
function<NodeRef(const Call&)> fcontext,
function<Expr(const Expr&)> fmulti_ref_trigger) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ForwardRewrite(f,
rewrite_func,
fcontext,
fmulti_ref_trigger));
};
return CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { ...@@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
TVM_REGISTER_API("relay._ir_pass.FuseOps") TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body_typed(FuseOps); .set_body_typed(FuseOps);
namespace transform {
Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
return CreateFunctionPass(pass_func, 1, "fuse_ops", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate") ...@@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
*ret = PartialEval(args[0]); *ret = PartialEval(args[0]);
}); });
namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f));
};
return CreateFunctionPass(pass_func, 1, "partial_eval", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode { ...@@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will * `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module. * then be applied on each function in the module.
*/ */
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func; runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
FunctionPassNode() = default; FunctionPassNode() = default;
...@@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode { ...@@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode {
PassInfo Info() const { return pass_info; } PassInfo Info() const { return pass_info; }
TVM_DLL static FunctionPass make( TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func, runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
PassInfo pass_info); PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass"; static constexpr const char* _type_key = "relay.FunctionPass";
...@@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod, ...@@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod,
} }
FunctionPass FunctionPassNode::make( FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func, runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
PassInfo pass_info) { PassInfo pass_info) {
auto n = make_node<FunctionPassNode>(); auto n = make_node<FunctionPassNode>();
n->pass_func = std::move(pass_func); n->pass_func = std::move(pass_func);
...@@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod,
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
auto updated_func = auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx);
new_mod->Add(it.first, updated_func); new_mod->Add(it.first, updated_func);
} }
...@@ -501,7 +500,7 @@ Pass CreateModulePass( ...@@ -501,7 +500,7 @@ Pass CreateModulePass(
} }
Pass CreateFunctionPass( Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func, const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::Expr>& required) { const tvm::Array<tvm::Expr>& required) {
...@@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
const PassInfoNode* seq_pn = node->Info().operator->(); const PassInfoNode* seq_pn = node->Info().operator->();
p->stream << "Run Sequential pass: " << seq_pn->name p->stream << "Run Sequential pass: " << seq_pn->name
<< " at the optimization level. " << seq_pn->opt_level; << " at the optimization level " << seq_pn->opt_level << ". ";
p->stream << "The passes will be executed are: ["; p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) { for (const auto& it : node->passes) {
const PassNode* pn = it.operator->(); const PassNode* pn = it.operator->();
......
...@@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) { ...@@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm)); .set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
namespace transform {
Pass ToANormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ToANormalForm(f, m));
};
return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) { ...@@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) {
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body_typed(ToGraphNormalForm); .set_body_typed(ToGraphNormalForm);
namespace transform {
Pass ToGraphNormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ToGraphNormalForm(f));
};
return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {});
}
} // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -204,7 +204,7 @@ def test_function_pass(): ...@@ -204,7 +204,7 @@ def test_function_pass():
pass_ctx = None pass_ctx = None
@_transform.function_pass(opt_level=opt_level, name=pass_name) @_transform.function_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, mod, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
def get_ref_log(): def get_ref_log():
...@@ -303,7 +303,7 @@ def test_sequential_pass(): ...@@ -303,7 +303,7 @@ def test_sequential_pass():
# Register a function pass. # Register a function pass.
@_transform.function_pass(opt_level=1) @_transform.function_pass(opt_level=1)
def func_transform(expr, ctx): def func_transform(expr, mod, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
function_pass = func_transform function_pass = func_transform
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment