Commit 5cedfba5 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] Fix Partial Evaluator, Add stricter checking for CheckWellFormed (#3749)

* aot

* save

* save

* fix test

* remove vta changes

* lint
parent 465313c5
...@@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { ...@@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
} }
Clause ExprMutator::VisitClause(const Clause& c) { Clause ExprMutator::VisitClause(const Clause& c) {
return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs)); Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
} }
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
...@@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator { ...@@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator {
} }
Var VisitVar(const Var& v) final { Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v)); CHECK(!args_map_.count(v))
<< "Cannnot bind an internal pattern variable";
return v;
} }
private: private:
......
...@@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) { ...@@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) {
} }
Var Fresh(const Var& v) { Var Fresh(const Var& v) {
CHECK_EQ(rename_.count(v), 0);
CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret; rename_[v] = ret;
return ret; return ret;
...@@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) { ...@@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) {
} }
Pattern VisitPattern(const Pattern& p) final { Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p); return PatternFunctor::VisitPattern(p);
} }
Pattern VisitPattern_(const PatternVarNode* op) final { Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var)); return PatternVarNode::make(Fresh(op->var));
} }
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Type VisitType_(const TypeVarNode* op) final { Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op); TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
...@@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) { ...@@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) {
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_; std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_; std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
}; };
CHECK(WellFormed(e)) << AsText(e, false);
Expr ret = DeDupMutator().VisitExpr(e); Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); CHECK(WellFormed(ret));
CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
return ret; return ret;
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#define TVM_RELAY_PASS_LET_LIST_H_ #define TVM_RELAY_PASS_LET_LIST_H_
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
...@@ -63,6 +64,7 @@ class LetList { ...@@ -63,6 +64,7 @@ class LetList {
*/ */
Var Push(Var pv, Expr expr) { Var Push(Var pv, Expr expr) {
CHECK(!used_); CHECK(!used_);
CHECK(WellFormed(expr));
lets_.emplace_back(std::make_pair(pv, expr)); lets_.emplace_back(std::make_pair(pv, expr));
return pv; return pv;
} }
......
...@@ -396,6 +396,7 @@ class Environment { ...@@ -396,6 +396,7 @@ class Environment {
void Insert(const Var& v, const PStatic& ps) { void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined()); CHECK(ps.defined());
CHECK_GT(env_.size(), 0);
CHECK_EQ(env_.back().locals.count(v), 0); CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps; env_.back().locals[v] = ps;
} }
...@@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (auto* op = e.as<CallNode>()) { if (const CallNode* c = e.as<CallNode>()) {
if (op->op.same_as(WithFuncIdOp())) { if (c->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1); CHECK_EQ(c->args.size(), 1);
return VisitExpr(op->args[0], ll, name); return VisitExpr(c->args[0], ll, name);
} }
} }
PStatic ret = e.as<FunctionNode>() ? PStatic ret = e.as<FunctionNode>() ?
...@@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
LetList* ll) { LetList* ll) {
return env_.Extend<PStatic>([&]() { return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size()); CHECK_EQ(pv.size(), func->params.size());
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
CHECK_GT(func_map_.count(func), 0); CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func); FuncId fid = func_map_.at(func);
if (fuel_map_.count(fid) == 0) { if (fuel_map_.count(fid) == 0) {
fuel_map_.insert({fid, MkFTop()}); fuel_map_.insert({fid, MkFTop()});
} }
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
if (std::get<1>(meet_res)) { if (std::get<1>(meet_res)) {
FuelFrame tf(this, fid, std::get<0>(meet_res)); FuelFrame tf(this, fid, std::get<0>(meet_res));
Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
Function func = AsFunc(dedup_func);
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
} else { } else {
std::vector<Expr> dyn; std::vector<Expr> dyn;
...@@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr_(const MatchNode* op, LetList* ll) final { PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->data, ll); PStatic ps = VisitExpr(op->data, ll);
return env_.Extend<PStatic>([&]() { return env_.Extend<PStatic>([&]() {
for (const Clause& c : op->clauses) { for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) { switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match: case MatchStatus::Match:
return VisitExpr(c->rhs, ll); return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch: case MatchStatus::NoMatch:
continue; continue;
case MatchStatus::Unknown: case MatchStatus::Unknown:
return [&]() {
tvm::Array<Clause> clauses; tvm::Array<Clause> clauses;
for (const Clause& c : op->clauses) { for (const Clause& c : op->clauses) {
Expr expr = store_.Extend<Expr>([&]() { Expr expr = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) { return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) { for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v)); env_.Insert(v, NoStatic(v));
} }
return VisitExpr(c->rhs, ll)->dynamic; return VisitExpr(c->rhs, ll)->dynamic;
});
}); });
});
clauses.push_back(ClauseNode::make(c->lhs, expr)); clauses.push_back(ClauseNode::make(c->lhs, expr));
} }
store_.Invalidate(); store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete))); return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
} }();
default:
LOG(FATAL) << "Unknown MatchStatus";
throw;
} }
LOG(FATAL) << "No case Match"; }
throw; LOG(FATAL) << "No case Match";
}); throw;
});
} }
MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final { MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
......
...@@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) { ...@@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
private: private:
const tvm::Map<TypeVar, Type>& subst_map_; const tvm::Map<TypeVar, Type>& subst_map_;
}; };
return TypeSubstMutator(subst_map).VisitExpr(expr); CHECK(WellFormed(expr));
auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
CHECK(WellFormed(ret));
return ret;
} }
} // namespace relay } // namespace relay
......
...@@ -35,36 +35,84 @@ namespace relay { ...@@ -35,36 +35,84 @@ namespace relay {
class WellFormedChecker : private ExprVisitor, PatternVisitor { class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true; bool well_formed = true;
std::unordered_set<Var, NodeHash, NodeEqual> s; std::vector<std::unordered_set<Var, NodeHash, NodeEqual>> scope;
std::unordered_set<Var, NodeHash, NodeEqual> current_bound;
std::unordered_set<Var, NodeHash, NodeEqual> total_bound;
std::unordered_set<Var, NodeHash, NodeEqual> free;
void Check(const Var& v) { struct Scope {
if (s.count(v) != 0) { WellFormedChecker* wfc;
explicit Scope(WellFormedChecker* wfc) : wfc(wfc) {
wfc->scope.push_back({});
}
~Scope() {
CHECK_GE(wfc->scope.size(), 0);
for (const Var& v : wfc->scope.back()) {
CHECK_GE(wfc->current_bound.count(v), 0);
wfc->current_bound.erase(v);
}
wfc->scope.pop_back();
}
};
void Bound(const Var& v) {
if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) {
well_formed = false; well_formed = false;
} }
s.insert(v); CHECK_GE(scope.size(), 0);
scope.back().insert(v);
current_bound.insert(v);
total_bound.insert(v);
}
void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
if (total_bound.count(v) != 0) {
well_formed = false;
} else {
free.insert(v);
}
}
} }
void VisitExpr_(const LetNode* l) final { void VisitExpr_(const LetNode* l) final {
Scope s(this);
// we do letrec only for FunctionNode, // we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it. // but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var); Bound(l->var);
CheckWellFormed(l->value); CheckWellFormed(l->value);
CheckWellFormed(l->body); CheckWellFormed(l->body);
} }
void VisitExpr_(const FunctionNode* f) final { void VisitExpr_(const FunctionNode* f) final {
Scope s(this);
for (const Var& param : f->params) { for (const Var& param : f->params) {
Check(param); Bound(param);
} }
CheckWellFormed(f->body); CheckWellFormed(f->body);
} }
void VisitClause(const Clause& c) final {
Scope s(this);
VisitPattern(c->lhs);
VisitExpr(c->rhs);
}
void VisitPattern(const Pattern& p) final { void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p); PatternVisitor::VisitPattern(p);
} }
void VisitVar(const Var& v) final { void VisitVar(const Var& v) final {
Check(v); Bound(v);
}
void VisitExpr(const Expr& e) final {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
ExprVisitor::VisitExpr(e);
}
} }
public: public:
......
...@@ -27,27 +27,36 @@ def check_type_err(expr, msg): ...@@ -27,27 +27,36 @@ def check_type_err(expr, msg):
except tvm.TVMError as err: except tvm.TVMError as err:
assert msg in str(err) assert msg in str(err)
def test_wellformed():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
check_type_err(
f(x),
"Check failed: WellFormed")
def test_too_many_args(): def test_too_many_args():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x) f = relay.Function([x], x)
y = relay.var('y', shape=(10, 10)) y = relay.var('y', shape=(10, 10))
check_type_err( check_type_err(
f(x, y), f(y, y),
"the function is provided too many arguments expected 1, found 2;") "the function is provided too many arguments expected 1, found 2;")
def test_too_few_args(): def test_too_few_args():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10)) y = relay.var('y', shape=(10, 10))
z = relay.var('z', shape=(10, 10))
f = relay.Function([x, y], x) f = relay.Function([x, y], x)
check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;")
def test_rel_fail(): def test_rel_fail():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(11, 10)) y = relay.var('y', shape=(11, 10))
f = relay.Function([x, y], x + y) f = relay.Function([x, y], x + y)
check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);") check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
if __name__ == "__main__": if __name__ == "__main__":
test_wellformed()
test_too_many_args() test_too_many_args()
test_too_few_args() test_too_few_args()
test_rel_fail() test_rel_fail()
...@@ -323,7 +323,16 @@ def test_triangle_number(): ...@@ -323,7 +323,16 @@ def test_triangle_number():
assert_alpha_equal(dcpe(orig), const(55)) assert_alpha_equal(dcpe(orig), const(55))
def test_nat_update():
m = Module()
p = Prelude(m)
add_nat_definitions(p)
m = transform.ToANormalForm()(m)
transform.PartialEvaluate()(m)
if __name__ == '__main__': if __name__ == '__main__':
test_nat_update()
test_ref() test_ref()
test_tuple() test_tuple()
test_empty_ad() test_empty_ad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment