Commit 12aca82e by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] A Normal Form Canonicalization (#2251)

parent 911c3a36
...@@ -296,6 +296,26 @@ struct StructuralHash { ...@@ -296,6 +296,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const; size_t operator()(const Expr& expr) const;
}; };
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit): ...@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
---------- ----------
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
fvisit : function fvisit : function
The visitor function to be applied. The visitor function to be applied.
""" """
...@@ -35,7 +36,6 @@ def infer_type(expr, mod=None): ...@@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module] mod: Optional[tvm.relay.Module]
The global module. The global module.
Returns Returns
------- -------
checked_expr : tvm.relay.Expr checked_expr : tvm.relay.Expr
...@@ -112,11 +112,11 @@ def check_kind(t, mod=None): ...@@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters Parameters
---------- ----------
t: tvm.relay.Type t : tvm.relay.Type
The type to check The type to check
mod: tvm.relay.Module, optional mod : Optional[tvm.relay.Module]
The global module The global module.
Returns Returns
------- -------
...@@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr): ...@@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr) return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
def gradient(expr, mod=None): def gradient(expr, mod=None):
""". """
Transform a function to return original result paired with gradient of input.
Parameters Parameters
---------- ----------
...@@ -489,11 +516,10 @@ def gradient(expr, mod=None): ...@@ -489,11 +516,10 @@ def gradient(expr, mod=None):
The input expression, which is a Function or a GlobalVar. The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module] mod : Optional[tvm.relay.Module]
The global module.
Returns Returns
------- -------
ret : tvm.relay.Expr expr : tvm.relay.Expr
A function that calculate the original result paired with gradient. The output expression.
""" """
return _ir_pass.first_order_gradient(expr, mod) return _ir_pass.first_order_gradient(expr, mod)
...@@ -36,6 +36,7 @@ class LetList { ...@@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(Var pv, Expr expr) { Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr)); lets_.emplace_back(std::make_pair(pv, expr));
return pv; return pv;
} }
...@@ -71,11 +72,13 @@ class LetList { ...@@ -71,11 +72,13 @@ class LetList {
* *
* \return the wrapped expr. * \return the wrapped expr.
*/ */
Expr Get(const Expr& body) const { Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body; Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
} }
used_ = true;
return ret; return ret;
} }
...@@ -108,6 +111,7 @@ class LetList { ...@@ -108,6 +111,7 @@ class LetList {
private: private:
std::vector<std::pair<Var, Expr> > lets_; std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
}; };
} // namespace relay } // namespace relay
......
...@@ -62,9 +62,9 @@ def test_recursion(): ...@@ -62,9 +62,9 @@ def test_recursion():
relay.Call(f, [subtract(n, relay.const(1.0)), relay.Call(f, [subtract(n, relay.const(1.0)),
log(data)])) log(data)]))
value = relay.Function([n, data], funcbody, e.float32, []) value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
def test_op_let(): def test_op_let():
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_explicit_bound():
x = relay.const(1)
y = op.add(x, x)
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
anf = to_anf(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
# test that the construction order does not matter,
# and is instead ordered by the scope and by post-dfs ordering.
def test_order():
z = relay.const(3)
y = relay.const(2)
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_anf(val))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
e = relay.Var('e', relay.IncompleteType())
expected_output = e
expected_output = relay.Let(e, a + d, expected_output)
expected_output = relay.Let(d, b * c, expected_output)
expected_output = relay.Let(c, z, expected_output)
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_anf(x))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
true_branch = relay.Let(a, relay.const(2), a)
false_branch = relay.Let(b, relay.const(3), b)
expected_output = relay.If(c, true_branch, false_branch)
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
# make sure we dont infinite loop.
# it is too large so we wont check for the exact program.
def test_recursion():
"""
Program:
let sum_twice(n: i32) -> i32 = {
m = (n * 2)
if (n == 0) {
return m;
} else {
return m + sum(n - 1);
}
}
sum_twice(5);
"""
return # cannot be run as fuse_ops need to recursively visit
mod = relay.Module()
i64 = relay.TensorType((), 'int64')
f = relay.GlobalVar("f")
n = relay.Var("n", i64)
m = n * relay.const(2, 'int64')
funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')),
m,
m + f(n - relay.const(1, 'int64')))
value = relay.Function([n], funcbody, i64, [])
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_anf(f, mod=mod)
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
if __name__ == '__main__':
test_explicit_bound()
test_order()
test_if()
test_recursion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment