Commit 12aca82e by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] A Normal Form Canonicalization (#2251)

parent 911c3a36
......@@ -296,6 +296,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);
} // namespace relay
} // namespace tvm
......
......@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
......@@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
......@@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
......@@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
def gradient(expr, mod=None):
""".
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
......@@ -489,11 +516,10 @@ def gradient(expr, mod=None):
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret : tvm.relay.Expr
A function that calculate the original result paired with gradient.
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
......@@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
......@@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}
......@@ -108,6 +111,7 @@ class LetList {
private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};
} // namespace relay
......
......@@ -62,9 +62,9 @@ def test_recursion():
relay.Call(f, [subtract(n, relay.const(1.0)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
def test_op_let():
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_explicit_bound():
x = relay.const(1)
y = op.add(x, x)
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
anf = to_anf(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
# test that the construction order does not matter,
# and is instead ordered by the scope and by post-dfs ordering.
def test_order():
z = relay.const(3)
y = relay.const(2)
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_anf(val))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
e = relay.Var('e', relay.IncompleteType())
expected_output = e
expected_output = relay.Let(e, a + d, expected_output)
expected_output = relay.Let(d, b * c, expected_output)
expected_output = relay.Let(c, z, expected_output)
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_anf(x))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
true_branch = relay.Let(a, relay.const(2), a)
false_branch = relay.Let(b, relay.const(3), b)
expected_output = relay.If(c, true_branch, false_branch)
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
# make sure we dont infinite loop.
# it is too large so we wont check for the exact program.
def test_recursion():
"""
Program:
let sum_twice(n: i32) -> i32 = {
m = (n * 2)
if (n == 0) {
return m;
} else {
return m + sum(n - 1);
}
}
sum_twice(5);
"""
return # cannot be run as fuse_ops need to recursively visit
mod = relay.Module()
i64 = relay.TensorType((), 'int64')
f = relay.GlobalVar("f")
n = relay.Var("n", i64)
m = n * relay.const(2, 'int64')
funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')),
m,
m + f(n - relay.const(1, 'int64')))
value = relay.Function([n], funcbody, i64, [])
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_anf(f, mod=mod)
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
if __name__ == '__main__':
test_explicit_bound()
test_order()
test_if()
test_recursion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment