Commit c8654e2a by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Partial Evaluator do concatenate, and has better termination checker for scalar. (#3703)

* save

lint some

lint

lint

add charrnn

save

save

save

remove debug

remove debug

remove space

refactor

save

rewrite dce

* reset files

* join -> meet

* lint

* address review comment

* wordsmith
parent ee74d00e
......@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The scope builder interface """
"""The scope builder interface."""
from __future__ import absolute_import
from . import expr as _expr
......
......@@ -419,8 +419,8 @@ class AlphaEqualHandler:
bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
if (!ExprEqual(lhs->value, rhs->value)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
......
......@@ -36,121 +36,94 @@
namespace tvm {
namespace relay {
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
class CalcDep;
class FindDef : private ExprVisitor {
private:
VarMap<Expr> expr_map_;
void VisitExpr_(const LetNode* l) final {
CHECK_EQ(expr_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
VisitExpr(l->value);
VisitExpr(l->body);
}
friend CalcDep;
};
class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
FindDef fd;
fd(e);
CalcDep cd(fd.expr_map_);
cd(e);
Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
return el(e);
}
private:
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;
void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
void VisitExpr(const Expr& e) final {
return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
}
void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body);
}
void VisitExpr(const Expr& e) final {
ExprFunctor<void(const Expr&)>::VisitExpr(e);
}
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
++use_map_[var];
if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
VisitExpr(expr_map_[var]);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}
class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
};
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
......
......@@ -68,7 +68,7 @@ class GNF : public ExprMutator {
}
Expr VisitExpr_(const LetNode* ln) override {
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value))));
var_map_.insert(std::pair<Var, Expr>(ln->var, WrapRec(ln->var, VisitExpr(ln->value))));
return VisitExpr(ln->body);
}
};
......
......@@ -19,7 +19,7 @@ from nose.tools import nottest
import tvm
from tvm import relay
from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.op import log, add, equal, subtract
......@@ -65,11 +65,10 @@ def test_used_let():
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
@nottest
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let():
......@@ -78,6 +77,17 @@ def test_chain_unused_let():
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func):
f = relay.Var("f")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
return relay.Let(f, value, func(f))
# make sure we dont infinite loop
def test_recursion():
"""
......@@ -91,21 +101,15 @@ def test_recursion():
}
f(2, 10000);
"""
f = relay.Var("f")
f1 = relay.Var("f1")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f1, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType())
assert graph_equal(dced, orig)
dced = run_opt_pass(relay.Let(f, value, e.three),
transform.DeadCodeElimination())
assert_alpha_equal(dced, orig)
def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
......@@ -133,5 +137,6 @@ if __name__ == "__main__":
test_inline()
test_chain_unused_let()
test_recursion()
test_recursion_dead()
test_op_let()
test_tuple_get_item()
......@@ -123,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected)
assert_alpha_equal(g, expected)
def test_if_ref():
......@@ -311,7 +311,16 @@ def test_concat():
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig))
assert_alpha_equal(dcpe(orig), orig)
def test_triangle():
t = relay.TensorType([], "int32")
x = Var("x", t)
f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10))))
assert_alpha_equal(dcpe(orig), const(55))
if __name__ == '__main__':
......@@ -332,3 +341,4 @@ if __name__ == '__main__':
test_global_match_nat_id()
test_match_nat_id()
test_concat()
test_triangle()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment