Commit c8654e2a by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Partial Evaluator do concatenate, and has better termination checker for scalar. (#3703)

* save

lint some

lint

lint

add charrnn

save

save

save

remove debug

remove debug

remove space

refactor

save

rewrite dce

* reset files

* join -> meet

* lint

* address review comment

* wordsmith
parent ee74d00e
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""The scope builder interface """
"""The scope builder interface."""
from __future__ import absolute_import from __future__ import absolute_import
from . import expr as _expr from . import expr as _expr
......
...@@ -419,8 +419,8 @@ class AlphaEqualHandler: ...@@ -419,8 +419,8 @@ class AlphaEqualHandler:
bool VisitExpr_(const LetNode* lhs, const Expr& other) final { bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) { if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false; if (!MergeVarDecl(lhs->var, rhs->var)) return false;
if (!ExprEqual(lhs->value, rhs->value)) return false;
return ExprEqual(lhs->body, rhs->body); return ExprEqual(lhs->body, rhs->body);
} else { } else {
return false; return false;
......
...@@ -36,94 +36,34 @@ ...@@ -36,94 +36,34 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// calculate the dependency graph from expression template<typename X>
class CalcDep : private ExprVisitor { using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
public: using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}
class CalcDep;
class FindDef : private ExprVisitor {
private: private:
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
VarMap<Expr> expr_map_; VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;
void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);
}
void VisitExpr_(const LetNode* l) final { void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0); CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value; expr_map_[l->var] = l->value;
use_map_[l->var] = 0; VisitExpr(l->value);
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body); VisitExpr(l->body);
} }
void VisitExpr(const Expr& e) final { friend CalcDep;
ExprFunctor<void(const Expr&)>::VisitExpr(e); };
}
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
}
}
void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}
class Eliminator : private ExprMutator { class Eliminator : private ExprMutator {
private: private:
VarMap<Expr> expr_map_; VarMap<Expr> expr_map_;
VarMap<size_t> use_map_; VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_; bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map, explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map, const VarMap<size_t>& use_map,
const VarSet& letrec_set,
bool inline_once) : bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { } expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
friend CalcDep; friend CalcDep;
bool HasLet(const Var& v) { bool HasLet(const Var& v) {
...@@ -131,7 +71,7 @@ class CalcDep : private ExprVisitor { ...@@ -131,7 +71,7 @@ class CalcDep : private ExprVisitor {
case 0: case 0:
return false; return false;
case 1: case 1:
return letrec_set_.count(v) > 0 || !inline_once_; return !inline_once_;
default: default:
return true; return true;
} }
...@@ -150,7 +90,40 @@ class CalcDep : private ExprVisitor { ...@@ -150,7 +90,40 @@ class CalcDep : private ExprVisitor {
return VisitExpr(op->body); return VisitExpr(op->body);
} }
} }
}; };
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
FindDef fd;
fd(e);
CalcDep cd(fd.expr_map_);
cd(e);
Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
return el(e);
}
private:
explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
void VisitExpr(const Expr& e) final {
return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
}
void VisitExpr_(const LetNode* l) final {
VisitExpr(l->body);
}
void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
++use_map_[var];
if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
VisitExpr(expr_map_[var]);
}
}
}; };
Expr DeadCodeElimination(const Expr& e, bool inline_once) { Expr DeadCodeElimination(const Expr& e, bool inline_once) {
......
...@@ -68,7 +68,7 @@ class GNF : public ExprMutator { ...@@ -68,7 +68,7 @@ class GNF : public ExprMutator {
} }
Expr VisitExpr_(const LetNode* ln) override { Expr VisitExpr_(const LetNode* ln) override {
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value)))); var_map_.insert(std::pair<Var, Expr>(ln->var, WrapRec(ln->var, VisitExpr(ln->value))));
return VisitExpr(ln->body); return VisitExpr(ln->body);
} }
}; };
......
...@@ -19,7 +19,7 @@ from nose.tools import nottest ...@@ -19,7 +19,7 @@ from nose.tools import nottest
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import Function, transform from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.op import log, add, equal, subtract from tvm.relay.op import log, add, equal, subtract
...@@ -65,11 +65,10 @@ def test_used_let(): ...@@ -65,11 +65,10 @@ def test_used_let():
expected = relay.Let(e.c, e.one, e.c + e.c) expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
@nottest
def test_inline(): def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let(): def test_chain_unused_let():
...@@ -78,6 +77,17 @@ def test_chain_unused_let(): ...@@ -78,6 +77,17 @@ def test_chain_unused_let():
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func):
f = relay.Var("f")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
return relay.Let(f, value, func(f))
# make sure we dont infinite loop # make sure we dont infinite loop
def test_recursion(): def test_recursion():
""" """
...@@ -91,21 +101,15 @@ def test_recursion(): ...@@ -91,21 +101,15 @@ def test_recursion():
} }
f(2, 10000); f(2, 10000);
""" """
f = relay.Var("f") orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
f1 = relay.Var("f1")
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, relay.const(0)),
data,
relay.Call(f1, [subtract(n, relay.const(1)),
log(data)]))
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination()) dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType()) orig = run_opt_pass(orig, transform.InferType())
assert graph_equal(dced, orig) assert_alpha_equal(dced, orig)
dced = run_opt_pass(relay.Let(f, value, e.three),
transform.DeadCodeElimination()) def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three) assert alpha_equal(dced, e.three)
...@@ -133,5 +137,6 @@ if __name__ == "__main__": ...@@ -133,5 +137,6 @@ if __name__ == "__main__":
test_inline() test_inline()
test_chain_unused_let() test_chain_unused_let()
test_recursion() test_recursion()
test_recursion_dead()
test_op_let() test_op_let()
test_tuple_get_item() test_tuple_get_item()
...@@ -123,7 +123,7 @@ def test_ad(): ...@@ -123,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body) body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body)) expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType()) expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected) assert_alpha_equal(g, expected)
def test_if_ref(): def test_if_ref():
...@@ -311,7 +311,16 @@ def test_concat(): ...@@ -311,7 +311,16 @@ def test_concat():
x = Var("x", t) x = Var("x", t)
y = Var("x", t) y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0))) orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig)) assert_alpha_equal(dcpe(orig), orig)
def test_triangle():
t = relay.TensorType([], "int32")
x = Var("x", t)
f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10))))
assert_alpha_equal(dcpe(orig), const(55))
if __name__ == '__main__': if __name__ == '__main__':
...@@ -332,3 +341,4 @@ if __name__ == '__main__': ...@@ -332,3 +341,4 @@ if __name__ == '__main__':
test_global_match_nat_id() test_global_match_nat_id()
test_match_nat_id() test_match_nat_id()
test_concat() test_concat()
test_triangle()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment