Commit f2a6851a by Zhi Committed by Jared Roesch

[Relay][Pass] Only allow Module -> Module for opts managed by pass infra (#3430)

* [Relay][Pass] Only allow Module -> Module for opts managed by pass infra

* revert gradient pass
parent 6c81d784
......@@ -141,23 +141,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param e The original function.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return the new function with abstraction
*/
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
......@@ -288,24 +271,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*!
* \brief Fold constant expressions.
*
......@@ -388,38 +353,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form.
*/
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);
/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
* \param match the match expression to test
......@@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*
* \brief Bind function parameters or free variables.
* \brief Bind the free variables to a Relay expression.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*
* \return The expression with all free vars bound.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
......
......@@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout();
*/
TVM_DLL Pass CanonicalizeCast();
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();
/*!
* \brief This is a helper function that runs a some optimization passes on
* a certain expression and returns the optimized version. With the help of this
* function, users don't need to manually construct a module, then perform
* passes, and finally and extract the target function/expression from the
* returned module frequently.
*
* \param expr The expression to be optimized.
* \param passes The passses that will be applied on the given expression.
*
* \return The optimized expression.
*/
TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);
} // namespace transform
} // namespace relay
} // namespace tvm
......
......@@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr):
"""
return _ir_pass.backward_fold_scale_axis(expr)
def eta_expand(expr, mod):
"""Add abstraction over a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
mod : tvm.relay.Module
The global module.
Returns
-------
expanded_expr : tvm.relay.Expr
The expression after eta expansion.
"""
return _ir_pass.eta_expand(expr, mod)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
......@@ -318,25 +301,6 @@ def canonicalize_ops(expr):
return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
......@@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_a_normal_form(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
......@@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None):
return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
......
......@@ -302,15 +302,20 @@ def CanonicalizeOps():
return _transform.CanonicalizeOps()
def DeadCodeElimination():
""" Remove expressions which does not effect the program result (dead code).
def DeadCodeElimination(inline_once=False):
"""Remove expressions which does not effect the program result (dead code).
Parameters
----------
inline_once: Optional[Bool]
Whether to inline binding that occurs only once.
Returns
-------
ret: tvm.relay.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _transform.DeadCodeElimination()
return _transform.DeadCodeElimination(inline_once)
def FoldConstant():
......@@ -406,6 +411,7 @@ def ToANormalForm():
"""
return _transform.ToANormalForm()
def EtaExpand():
"""Add abstraction over a function
......@@ -416,6 +422,7 @@ def EtaExpand():
"""
return _transform.EtaExpand()
def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
......@@ -449,7 +456,7 @@ def PartialEvaluate():
Returns
-------
ret : tvm.relay.Pass
ret: tvm.relay.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()
......@@ -465,6 +472,31 @@ def CanonicalizeCast():
"""
return _transform.CanonicalizeCast()
def OptimizeOnExpr(expr, passes):
"""Perform optimization passes on an expressioin.
Parameters
----------
expr: tvm.relay.Expr
The expression for optimization.
passes: Union[Pass, List[Pass]]
The list of optimizations to be applied.
Returns
-------
ret: tvm.relay.Expr
The optimized expression.
"""
if isinstance(passes, Pass):
passes = [passes]
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a pass or a list of pass objects.")
return _transform.OptimizeOnExpr(expr, passes)
def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
......
......@@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);
namespace transform {
Pass DeadCodeElimination(bool inline_once) {
......
......@@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) {
} // namespace partial_eval
Expr PartialEval(const Expr& e, const Module& m) {
return TransformF([&](const Expr& e) {
Module PartialEval(const Module& m) {
CHECK(m->entry_func.defined());
auto func = m->Lookup(m->entry_func);
Expr ret =
TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) {
relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
pe.InitializeFuncId(e);
return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
});
}, e);
relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
pe.InitializeFuncId(e);
return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
});
}, func);
CHECK(ret->is_type<FunctionNode>());
m->Update(m->entry_func, Downcast<Function>(ret));
return m;
}
TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
.set_body_typed(PartialEval);
namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f, m));
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return PartialEval(m);
};
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
}
TVM_REGISTER_API("relay._transform.PartialEvaluate")
......
......@@ -573,6 +573,18 @@ class PassContext::Internal {
}
};
Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes) {
auto mod = ModuleNode::FromExpr(expr);
Sequential seq(passes);
auto pass_ctx = PassContext::Create();
pass_ctx->opt_level = 3;
tvm::With<PassContext> ctx_scope(pass_ctx);
mod = seq(mod);
CHECK(mod.defined());
auto entry_func = mod->Lookup(mod->entry_func);
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
.set_body_typed(PassContext::Current);
......@@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext")
TVM_REGISTER_API("relay._transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
TVM_REGISTER_API("relay._transform.OptimizeOnExpr")
.set_body_typed(OptimizeOnExpr);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -26,6 +26,8 @@
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include "let_list.h"
#include "../../common/arena.h"
......@@ -35,10 +37,6 @@
namespace tvm {
namespace relay {
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
......@@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) {
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANormalForm(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Fill fi(m, dg, node_scope, gv);
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) {
Fill fi(dg, node_scope);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
private:
Module mod_;
const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod),
Fill(const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope) :
dg_(dg),
node_scope_(node_scope),
visited_(visited) { }
node_scope_(node_scope) { }
Scope GetScope(const Expr& e) {
return node_scope_->at(dg_.expr_node.at(e));
......@@ -246,10 +236,6 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
GlobalVar gv = GetRef<GlobalVar>(gvn);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return Atomic(gv, gv, v);
}
......@@ -276,9 +262,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};
Expr ToANormalFormAux(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Expr ToANormalFormAux(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
......@@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e,
* We do an additional pass to fill all the LetList and we are done.
*/
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
return Fill::ToANormalForm(e, dg, &node_scope);
}
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
DLOG(INFO)
<< "ToANF:" << std::endl
<< AsText(e, false);
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e, m, gv);
}, e);
CHECK_EQ(FreeVars(ret).size(), 0);
Module ToANormalForm(const Module& m) {
DLOG(INFO) << "ToANF:" << std::endl << m;
tvm::Map<GlobalVar, Function> updates;
auto funcs = m->functions;
for (const auto& it : funcs) {
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e);
}, it.second);
CHECK_EQ(FreeVars(ret).size(), 0);
updates.Set(it.first, Downcast<Function>(ret));
}
DLOG(INFO)
<< "ToANF: transformed" << std::endl
<< AsText(ret, false);
for (auto pair : updates) {
m->Add(pair.first, pair.second, true);
}
return ret;
}
DLOG(INFO) << "ToANF: transformed" << std::endl << m;
Expr ToANormalForm(const Expr& e, const Module& m) {
std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv);
return m;
}
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
namespace transform {
Pass ToANormalForm() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(ToANormalForm(f, m));
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return ToANormalForm(m);
};
return CreateFunctionPass(pass_func, 1, "ToANormalForm", {});
return CreateModulePass(pass_func, 1, "ToANormalForm", {});
}
TVM_REGISTER_API("relay._transform.ToANormalForm")
......
......@@ -24,8 +24,8 @@
*
* \brief Turn A normal form into graph normal form.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "let_list.h"
namespace tvm {
......@@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) {
return GNF()(e);
}
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body_typed(ToGraphNormalForm);
namespace transform {
Pass ToGraphNormalForm() {
......
......@@ -18,20 +18,13 @@ from nose.tools import nottest
import tvm
from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
from tvm.relay import Function, transform
from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars
from tvm.relay.op import log, add, equal, subtract
class env:
def __init__(self):
self.a = relay.Var("a")
self.b = relay.Var("b")
self.c = relay.Var("c")
self.d = relay.Var("d")
self.e = relay.Var("e")
self.x = relay.Var("x")
self.y = relay.Var("y")
self.z = relay.Var("z")
self.shape = tvm.convert([1, 2, 3])
self.tt = relay.TensorType(self.shape, "float32")
self.int32 = relay.TensorType([], "int32")
......@@ -39,6 +32,14 @@ class env:
self.one = relay.const(1.0)
self.two = relay.const(2.0)
self.three = relay.const(3.0)
self.a = relay.Var("a", self.float32)
self.b = relay.Var("b", self.float32)
self.c = relay.Var("c", self.float32)
self.d = relay.Var("d", self.float32)
self.e = relay.Var("e", self.float32)
self.x = relay.Var("x", self.int32)
self.y = relay.Var("y", self.int32)
self.z = relay.Var("z", self.int32)
e = env()
......@@ -46,22 +47,27 @@ e = env()
def test_let():
orig = relay.Let(e.x, e.y, e.z)
assert alpha_equal(dead_code_elimination(orig), e.z)
orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c))
orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
@nottest
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
assert alpha_equal(dead_code_elimination(orig), e.d)
orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
assert alpha_equal(dead_code_elimination(orig), e.e)
orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
# make sure we dont infinite loop
......@@ -78,27 +84,39 @@ def test_recursion():
f(2, 10000);
"""
f = relay.Var("f")
f1 = relay.Var("f1")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1.0)),
relay.Call(f1, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
orig = transform.OptimizeOnExpr(orig, transform.InferType())
assert graph_equal(dced, orig)
dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three),
transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination())
assert alpha_equal(dced, add(e.three, e.two))
def test_tuple_get_item():
t = relay.Var('t')
tt = relay.TupleType([e.float32, e.float32])
t = relay.Var('t', tt)
a = relay.Var('a')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
if __name__ == "__main__":
......
......@@ -18,17 +18,13 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
from tvm.relay.ir_pass import gradient
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.ir_pass import alpha_equal, gradient
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
from nose.tools import nottest
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call, Type
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
from tvm.relay import GlobalVar, Call
from tvm.relay.testing import add_nat_definitions, make_nat_expr
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
......@@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr, mod=None):
return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
def tipe(expr):
return transform.OptimizeOnExpr(expr,
[transform.InferType(),
transform.PartialEvaluate(),
transform.InferType()])
def dcpe(expr, mod=None, grad=False):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
if grad:
expr = gradient(expr)
if mod:
assert isinstance(expr, Function)
mod[mod.entry_func] = expr
seq = transform.Sequential(passes)
mod = seq(mod)
return mod[mod.entry_func]
return transform.OptimizeOnExpr(expr, passes)
def test_tuple():
......@@ -47,24 +60,31 @@ def test_tuple():
x = Var("x", t)
body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
expected = relay.Function([x], x, None, [t])
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(dcpe(f), expected)
def test_const_inline():
d = Var("d")
t = relay.TensorType([], "float32")
d = Var("d", t)
double = Function([d], d + d)
orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0))
def test_ref():
d = relay.Var("d")
r = relay.Var("r")
t = relay.TensorType([], "float32")
d = relay.Var("d", t)
r = relay.Var("r", relay.RefType(t))
x = relay.Var("x")
body = relay.RefRead(r)
body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
body = Let(r, RefCreate(d), body)
square = Function([d], body)
assert alpha_equal(dcpe(square), Function([d], d * d))
expected = transform.OptimizeOnExpr(Function([d], d * d),
transform.InferType())
assert alpha_equal(dcpe(square), expected)
def test_empty_ad():
......@@ -73,17 +93,19 @@ def test_empty_ad():
t = TensorType(shape, dtype)
d = Var("d", t)
f = Function([d], d)
g = dcpe(gradient(f))
g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(g, expected)
def test_ad():
shape = (10, 10)
dtype = "float32"
t = TensorType(shape, dtype)
d = Var("d", t)
f = Function([d], d * d)
g = dcpe(gradient(f))
g = dcpe(f, grad=True)
m = d * d
x = relay.Var("x")
o = op.ones_like(x)
......@@ -92,6 +114,7 @@ def test_ad():
body = Tuple([x, Tuple([grad])])
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
expected = transform.OptimizeOnExpr(expected, transform.InferType())
assert alpha_equal(g, expected)
......@@ -107,8 +130,7 @@ def test_if_ref():
eff = Var("eff")
body = Let(eff, body, RefRead(r))
f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
pe_f = tipe(f)
ex = create_executor()
f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True))
......@@ -132,8 +154,7 @@ def test_function_invalidate():
body = Let(fet, fetch, body)
body = Let(r, RefCreate(const(0)), body)
f = Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
pe_f = tipe(f)
ex = create_executor()
f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True))
......@@ -144,35 +165,30 @@ def test_function_invalidate():
def test_head_cons():
mod = Module()
p = Prelude(mod)
def hd_impl():
a = TypeVar("a")
x = Var("x", p.l(a))
y = Var("y")
z = Var("z")
cons_case = Clause(PatternConstructor(p.cons,
[PatternVar(y),
PatternVar(z)]),
y)
y = Var("y")
z = Var("z")
return Function([x], Match(x, [cons_case]), a, [a])
hd = p.hd
t = TypeVar("t")
x = Var("x", t)
hd = Var("hd")
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
body = hd(p.cons(x, p.nil()))
f = Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
res = dcpe(f, mod)
assert alpha_equal(res, Function([x], x, t, [t]))
def test_map():
mod = Module()
p = Prelude(mod)
f = Var("f")
f = GlobalVar("f")
t = TypeVar("t")
a = Var("a", t)
mod[f] = Function([a], a, t, [t])
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
assert alpha_equal(dcpe(orig, mod=mod), expected)
expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
expected = Function([], expected)
mod[mod.entry_func] = expected
expected = mod[mod.entry_func]
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body)
def test_loop():
......@@ -181,9 +197,12 @@ def test_loop():
x = Var("x", t)
loop = GlobalVar("loop")
mod[loop] = Function([x], loop(x), t, [t])
res = dcpe(loop(const(1)), mod=mod)
expected = Call(loop, [const(1)], None, [None])
assert alpha_equal(res, expected)
expected = Call(loop, [const(1)])
mod[mod.entry_func] = Function([], expected)
expected = mod[mod.entry_func].body
call = Function([], loop(const(1)))
res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected)
def test_swap_loop():
......@@ -196,8 +215,9 @@ def test_swap_loop():
loop = GlobalVar("loop")
mod[loop] = Function([x, y], loop(y, x), nat)
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = dcpe(prog, mod=mod)
assert alpha_equal(prog, res)
res = Function([], prog)
res = dcpe(res, mod=mod)
assert alpha_equal(prog, res.body)
def test_abs_diff():
......@@ -217,8 +237,9 @@ def test_abs_diff():
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 4))
assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id():
......@@ -233,8 +254,9 @@ def test_match_nat_id():
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_nat_id():
......@@ -247,8 +269,9 @@ def test_nat_id():
nat_id = GlobalVar("nat_id")
mod[nat_id] = Function([x], x)
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id():
......@@ -260,8 +283,9 @@ def test_global_match_nat_id():
z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
assert alpha_equal(res.body, make_nat_expr(p, 3))
def test_double():
......@@ -269,8 +293,9 @@ def test_double():
p = Prelude(mod)
add_nat_definitions(p)
orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 6))
assert alpha_equal(res.body, make_nat_expr(p, 6))
if __name__ == '__main__':
......
......@@ -17,9 +17,8 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.ir_pass import alpha_equal, detect_feature
from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature
......@@ -39,7 +38,7 @@ def test_explicit_bound():
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not Feature.fLet in detect_feature(f)
anf = to_a_normal_form(f)
anf = transform.OptimizeOnExpr(f, transform.ToANormalForm())
assert Feature.fLet in detect_feature(anf)
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
......@@ -53,7 +52,8 @@ def test_order():
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_a_normal_form(val))
anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(),
transform.InferType()])
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
......@@ -65,14 +65,16 @@ def test_order():
expected_output = relay.Let(c, z, expected_output)
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = infer_type(expected_output)
expected_output = transform.OptimizeOnExpr(expected_output,
transform.InferType())
assert alpha_equal(anf, expected_output)
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_a_normal_form(x))
anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(),
transform.InferType()])
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
......@@ -82,7 +84,8 @@ def test_if():
expected_output = relay.If(c, true_branch, false_branch)
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = infer_type(expected_output)
expected_output = transform.OptimizeOnExpr(expected_output,
transform.InferType())
assert alpha_equal(anf, expected_output)
......@@ -114,7 +117,8 @@ def test_recursion():
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_a_normal_form(f, mod=mod)
mod = transform.ToANormalForm()(mod)
f = mod[f]
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
......@@ -129,7 +133,8 @@ def test_ref():
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
check_eval(body, 3)
check_eval(to_a_normal_form(body), 3)
opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
check_eval(opt_body, 3)
def test_nat_add():
......@@ -144,7 +149,12 @@ def test_nat_add():
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
expr = add(s(z()), s(z()))
f = relay.GlobalVar("f")
mod[f] = relay.Function([], expr)
mod = transform.ToANormalForm()(mod)
expr = mod["f"]
assert count(p, intrp.evaluate(expr.body)) == 2
assert Feature.fLet in detect_feature(mod[add])
......@@ -155,14 +165,16 @@ def test_let():
body = relay.Let(y, x, x + y)
body = relay.Let(x, d, body)
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)
opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm())
check_eval(opt_body, 8)
def test_function():
x = relay.Var("x")
t = relay.TensorType((), 'float32')
x = relay.Var("x", t)
f = relay.Function([x], x + x)
d = relay.const(4.0, 'float32')
anf_f = to_a_normal_form(f)
anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm())
assert isinstance(anf_f, relay.Function)
check_eval(f(d), 8)
check_eval(anf_f(d), 8)
......
......@@ -17,10 +17,9 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature
from tvm.relay import op, create_executor
from tvm.relay import op, create_executor, transform
from tvm.relay.ir_pass import detect_feature
from tvm.relay.feature import Feature
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
......@@ -41,9 +40,9 @@ def test_implicit_share():
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
assert "let" in f.astext()
assert not "let" in g.astext()
g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
assert Feature.fLet in detect_feature(f)
assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)
......@@ -55,8 +54,8 @@ def test_round_trip():
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
h = to_a_normal_form(g)
g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
h = transform.OptimizeOnExpr(g, transform.ToANormalForm())
assert Feature.fLet in detect_feature(f)
assert not Feature.fLet in detect_feature(g)
check_eval(f, [], 8.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment