Commit 7beafddd by 雾雨魔理沙 Committed by Tianqi Chen

[RELAY] IR Wellform Checker (#1748)

parent e22ac6b3
...@@ -12,21 +12,19 @@ ...@@ -12,21 +12,19 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
struct Error : dmlc::Error { struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {} explicit Error(const std::string &msg) : dmlc::Error(msg) {}
}; };
struct InternalError : Error { struct InternalError : public Error {
explicit InternalError(const std::string &msg) : Error(msg) {} explicit InternalError(const std::string &msg) : Error(msg) {}
}; };
// TODO(@jroesch): we should change spanned errors to report struct FatalTypeError : public Error {
// errors against the Environment, inverting control to error definition. explicit FatalTypeError(const std::string &s) : Error(s) {}
struct FatalTypeError : dmlc::Error {
explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {}
}; };
struct TypecheckerError : public dmlc::Error { struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {} explicit TypecheckerError(const std::string &msg) : Error(msg) {}
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <string> #include <string>
#include <functional>
#include "./base.h" #include "./base.h"
#include "./type.h" #include "./type.h"
......
...@@ -80,6 +80,18 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -80,6 +80,18 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
*/ */
bool AlphaEqual(const Type& t1, const Type& t2); bool AlphaEqual(const Type& t1, const Type& t2);
/*! brief Check that each Var is only bind once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
*
* \param e the expression to check.
*
* \return true iff all Var in e is bind at most once.
*/
bool WellFormed(const Expr & e);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_H_ #endif // TVM_RELAY_PASS_H_
...@@ -4,3 +4,4 @@ from . import ir ...@@ -4,3 +4,4 @@ from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
def well_formed(expr: ir.Expr) -> bool: ...
\ No newline at end of file
...@@ -10,3 +10,5 @@ from . import _ir_pass ...@@ -10,3 +10,5 @@ from . import _ir_pass
# Expose checking expression, should rename to infer_type. # Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name # pylint: disable=invalid-name
check_expr = _ir_pass.check_expr check_expr = _ir_pass.check_expr
well_formed = _ir_pass.well_formed
/*!
* Copyright (c) 2018 by Contributors
* \file well_formed.cc
* \brief check that expression is well formed.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <unordered_set>
namespace tvm {
namespace relay {
struct NotWellFormed { };
//! brief make sure each Var is bind at most once.
class WellFormedChecker : private ExprVisitor {
bool well_formed = true;
std::unordered_set<Var, NodeHash, NodeEqual> s;
void Check(const Var & v) {
if (s.count(v) != 0) {
well_formed = false;
}
s.insert(v);
}
void VisitExpr_(const LetNode * l) final {
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}
void VisitExpr_(const FunctionNode * f) final {
for (const Param & p : f->params) {
Check(p->var);
}
CheckWellFormed(f->body);
}
public:
bool CheckWellFormed(const Expr & e) {
this->VisitExpr(e);
return well_formed;
}
};
bool WellFormed(const Expr & e) {
return WellFormedChecker().CheckWellFormed(e);
}
TVM_REGISTER_API("relay._ir_pass.well_formed")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr e = args[0];
*ret = WellFormed(e);
});
} // namespace relay
} // namespace tvm
...@@ -24,4 +24,3 @@ def test_op_level1(): ...@@ -24,4 +24,3 @@ def test_op_level1():
if __name__ == "__main__": if __name__ == "__main__":
test_op_attr() test_op_attr()
test_op_level1() test_op_level1()
import tvm
from tvm import relay
from tvm.relay.ir_pass import well_formed
def test_well_formed():
x = relay.Var("x")
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
assert well_formed(let)
assert not well_formed(relay.Let(x, v, let, ty))
f = relay.Function([relay.Param(x, ty)], ty, x)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment