Commit 12aca82e by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] A Normal Form Canonicalization (#2251)

parent 911c3a36
......@@ -296,6 +296,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);
} // namespace relay
} // namespace tvm
......
......@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
......@@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
......@@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
......@@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
def gradient(expr, mod=None):
""".
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
......@@ -489,11 +516,10 @@ def gradient(expr, mod=None):
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret : tvm.relay.Expr
A function that calculate the original result paired with gradient.
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
......@@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
......@@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}
......@@ -108,6 +111,7 @@ class LetList {
private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};
} // namespace relay
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file to_anf.cc
*
* \brief Turn implicit sharing into observable sharing.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
#include "../../common/arena.h"
namespace tvm {
namespace relay {
using common::LinkNode;
using common::LinkedList;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class DependencyGraph {
public:
/*! \brief A node in the graph. */
struct Node {
bool new_scope = false;
LinkedList<Node*> input;
LinkedList<Node*> output;
};
/*! \brief The node map that maps node to graph */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> expr_node;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static DependencyGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(DependencyGraph::Node* parent, const Expr& child) {
VisitExpr(child);
CHECK_NE(graph_.expr_node.count(child), 0);
Depend(parent, graph_.expr_node[child]);
}
void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = parent;
child->output.Push(parent_link);
auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
child_link->value = child;
parent->input.Push(child_link);
}
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
DependencyGraph::Node* NewNode(bool new_scope) {
auto* ret = arena_->make<DependencyGraph::Node>();
ret->new_scope = new_scope;
return ret;
}
void VisitExpr(const Expr& e) final {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
void VisitExpr_(const CallNode* c) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
Depend(n, c->op);
for (const auto& a : c->args) {
Depend(n, a);
}
}
void VisitExpr_(const TupleNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
for (const auto& a : t->fields) {
Depend(n, a);
}
}
void VisitExpr_(const TupleGetItemNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
Depend(n, t->tuple);
}
void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true);
DependencyGraph::Node* f = NewNode(true);
Depend(n, i->cond);
Depend(n, t);
Depend(n, f);
Depend(t, i->true_branch);
Depend(f, i->false_branch);
graph_.post_dfs_order.push_back(f);
graph_.post_dfs_order.push_back(t);
}
void VisitExpr_(const FunctionNode* f) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const LetNode* l) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { }
void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { }
};
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Create(body);
}
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
/* Invariant: when parent is null level is 0
*
* Invariant: when parent is not null level is 1 + parent->level
*/
struct ScopeNode {
size_t level;
Scope parent;
std::shared_ptr<LetList> ll = std::make_shared<LetList>();
explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { }
ScopeNode() : level(0) { }
};
Scope ChildScope(const Scope& s) {
return std::make_shared<ScopeNode>(s);
}
Scope LCA(Scope lhs, Scope rhs) {
while (lhs != rhs) {
if (lhs->level > rhs->level) {
lhs = lhs->parent;
} else if (lhs->level < rhs->level) {
rhs = rhs->parent;
} else {
lhs = lhs->parent;
rhs = rhs->parent;
}
}
return lhs;
}
std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
Scope global_scope = std::make_shared<ScopeNode>();
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->output.head;
Scope s;
if (iit == nullptr) {
s = global_scope;
} else {
s = expr_scope.at(iit->value);
iit = iit->next;
for (; iit != nullptr; iit = iit->next) {
s = LCA(s, expr_scope.at(iit->value));
}
}
expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
}
return expr_scope;
}
bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
}
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANF(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
private:
Module mod_;
const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::set<GlobalVar>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* visited) :
mod_(mod),
dg_(dg),
node_scope_(node_scope),
visited_(visited) { }
Scope GetScope(const Expr& e) {
return node_scope_->at(dg_.expr_node.at(e));
}
Scope GetSubScope(const Expr& e, size_t i) {
DependencyGraph::Node* n = dg_.expr_node.at(e);
auto h = n->input.head;
while (i != 0) {
CHECK(h);
--i;
h = h->next;
}
CHECK(h);
return node_scope_->at(h->value);
}
Expr VisitExpr(const Expr& e, const Var& v) final {
if (memo.count(e) == 0) {
memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
}
return memo.at(e);
}
Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType));
return this->VisitExpr(e, v);
}
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
return GetScope(orig)->ll->Push(v, now);
}
Expr VisitExpr_(const CallNode* c, const Var& v) final {
Expr e = GetRef<Expr>(c);
std::vector<Expr> args;
for (const auto& a : c->args) {
args.push_back(VisitExpr(a));
}
return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v);
}
Expr VisitExpr_(const TupleNode* t, const Var& v) final {
Expr e = GetRef<Expr>(t);
std::vector<Expr> fields;
for (const auto& a : t->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, TupleNode::make(fields), v);
}
Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
Expr e = GetRef<Expr>(t);
return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v);
}
Expr VisitExpr_(const IfNode* i, const Var& v) final {
Expr e = GetRef<Expr>(i);
Expr ret = IfNode::make(VisitExpr(i->cond),
GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
return Compound(e, ret, v);
}
Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
Expr e = GetRef<Expr>(f);
Expr ret;
if (IsPrimitiveFunction(e)) {
ret = e;
} else {
ret = FunctionNode::make(f->params,
GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
f->ret_type,
f->type_params,
f->attrs);
}
return Compound(e, ret, v);
}
Expr VisitExpr_(const LetNode* l, const Var& v) final {
Expr e = GetRef<Expr>(l);
VisitExpr(l->value, l->var);
Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body));
return Compound(e, ret, v);
}
Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
Expr e = GetRef<Expr>(c);
return Compound(e, e, v);
}
Expr VisitExpr_(const VarNode* vn, const Var& v) final {
return GetRef<Expr>(vn);
}
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
GlobalVar gv = GetRef<GlobalVar>(gvn);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANF(mod_->Lookup(gv), mod_, visited_)));
}
return gv;
}
Expr VisitExpr_(const OpNode* op, const Var& v) final {
return GetRef<Expr>(op);
}
};
Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
*
* To make this more principled,
* we always determine the scope of parent before determining the scope of children.
*
* So we calculate all the dependency between nodes.
*/
common::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* In order to model new subscopes created by lambda, if else and pattern matching,
* we also assign scope to edge as well.
* The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
*
* So, the scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
*
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANF(e, m, dg, &node_scope, gv);
}
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (auto f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANFAux(e, m, gv);
}
}
Expr ToANF(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
return ToANF(e, m, &gv);
}
TVM_REGISTER_API("relay._ir_pass.to_anf")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToANF(args[0], args[1]);
});
} // namespace relay
} // namespace tvm
......@@ -62,9 +62,9 @@ def test_recursion():
relay.Call(f, [subtract(n, relay.const(1.0)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
def test_op_let():
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_explicit_bound():
x = relay.const(1)
y = op.add(x, x)
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
anf = to_anf(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
# test that the construction order does not matter,
# and is instead ordered by the scope and by post-dfs ordering.
def test_order():
z = relay.const(3)
y = relay.const(2)
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_anf(val))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
e = relay.Var('e', relay.IncompleteType())
expected_output = e
expected_output = relay.Let(e, a + d, expected_output)
expected_output = relay.Let(d, b * c, expected_output)
expected_output = relay.Let(c, z, expected_output)
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_anf(x))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
d = relay.Var('d', relay.IncompleteType())
true_branch = relay.Let(a, relay.const(2), a)
false_branch = relay.Let(b, relay.const(3), b)
expected_output = relay.If(c, true_branch, false_branch)
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = infer_type(expected_output)
assert alpha_equal(anf, expected_output)
# make sure we dont infinite loop.
# it is too large so we wont check for the exact program.
def test_recursion():
"""
Program:
let sum_twice(n: i32) -> i32 = {
m = (n * 2)
if (n == 0) {
return m;
} else {
return m + sum(n - 1);
}
}
sum_twice(5);
"""
return # cannot be run as fuse_ops need to recursively visit
mod = relay.Module()
i64 = relay.TensorType((), 'int64')
f = relay.GlobalVar("f")
n = relay.Var("n", i64)
m = n * relay.const(2, 'int64')
funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')),
m,
m + f(n - relay.const(1, 'int64')))
value = relay.Function([n], funcbody, i64, [])
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_anf(f, mod=mod)
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
if __name__ == '__main__':
test_explicit_bound()
test_order()
test_if()
test_recursion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment