Commit 54dbcc28 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] fix exponential blowup in interpreter (#3559)

parent 5bff6cce
...@@ -81,13 +81,13 @@ class FeatureSet { ...@@ -81,13 +81,13 @@ class FeatureSet {
return ret; return ret;
} }
/*! \brief A set that contain all the Feature. */ /*! \brief A set that contain all the Feature. */
static FeatureSet AllFeature() { static FeatureSet All() {
FeatureSet fs; FeatureSet fs;
fs.bs_.flip(); fs.bs_.flip();
return fs; return fs;
} }
/*! \brief The empty set. Contain no Feature. */ /*! \brief The empty set. Contain no Feature. */
static FeatureSet NoFeature() { static FeatureSet No() {
FeatureSet fs; FeatureSet fs;
return fs; return fs;
} }
......
...@@ -280,6 +280,7 @@ class Interpreter(Executor): ...@@ -280,6 +280,7 @@ class Interpreter(Executor):
""" """
seq = transform.Sequential([transform.SimplifyInference(), seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0), transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()]) transform.InferType()])
return seq(self.mod) return seq(self.mod)
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h> #include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
#include "compile_engine.h" #include "compile_engine.h"
namespace tvm { namespace tvm {
...@@ -761,6 +762,8 @@ CreateInterpreter( ...@@ -761,6 +762,8 @@ CreateInterpreter(
Target target) { Target target) {
auto intrp = std::make_shared<Interpreter>(mod, context, target); auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) { auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr);
CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
return intrp->Eval(expr); return intrp->Eval(expr);
}; };
return TypedPackedFunc<Value(Expr)>(packed); return TypedPackedFunc<Value(Expr)>(packed);
......
...@@ -120,7 +120,7 @@ class AlphaEqualHandler: ...@@ -120,7 +120,7 @@ class AlphaEqualHandler:
* \return the comparison result. * \return the comparison result.
*/ */
bool TypeEqual(const Type& lhs, const Type& rhs) { bool TypeEqual(const Type& lhs, const Type& rhs) {
auto compute = [&](){ auto compute = [&]() {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false; if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs); return this->VisitType(lhs, rhs);
......
...@@ -34,13 +34,15 @@ namespace relay { ...@@ -34,13 +34,15 @@ namespace relay {
FeatureSet DetectFeature(const Expr& expr) { FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) { if (!expr.defined()) {
return FeatureSet::NoFeature(); return FeatureSet::No();
} }
struct FeatureDetector : ExprVisitor { struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, NodeHash, NodeEqual> visited_; std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
FeatureSet fs = FeatureSet::NoFeature(); FeatureSet fs = FeatureSet::No();
void VisitExpr(const Expr& expr) final { void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) { if (visited_.count(expr) == 0) {
visited_.insert(expr);
ExprVisitor::VisitExpr(expr); ExprVisitor::VisitExpr(expr);
} else { } else {
if (!IsAtomic(expr)) { if (!IsAtomic(expr)) {
...@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) { ...@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) {
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \ STMT \
fs += f##CONSTRUCT_NAME; \ fs += f##CONSTRUCT_NAME; \
ExprVisitor::VisitExpr_(op); \
} }
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {}) #define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
ExprVisitor::VisitExpr_(op); \
})
DETECT_DEFAULT_CONSTRUCT(Var) DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar) DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant) DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_DEFAULT_CONSTRUCT(Function) DETECT_CONSTRUCT(Function, {
if (!op->IsPrimitive()) {
ExprVisitor::VisitExpr_(op);
}
})
DETECT_DEFAULT_CONSTRUCT(Op) DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call) DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, { DETECT_CONSTRUCT(Let, {
...@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) { ...@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) {
fs += fLetRec; fs += fLetRec;
} }
} }
ExprVisitor::VisitExpr_(op);
}) })
DETECT_DEFAULT_CONSTRUCT(If) DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate) DETECT_DEFAULT_CONSTRUCT(RefCreate)
...@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) { ...@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) {
} }
FeatureSet DetectFeature(const Module& mod) { FeatureSet DetectFeature(const Module& mod) {
FeatureSet fs = FeatureSet::NoFeature(); FeatureSet fs = FeatureSet::No();
if (mod.defined()) { if (mod.defined()) {
for (const auto& f : mod->functions) { for (const auto& f : mod->functions) {
fs += DetectFeature(f.second); fs += DetectFeature(f.second);
......
...@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Perform unification on two types and report the error at the expression // Perform unification on two types and report the error at the expression
// or the span of the expression. // or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try { try {
// instantiate higher-order func types when unifying because return solver_.Unify(t1, t2, expr);
// we only allow polymorphism at the top level
Type first = t1;
Type second = t2;
if (auto* ft1 = t1.as<FuncTypeNode>()) {
first = InstantiateFuncType(ft1);
}
if (auto* ft2 = t2.as<FuncTypeNode>()) {
second = InstantiateFuncType(ft2);
}
return solver_.Unify(first, second, expr);
} catch (const dmlc::Error &e) { } catch (const dmlc::Error &e) {
this->ReportFatalError( this->ReportFatalError(
expr, expr,
......
...@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
const auto* ftn = tn.as<FuncTypeNode>(); const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn if (!ftn
|| op->arg_types.size() != ftn->arg_types.size() || op->arg_types.size() != ftn->arg_types.size()
|| op->type_params.size() != ftn->type_params.size()
|| op->type_constraints.size() != ftn->type_constraints.size()) { || op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr); return Type(nullptr);
} }
// without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
if (op->type_params.size() < ftn->type_params.size()) {
return VisitType_(ftn, GetRef<FuncType>(op));
}
// remap type vars so they match // remap type vars so they match
Map<TypeVar, Type> subst_map; Map<TypeVar, Type> subst_map;
for (size_t i = 0; i < op->type_params.size(); i++) { tvm::Array<TypeVar> ft_type_params;
subst_map.Set(ftn->type_params[i], op->type_params[i]); for (size_t i = 0; i < ftn->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], ftn->type_params[i]);
ft_type_params.push_back(op->type_params[i]);
}
for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
} }
auto ft1 = GetRef<FuncType>(op); FuncType ft = FuncTypeNode::make(op->arg_types,
auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map)); op->ret_type,
ft_type_params,
op->type_constraints);
auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
auto ft2 = GetRef<FuncType>(ftn);
Type ret_type = Unify(ft1->ret_type, ft2->ret_type); Type ret_type = Unify(ft1->ret_type, ft2->ret_type);
std::vector<Type> arg_types; std::vector<Type> arg_types;
for (size_t i = 0; i < ft1->arg_types.size(); i++) { for (size_t i = 0; i < ft2->arg_types.size(); ++i) {
Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
arg_types.push_back(arg_type); arg_types.push_back(arg_type);
} }
std::vector<TypeConstraint> type_constraints; std::vector<TypeConstraint> type_constraints;
for (size_t i = 0; i < ft1->type_constraints.size(); i++) { for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
Type unified_constraint = Unify(ft1->type_constraints[i], Type unified_constraint = Unify(ft1->type_constraints[i],
ft2->type_constraints[i]); ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>(); const auto* tcn = unified_constraint.as<TypeConstraintNode>();
...@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints.push_back(GetRef<TypeConstraint>(tcn)); type_constraints.push_back(GetRef<TypeConstraint>(tcn));
} }
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
} }
Type VisitType_(const RefTypeNode* op, const Type& tn) final { Type VisitType_(const RefTypeNode* op, const Type& tn) final {
......
...@@ -63,7 +63,8 @@ def test_ad(): ...@@ -63,7 +63,8 @@ def test_ad():
Feature.fLet, Feature.fLet,
Feature.fRefCreate, Feature.fRefCreate,
Feature.fRefRead, Feature.fRefRead,
Feature.fRefWrite Feature.fRefWrite,
Feature.fGraph
]) ])
......
...@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape): ...@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype)) return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x))
id_cps = run_infer_type(to_cps(id))
def test_double():
t = relay.TypeVar("t")
x = relay.var("x", t)
f = relay.var("f", relay.FuncType([t], t))
double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
double_cps = run_infer_type(to_cps(double))
# make sure cps work for recursion. # make sure cps work for recursion.
def test_recursion(): def test_recursion():
mod = relay.Module() mod = relay.Module()
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
""" """
from tvm import relay from tvm import relay
from tvm.relay import op, transform, analysis from tvm.relay import op, transform, analysis
from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None): def run_infer_type(expr, mod=None):
...@@ -349,6 +350,17 @@ def test_adt_match_type_annotations(): ...@@ -349,6 +350,17 @@ def test_adt_match_type_annotations():
assert ft.checked_type == relay.FuncType([tt], relay.TupleType([])) assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
def test_let_polymorphism():
id = relay.Var("id")
xt = relay.TypeVar("xt")
x = relay.Var("x", xt)
body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body)
int32 = relay.TensorType((), "int32")
assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
test_dual_op() test_dual_op()
...@@ -366,3 +378,4 @@ if __name__ == "__main__": ...@@ -366,3 +378,4 @@ if __name__ == "__main__":
test_constructor_type() test_constructor_type()
test_constructor_call() test_constructor_call()
test_adt_match() test_adt_match()
test_let_polymorphism()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment