Commit c88bda51 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] GNF (#2492)

parent 97bae615
......@@ -320,7 +320,18 @@ struct StructuralHash {
*
* \return expression in A-Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);
Expr ToANormalForm(const Expr& e, const Module& mod);
/*! \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
Expr ToGraphNormalForm(const Expr& e);
} // namespace relay
} // namespace tvm
......
......@@ -490,7 +490,7 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_anf(expr, mod=None):
def to_a_normal_form(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
......@@ -513,7 +513,21 @@ def to_anf(expr, mod=None):
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
return _ir_pass.to_a_normal_form(expr, mod)
def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
expr : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
def gradient(expr, mod=None):
......@@ -534,6 +548,7 @@ def gradient(expr, mod=None):
"""
return _ir_pass.first_order_gradient(expr, mod)
def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
......
......@@ -196,7 +196,7 @@ DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body)
return Creator(arena).Create(body);
}
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
......@@ -258,11 +258,11 @@ bool IsPrimitiveFunction(const Expr& e) {
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANF(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
static Expr ToANormalForm(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
......@@ -396,7 +396,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
GlobalVar gv = GetRef<GlobalVar>(gvn);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANF(mod_->Lookup(gv), mod_, visited_)));
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return gv;
}
......@@ -423,7 +423,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};
Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
......@@ -446,29 +446,29 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
* We do an additional pass to fill all the LetList and we are done.
*/
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANF(e, m, dg, &node_scope, gv);
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
}
Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
ToANormalFormAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANFAux(e, m, gv);
return ToANormalFormAux(e, m, gv);
}
}
Expr ToANF(const Expr& e, const Module& m) {
Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
return ToANF(e, m, &gv);
return ToANormalForm(e, m, &gv);
}
TVM_REGISTER_API("relay._ir_pass.to_anf")
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToANF(args[0], args[1]);
*ret = ToANormalForm(args[0], args[1]);
});
} // namespace relay
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file to_gnf.cc
*
* \brief Turn A normal form into graph normal form.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"
namespace tvm {
namespace relay {
class UseVarVisitor : public ExprVisitor {
public:
explicit UseVarVisitor(const Var& v) : v(v) { }
static bool UseVar(const Var& v, const Expr& e) {
UseVarVisitor uv(v);
uv(e);
return uv.use_var;
}
private:
bool use_var = false;
Var v;
void VisitExpr_(const VarNode* vn) override {
use_var = use_var || (v == GetRef<Var>(vn));
}
};
class GNF : public ExprMutator {
private:
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map_;
Expr VisitExpr_(const VarNode* vn) override {
Var v = GetRef<Var>(vn);
return var_map_.count(v) == 0 ? v : var_map_.at(v);
}
static bool UseVar(const Var& v, const Expr& e) {
return UseVarVisitor::UseVar(v, e);
}
static Expr WrapRec(const Var& var, const Expr& val) {
return UseVar(var, val) ? LetNode::make(var, val, var) : val;
}
Expr VisitExpr_(const LetNode* ln) override {
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value))));
return VisitExpr(ln->body);
}
};
Expr ToGraphNormalForm(const Expr& e) {
return GNF()(e);
}
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToGraphNormalForm(args[0]);
});
} // namespace relay
} // namespace tvm
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
......@@ -21,7 +21,7 @@ def test_explicit_bound():
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
anf = to_anf(f)
anf = to_a_normal_form(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
......@@ -35,7 +35,7 @@ def test_order():
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_anf(val))
anf = infer_type(to_a_normal_form(val))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
......@@ -54,7 +54,7 @@ def test_order():
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_anf(x))
anf = infer_type(to_a_normal_form(x))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
......@@ -96,7 +96,7 @@ def test_recursion():
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_anf(f, mod=mod)
f = to_a_normal_form(f, mod=mod)
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
......@@ -111,7 +111,7 @@ def test_ref():
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
check_eval(body, 3)
check_eval(to_anf(body), 3)
check_eval(to_a_normal_form(body), 3)
# this is an example of using the adt value in python side
......@@ -135,7 +135,7 @@ def test_add():
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(intrp.evaluate(to_anf(add(s(z()), s(z())), mod))) == 2
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()
if __name__ == '__main__':
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
if mod is None:
mod = relay.Module()
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_implicit_share():
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
assert "let" in f.astext()
assert not "let" in g.astext()
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)
def test_round_trip():
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
h = to_a_normal_form(g)
assert "let" in f.astext()
assert not "let" in g.astext()
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)
check_eval(h, [], 8.0)
if __name__ == '__main__':
test_implicit_share()
test_round_trip()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment