Commit f2a6851a by Zhi Committed by Jared Roesch

[Relay][Pass] Only allow Module -> Module for opts managed by pass infra (#3430)

* [Relay][Pass] Only allow Module -> Module for opts managed by pass infra

* revert gradient pass
parent 6c81d784
...@@ -141,23 +141,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -141,23 +141,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*! /*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param e The original function.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return the new function with abstraction
*/
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
/*!
* \brief Check that each Var is only bound once. * \brief Check that each Var is only bound once.
* *
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
...@@ -288,24 +271,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod); ...@@ -288,24 +271,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
* *
...@@ -388,38 +353,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr); ...@@ -388,38 +353,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr); TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
/*! /*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form.
*/
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*!
* \brief Finds cases that the given match expression does not catch, if any. * \brief Finds cases that the given match expression does not catch, if any.
* *
* \param match the match expression to test * \param match the match expression to test
...@@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); ...@@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod); TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
/*! /*!
* \brief Aggressive constant propagation/constant folding/inlining. * \brief Bind the free variables to a Relay expression.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*
* \brief Bind function parameters or free variables.
* *
* Parameter binding can only happen if expr is a Function. * Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions. * binds cannot change internal arguments of internal functions.
* *
* \param expr The function to be binded. * \param expr The function to be binded.
* \param binds The map of arguments to * \param binds The map of arguments to
*
* \return The expression with all free vars bound.
*/ */
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map); TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
/*! \brief A hashing structure in the style of std::hash. */ /*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash { struct StructuralHash {
......
...@@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout(); ...@@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout();
*/ */
TVM_DLL Pass CanonicalizeCast(); TVM_DLL Pass CanonicalizeCast();
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();
/*!
* \brief This is a helper function that runs a some optimization passes on
* a certain expression and returns the optimized version. With the help of this
* function, users don't need to manually construct a module, then perform
* passes, and finally and extract the target function/expression from the
* returned module frequently.
*
* \param expr The expression to be optimized.
* \param passes The passses that will be applied on the given expression.
*
* \return The optimized expression.
*/
TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);
} // namespace transform } // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr): ...@@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr):
""" """
return _ir_pass.backward_fold_scale_axis(expr) return _ir_pass.backward_fold_scale_axis(expr)
def eta_expand(expr, mod):
"""Add abstraction over a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
mod : tvm.relay.Module
The global module.
Returns
-------
expanded_expr : tvm.relay.Expr
The expression after eta expansion.
"""
return _ir_pass.eta_expand(expr, mod)
def forward_fold_scale_axis(expr): def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense. """Fold the scaling of axis into weights of conv2d/dense.
...@@ -318,25 +301,6 @@ def canonicalize_ops(expr): ...@@ -318,25 +301,6 @@ def canonicalize_ops(expr):
return _ir_pass.canonicalize_ops(expr) return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs): def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence). """Compare two Relay expr for structural equivalence (alpha equivalence).
...@@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr): ...@@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr) return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_a_normal_form(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
def gradient(expr, mod=None, mode='higher_order'): def gradient(expr, mod=None, mode='higher_order'):
""" """
Transform the input function, Transform the input function,
...@@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None): ...@@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None):
return _ir_pass.eliminate_common_subexpr(expr, fskip) return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None): def unmatched_cases(match, mod=None):
""" """
Finds cases that the match expression does not catch, if any. Finds cases that the match expression does not catch, if any.
......
...@@ -302,15 +302,20 @@ def CanonicalizeOps(): ...@@ -302,15 +302,20 @@ def CanonicalizeOps():
return _transform.CanonicalizeOps() return _transform.CanonicalizeOps()
def DeadCodeElimination(): def DeadCodeElimination(inline_once=False):
""" Remove expressions which does not effect the program result (dead code). """Remove expressions which does not effect the program result (dead code).
Parameters
----------
inline_once: Optional[Bool]
Whether to inline binding that occurs only once.
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.relay.Pass
The registered pass that eliminates the dead code in a Relay program. The registered pass that eliminates the dead code in a Relay program.
""" """
return _transform.DeadCodeElimination() return _transform.DeadCodeElimination(inline_once)
def FoldConstant(): def FoldConstant():
...@@ -406,6 +411,7 @@ def ToANormalForm(): ...@@ -406,6 +411,7 @@ def ToANormalForm():
""" """
return _transform.ToANormalForm() return _transform.ToANormalForm()
def EtaExpand(): def EtaExpand():
"""Add abstraction over a function """Add abstraction over a function
...@@ -416,6 +422,7 @@ def EtaExpand(): ...@@ -416,6 +422,7 @@ def EtaExpand():
""" """
return _transform.EtaExpand() return _transform.EtaExpand()
def ToGraphNormalForm(): def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression """Turn A Normal Form expression into Graph Normal Form expression
...@@ -449,7 +456,7 @@ def PartialEvaluate(): ...@@ -449,7 +456,7 @@ def PartialEvaluate():
Returns Returns
------- -------
ret : tvm.relay.Pass ret: tvm.relay.Pass
The registered pass that performs partial evaluation on an expression. The registered pass that performs partial evaluation on an expression.
""" """
return _transform.PartialEvaluate() return _transform.PartialEvaluate()
...@@ -465,6 +472,31 @@ def CanonicalizeCast(): ...@@ -465,6 +472,31 @@ def CanonicalizeCast():
""" """
return _transform.CanonicalizeCast() return _transform.CanonicalizeCast()
def OptimizeOnExpr(expr, passes):
"""Perform optimization passes on an expressioin.
Parameters
----------
expr: tvm.relay.Expr
The expression for optimization.
passes: Union[Pass, List[Pass]]
The list of optimizations to be applied.
Returns
-------
ret: tvm.relay.Expr
The optimized expression.
"""
if isinstance(passes, Pass):
passes = [passes]
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a pass or a list of pass objects.")
return _transform.OptimizeOnExpr(expr, passes)
def _wrap_class_module_pass(pass_cls, pass_info): def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass""" """Wrap a python class as function pass"""
class PyModulePass(ModulePass): class PyModulePass(ModulePass):
......
...@@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { ...@@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once); return CalcDep::Eliminate(e, inline_once);
} }
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);
namespace transform { namespace transform {
Pass DeadCodeElimination(bool inline_once) { Pass DeadCodeElimination(bool inline_once) {
......
...@@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) { ...@@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) {
} // namespace partial_eval } // namespace partial_eval
Expr PartialEval(const Expr& e, const Module& m) { Module PartialEval(const Module& m) {
return TransformF([&](const Expr& e) { CHECK(m->entry_func.defined());
auto func = m->Lookup(m->entry_func);
Expr ret =
TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) { return LetList::With([&](LetList* ll) {
relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
pe.InitializeFuncId(e); pe.InitializeFuncId(e);
return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
}); });
}, e); }, func);
CHECK(ret->is_type<FunctionNode>());
m->Update(m->entry_func, Downcast<Function>(ret));
return m;
} }
TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
.set_body_typed(PartialEval);
namespace transform { namespace transform {
Pass PartialEval() { Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f, m)); return PartialEval(m);
}; };
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
} }
TVM_REGISTER_API("relay._transform.PartialEvaluate") TVM_REGISTER_API("relay._transform.PartialEvaluate")
......
...@@ -573,6 +573,18 @@ class PassContext::Internal { ...@@ -573,6 +573,18 @@ class PassContext::Internal {
} }
}; };
Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes) {
auto mod = ModuleNode::FromExpr(expr);
Sequential seq(passes);
auto pass_ctx = PassContext::Create();
pass_ctx->opt_level = 3;
tvm::With<PassContext> ctx_scope(pass_ctx);
mod = seq(mod);
CHECK(mod.defined());
auto entry_func = mod->Lookup(mod->entry_func);
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
TVM_REGISTER_API("relay._transform.GetCurrentPassContext") TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
.set_body_typed(PassContext::Current); .set_body_typed(PassContext::Current);
...@@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext") ...@@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext")
TVM_REGISTER_API("relay._transform.ExitPassContext") TVM_REGISTER_API("relay._transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope); .set_body_typed(PassContext::Internal::ExitScope);
TVM_REGISTER_API("relay._transform.OptimizeOnExpr")
.set_body_typed(OptimizeOnExpr);
} // namespace transform } // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h> #include <tvm/logging.h>
#include "let_list.h" #include "let_list.h"
#include "../../common/arena.h" #include "../../common/arena.h"
...@@ -35,10 +37,6 @@ ...@@ -35,10 +37,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
struct ScopeNode; struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>; using Scope = std::shared_ptr<ScopeNode>;
...@@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) { ...@@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) {
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public: public:
static Expr ToANormalForm(const Expr& e, static Expr ToANormalForm(const Expr& e,
const Module& m,
const DependencyGraph& dg, const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) { Fill fi(dg, node_scope);
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
} }
private: private:
Module mod_;
const DependencyGraph& dg_; const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_; std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo; std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod, Fill(const DependencyGraph& dg,
const DependencyGraph& dg, std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) :
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod),
dg_(dg), dg_(dg),
node_scope_(node_scope), node_scope_(node_scope) { }
visited_(visited) { }
Scope GetScope(const Expr& e) { Scope GetScope(const Expr& e) {
return node_scope_->at(dg_.expr_node.at(e)); return node_scope_->at(dg_.expr_node.at(e));
...@@ -246,10 +236,6 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -246,10 +236,6 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
GlobalVar gv = GetRef<GlobalVar>(gvn); GlobalVar gv = GetRef<GlobalVar>(gvn);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return Atomic(gv, gv, v); return Atomic(gv, gv, v);
} }
...@@ -276,9 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -276,9 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
} }
}; };
Expr ToANormalFormAux(const Expr& e, Expr ToANormalFormAux(const Expr& e) {
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
/* When you lift a lambda, what is inside is also being lift. /* When you lift a lambda, what is inside is also being lift.
* *
* So we must determine the scope of the lambda before determining the scope of it's body. * So we must determine the scope of the lambda before determining the scope of it's body.
...@@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e, ...@@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e,
* We do an additional pass to fill all the LetList and we are done. * We do an additional pass to fill all the LetList and we are done.
*/ */
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg); std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANormalForm(e, m, dg, &node_scope, gv); return Fill::ToANormalForm(e, dg, &node_scope);
} }
Expr ToANormalForm(const Expr& e, Module ToANormalForm(const Module& m) {
const Module& m, DLOG(INFO) << "ToANF:" << std::endl << m;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
DLOG(INFO)
<< "ToANF:" << std::endl
<< AsText(e, false);
tvm::Map<GlobalVar, Function> updates;
auto funcs = m->functions;
for (const auto& it : funcs) {
Expr ret = Expr ret =
TransformF([&](const Expr& e) { TransformF([&](const Expr& e) {
return ToANormalFormAux(e, m, gv); return ToANormalFormAux(e);
}, e); }, it.second);
CHECK_EQ(FreeVars(ret).size(), 0); CHECK_EQ(FreeVars(ret).size(), 0);
updates.Set(it.first, Downcast<Function>(ret));
}
DLOG(INFO) for (auto pair : updates) {
<< "ToANF: transformed" << std::endl m->Add(pair.first, pair.second, true);
<< AsText(ret, false); }
return ret; DLOG(INFO) << "ToANF: transformed" << std::endl << m;
}
Expr ToANormalForm(const Expr& e, const Module& m) { return m;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv);
} }
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
namespace transform { namespace transform {
Pass ToANormalForm() { Pass ToANormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Module m, PassContext pc) {
return Downcast<Function>(ToANormalForm(f, m)); return ToANormalForm(m);
}; };
return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); return CreateModulePass(pass_func, 1, "ToANormalForm", {});
} }
TVM_REGISTER_API("relay._transform.ToANormalForm") TVM_REGISTER_API("relay._transform.ToANormalForm")
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
* *
* \brief Turn A normal form into graph normal form. * \brief Turn A normal form into graph normal form.
*/ */
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "let_list.h" #include "let_list.h"
namespace tvm { namespace tvm {
...@@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) { ...@@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) {
return GNF()(e); return GNF()(e);
} }
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body_typed(ToGraphNormalForm);
namespace transform { namespace transform {
Pass ToGraphNormalForm() { Pass ToGraphNormalForm() {
......
...@@ -18,20 +18,13 @@ from nose.tools import nottest ...@@ -18,20 +18,13 @@ from nose.tools import nottest
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal from tvm.relay import Function, transform
from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars
from tvm.relay.op import log, add, equal, subtract from tvm.relay.op import log, add, equal, subtract
class env: class env:
def __init__(self): def __init__(self):
self.a = relay.Var("a")
self.b = relay.Var("b")
self.c = relay.Var("c")
self.d = relay.Var("d")
self.e = relay.Var("e")
self.x = relay.Var("x")
self.y = relay.Var("y")
self.z = relay.Var("z")
self.shape = tvm.convert([1, 2, 3]) self.shape = tvm.convert([1, 2, 3])
self.tt = relay.TensorType(self.shape, "float32") self.tt = relay.TensorType(self.shape, "float32")
self.int32 = relay.TensorType([], "int32") self.int32 = relay.TensorType([], "int32")
...@@ -39,6 +32,14 @@ class env: ...@@ -39,6 +32,14 @@ class env:
self.one = relay.const(1.0) self.one = relay.const(1.0)
self.two = relay.const(2.0) self.two = relay.const(2.0)
self.three = relay.const(3.0) self.three = relay.const(3.0)
self.a = relay.Var("a", self.float32)
self.b = relay.Var("b", self.float32)
self.c = relay.Var("c", self.float32)
self.d = relay.Var("d", self.float32)
self.e = relay.Var("e", self.float32)
self.x = relay.Var("x", self.int32)
self.y = relay.Var("y", self.int32)
self.z = relay.Var("z", self.int32)
e = env() e = env()
...@@ -46,22 +47,27 @@ e = env() ...@@ -46,22 +47,27 @@ e = env()
def test_let(): def test_let():
orig = relay.Let(e.x, e.y, e.z) orig = relay.Let(e.x, e.y, e.z)
assert alpha_equal(dead_code_elimination(orig), e.z) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_used_let(): def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c) orig = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c)) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
@nottest @nottest
def test_inline(): def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), e.d) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let(): def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
assert alpha_equal(dead_code_elimination(orig), e.e) orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
# make sure we dont infinite loop # make sure we dont infinite loop
...@@ -78,27 +84,39 @@ def test_recursion(): ...@@ -78,27 +84,39 @@ def test_recursion():
f(2, 10000); f(2, 10000);
""" """
f = relay.Var("f") f = relay.Var("f")
f1 = relay.Var("f1")
n = relay.Var("n", e.int32) n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32) data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)), funcbody = relay.If(equal(n, relay.const(0)),
data, data,
relay.Call(f, [subtract(n, relay.const(1.0)), relay.Call(f1, [subtract(n, relay.const(1)),
log(data)])) log(data)]))
value = relay.Function([n, data], funcbody, e.float32, []) value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) orig = transform.OptimizeOnExpr(orig, transform.InferType())
assert graph_equal(dced, orig)
dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three),
transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
def test_op_let(): def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination())
assert alpha_equal(dced, add(e.three, e.two))
def test_tuple_get_item(): def test_tuple_get_item():
t = relay.Var('t') tt = relay.TupleType([e.float32, e.float32])
t = relay.Var('t', tt)
a = relay.Var('a')
g = relay.TupleGetItem(t, 0) g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g) dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination())
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -18,17 +18,13 @@ ...@@ -18,17 +18,13 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination from tvm.relay.ir_pass import alpha_equal, gradient
from tvm.relay.ir_pass import gradient
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay import create_executor from tvm.relay import op, create_executor, transform
from nose.tools import nottest
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call, Type from tvm.relay import GlobalVar, Call
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr from tvm.relay.testing import add_nat_definitions, make_nat_expr
def check_eval(expr, expected_result, mod=None, rtol=1e-07): def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0) ctx = tvm.context("llvm", 0)
...@@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): ...@@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr, mod=None): def tipe(expr):
return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True) return transform.OptimizeOnExpr(expr,
[transform.InferType(),
transform.PartialEvaluate(),
transform.InferType()])
def dcpe(expr, mod=None, grad=False):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
if grad:
expr = gradient(expr)
if mod:
assert isinstance(expr, Function)
mod[mod.entry_func] = expr
seq = transform.Sequential(passes)
mod = seq(mod)
return mod[mod.entry_func]
return transform.OptimizeOnExpr(expr, passes)
def test_tuple(): def test_tuple():
...@@ -47,24 +60,31 @@ def test_tuple(): ...@@ -47,24 +60,31 @@ def test_tuple():
x = Var("x", t) x = Var("x", t)
body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = Function([x], body, None, [t]) f = Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t])) expected = relay.Function([x], x, None, [t])
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(dcpe(f), expected)
def test_const_inline(): def test_const_inline():
d = Var("d") t = relay.TensorType([], "float32")
d = Var("d", t)
double = Function([d], d + d) double = Function([d], d + d)
orig = double(const(4.0)) orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0)) assert alpha_equal(dcpe(orig), const(8.0))
def test_ref(): def test_ref():
d = relay.Var("d") t = relay.TensorType([], "float32")
r = relay.Var("r") d = relay.Var("d", t)
r = relay.Var("r", relay.RefType(t))
x = relay.Var("x") x = relay.Var("x")
body = relay.RefRead(r) body = relay.RefRead(r)
body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
body = Let(r, RefCreate(d), body) body = Let(r, RefCreate(d), body)
square = Function([d], body) square = Function([d], body)
assert alpha_equal(dcpe(square), Function([d], d * d)) expected = transform.OptimizeOnExpr(Function([d], d * d),
transform.InferType())
assert alpha_equal(dcpe(square), expected)
def test_empty_ad(): def test_empty_ad():
...@@ -73,17 +93,19 @@ def test_empty_ad(): ...@@ -73,17 +93,19 @@ def test_empty_ad():
t = TensorType(shape, dtype) t = TensorType(shape, dtype)
d = Var("d", t) d = Var("d", t)
f = Function([d], d) f = Function([d], d)
g = dcpe(gradient(f)) g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(g, expected) assert alpha_equal(g, expected)
def test_ad(): def test_ad():
shape = (10, 10) shape = (10, 10)
dtype = "float32" dtype = "float32"
t = TensorType(shape, dtype) t = TensorType(shape, dtype)
d = Var("d", t) d = Var("d", t)
f = Function([d], d * d) f = Function([d], d * d)
g = dcpe(gradient(f)) g = dcpe(f, grad=True)
m = d * d m = d * d
x = relay.Var("x") x = relay.Var("x")
o = op.ones_like(x) o = op.ones_like(x)
...@@ -92,6 +114,7 @@ def test_ad(): ...@@ -92,6 +114,7 @@ def test_ad():
body = Tuple([x, Tuple([grad])]) body = Tuple([x, Tuple([grad])])
body = relay.Let(x1, o, body) body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body)) expected = Function([d], relay.Let(x, m, body))
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(g, expected) assert alpha_equal(g, expected)
...@@ -107,8 +130,7 @@ def test_if_ref(): ...@@ -107,8 +130,7 @@ def test_if_ref():
eff = Var("eff") eff = Var("eff")
body = Let(eff, body, RefRead(r)) body = Let(eff, body, RefRead(r))
f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
f = infer_type(f) pe_f = tipe(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor() ex = create_executor()
f_res = ex.evaluate(f)(const(True)) f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True))
...@@ -132,8 +154,7 @@ def test_function_invalidate(): ...@@ -132,8 +154,7 @@ def test_function_invalidate():
body = Let(fet, fetch, body) body = Let(fet, fetch, body)
body = Let(r, RefCreate(const(0)), body) body = Let(r, RefCreate(const(0)), body)
f = Function([d], body) f = Function([d], body)
f = infer_type(f) pe_f = tipe(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor() ex = create_executor()
f_res = ex.evaluate(f)(const(True)) f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True))
...@@ -144,35 +165,30 @@ def test_function_invalidate(): ...@@ -144,35 +165,30 @@ def test_function_invalidate():
def test_head_cons(): def test_head_cons():
mod = Module() mod = Module()
p = Prelude(mod) p = Prelude(mod)
def hd_impl(): hd = p.hd
a = TypeVar("a")
x = Var("x", p.l(a))
y = Var("y")
z = Var("z")
cons_case = Clause(PatternConstructor(p.cons,
[PatternVar(y),
PatternVar(z)]),
y)
y = Var("y")
z = Var("z")
return Function([x], Match(x, [cons_case]), a, [a])
t = TypeVar("t") t = TypeVar("t")
x = Var("x", t) x = Var("x", t)
hd = Var("hd") body = hd(p.cons(x, p.nil()))
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = Function([x], body, None, [t]) f = Function([x], body, None, [t])
f = infer_type(f, mod=mod) res = dcpe(f, mod)
res = dcpe(f)
assert alpha_equal(res, Function([x], x, t, [t])) assert alpha_equal(res, Function([x], x, t, [t]))
def test_map(): def test_map():
mod = Module() mod = Module()
p = Prelude(mod) p = Prelude(mod)
f = Var("f") f = GlobalVar("f")
t = TypeVar("t")
a = Var("a", t)
mod[f] = Function([a], a, t, [t])
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil()))) expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
assert alpha_equal(dcpe(orig, mod=mod), expected) expected = Function([], expected)
mod[mod.entry_func] = expected
expected = mod[mod.entry_func]
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body)
def test_loop(): def test_loop():
...@@ -181,9 +197,12 @@ def test_loop(): ...@@ -181,9 +197,12 @@ def test_loop():
x = Var("x", t) x = Var("x", t)
loop = GlobalVar("loop") loop = GlobalVar("loop")
mod[loop] = Function([x], loop(x), t, [t]) mod[loop] = Function([x], loop(x), t, [t])
res = dcpe(loop(const(1)), mod=mod) expected = Call(loop, [const(1)])
expected = Call(loop, [const(1)], None, [None]) mod[mod.entry_func] = Function([], expected)
assert alpha_equal(res, expected) expected = mod[mod.entry_func].body
call = Function([], loop(const(1)))
res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected)
def test_swap_loop(): def test_swap_loop():
...@@ -196,8 +215,9 @@ def test_swap_loop(): ...@@ -196,8 +215,9 @@ def test_swap_loop():
loop = GlobalVar("loop") loop = GlobalVar("loop")
mod[loop] = Function([x, y], loop(y, x), nat) mod[loop] = Function([x, y], loop(y, x), nat)
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = dcpe(prog, mod=mod) res = Function([], prog)
assert alpha_equal(prog, res) res = dcpe(res, mod=mod)
assert alpha_equal(prog, res.body)
def test_abs_diff(): def test_abs_diff():
...@@ -217,8 +237,9 @@ def test_abs_diff(): ...@@ -217,8 +237,9 @@ def test_abs_diff():
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 4)) assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id(): def test_match_nat_id():
...@@ -233,8 +254,9 @@ def test_match_nat_id(): ...@@ -233,8 +254,9 @@ def test_match_nat_id():
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
mod[nat_id] = Function([x], Match(x, [z_case, s_case])) mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
orig = nat_id(make_nat_expr(p, 3)) orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3)) assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_nat_id(): def test_nat_id():
...@@ -247,8 +269,9 @@ def test_nat_id(): ...@@ -247,8 +269,9 @@ def test_nat_id():
nat_id = GlobalVar("nat_id") nat_id = GlobalVar("nat_id")
mod[nat_id] = Function([x], x) mod[nat_id] = Function([x], x)
orig = nat_id(make_nat_expr(p, 3)) orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3)) assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id(): def test_global_match_nat_id():
...@@ -260,8 +283,9 @@ def test_global_match_nat_id(): ...@@ -260,8 +283,9 @@ def test_global_match_nat_id():
z_case = Clause(PatternConstructor(p.z, []), p.z()) z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3)) assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_double(): def test_double():
...@@ -269,8 +293,9 @@ def test_double(): ...@@ -269,8 +293,9 @@ def test_double():
p = Prelude(mod) p = Prelude(mod)
add_nat_definitions(p) add_nat_definitions(p)
orig = p.double(make_nat_expr(p, 3)) orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 6)) assert alpha_equal(res.body, make_nat_expr(p, 6))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature from tvm.relay.ir_pass import alpha_equal, detect_feature
from tvm.relay import op, create_executor from tvm.relay import op, create_executor, transform
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature from tvm.relay.feature import Feature
...@@ -39,7 +38,7 @@ def test_explicit_bound(): ...@@ -39,7 +38,7 @@ def test_explicit_bound():
z = op.add(y, y) z = op.add(y, y)
f = relay.Function([], op.add(z, z)) f = relay.Function([], op.add(z, z))
assert not Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(f)
anf = to_a_normal_form(f) anf = transform.OptimizeOnExpr(f, transform.ToANormalForm())
assert Feature.fLet in detect_feature(anf) assert Feature.fLet in detect_feature(anf)
check_eval(f(), 8.0) check_eval(f(), 8.0)
check_eval(anf(), 8.0) check_eval(anf(), 8.0)
...@@ -53,7 +52,8 @@ def test_order(): ...@@ -53,7 +52,8 @@ def test_order():
x = relay.const(1) x = relay.const(1)
val = x + y * z val = x + y * z
check_eval(val, 7.0) check_eval(val, 7.0)
anf = infer_type(to_a_normal_form(val)) anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(),
transform.InferType()])
a = relay.Var('a', relay.IncompleteType()) a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType())
...@@ -65,14 +65,16 @@ def test_order(): ...@@ -65,14 +65,16 @@ def test_order():
expected_output = relay.Let(c, z, expected_output) expected_output = relay.Let(c, z, expected_output)
expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output) expected_output = relay.Let(a, x, expected_output)
expected_output = infer_type(expected_output) expected_output = transform.OptimizeOnExpr(expected_output,
transform.InferType())
assert alpha_equal(anf, expected_output) assert alpha_equal(anf, expected_output)
def test_if(): def test_if():
cond = relay.const(True) cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3)) x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_a_normal_form(x)) anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(),
transform.InferType()])
a = relay.Var('a', relay.IncompleteType()) a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType())
...@@ -82,7 +84,8 @@ def test_if(): ...@@ -82,7 +84,8 @@ def test_if():
expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.If(c, true_branch, false_branch)
expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output) expected_output = relay.Let(c, cond, expected_output)
expected_output = infer_type(expected_output) expected_output = transform.OptimizeOnExpr(expected_output,
transform.InferType())
assert alpha_equal(anf, expected_output) assert alpha_equal(anf, expected_output)
...@@ -114,7 +117,8 @@ def test_recursion(): ...@@ -114,7 +117,8 @@ def test_recursion():
mod[f] = value mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f] old_f = mod[f]
f = to_a_normal_form(f, mod=mod) mod = transform.ToANormalForm()(mod)
f = mod[f]
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
...@@ -129,7 +133,8 @@ def test_ref(): ...@@ -129,7 +133,8 @@ def test_ref():
body = relay.Let(iv, relay.RefRead(i), body) body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body) body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
check_eval(body, 3) check_eval(body, 3)
check_eval(to_a_normal_form(body), 3) opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
check_eval(opt_body, 3)
def test_nat_add(): def test_nat_add():
...@@ -144,7 +149,12 @@ def test_nat_add(): ...@@ -144,7 +149,12 @@ def test_nat_add():
intrp = create_executor(mod=mod, ctx=ctx, target="llvm") intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 expr = add(s(z()), s(z()))
f = relay.GlobalVar("f")
mod[f] = relay.Function([], expr)
mod = transform.ToANormalForm()(mod)
expr = mod["f"]
assert count(p, intrp.evaluate(expr.body)) == 2
assert Feature.fLet in detect_feature(mod[add]) assert Feature.fLet in detect_feature(mod[add])
...@@ -155,14 +165,16 @@ def test_let(): ...@@ -155,14 +165,16 @@ def test_let():
body = relay.Let(y, x, x + y) body = relay.Let(y, x, x + y)
body = relay.Let(x, d, body) body = relay.Let(x, d, body)
check_eval(body, 8) check_eval(body, 8)
check_eval(to_a_normal_form(body), 8) opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
check_eval(opt_body, 8)
def test_function(): def test_function():
x = relay.Var("x") t = relay.TensorType((), 'float32')
x = relay.Var("x", t)
f = relay.Function([x], x + x) f = relay.Function([x], x + x)
d = relay.const(4.0, 'float32') d = relay.const(4.0, 'float32')
anf_f = to_a_normal_form(f) anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm())
assert isinstance(anf_f, relay.Function) assert isinstance(anf_f, relay.Function)
check_eval(f(d), 8) check_eval(f(d), 8)
check_eval(anf_f(d), 8) check_eval(anf_f(d), 8)
......
...@@ -17,10 +17,9 @@ ...@@ -17,10 +17,9 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature from tvm.relay import op, create_executor, transform
from tvm.relay import op, create_executor from tvm.relay.ir_pass import detect_feature
from tvm.relay.feature import Feature from tvm.relay.feature import Feature
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
...@@ -41,9 +40,9 @@ def test_implicit_share(): ...@@ -41,9 +40,9 @@ def test_implicit_share():
body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body) body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body)) f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f) g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
assert "let" in f.astext() assert Feature.fLet in detect_feature(f)
assert not "let" in g.astext() assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0) check_eval(f, [], 8.0)
check_eval(g, [], 8.0) check_eval(g, [], 8.0)
...@@ -55,8 +54,8 @@ def test_round_trip(): ...@@ -55,8 +54,8 @@ def test_round_trip():
body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body) body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body)) f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f) g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
h = to_a_normal_form(g) h = transform.OptimizeOnExpr(g, transform.ToANormalForm())
assert Feature.fLet in detect_feature(f) assert Feature.fLet in detect_feature(f)
assert not Feature.fLet in detect_feature(g) assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0) check_eval(f, [], 8.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment