Commit 54dbcc28 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] fix exponential blowup in interpreter (#3559)

parent 5bff6cce
......@@ -81,13 +81,13 @@ class FeatureSet {
return ret;
}
/*! \brief A set that contain all the Feature. */
static FeatureSet AllFeature() {
static FeatureSet All() {
FeatureSet fs;
fs.bs_.flip();
return fs;
}
/*! \brief The empty set. Contain no Feature. */
static FeatureSet NoFeature() {
static FeatureSet No() {
FeatureSet fs;
return fs;
}
......
......@@ -280,6 +280,7 @@ class Interpreter(Executor):
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)
......
......@@ -29,6 +29,7 @@
#include <tvm/relay/interpreter.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
#include "compile_engine.h"
namespace tvm {
......@@ -761,6 +762,8 @@ CreateInterpreter(
Target target) {
auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr);
CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
return intrp->Eval(expr);
};
return TypedPackedFunc<Value(Expr)>(packed);
......
......@@ -120,7 +120,7 @@ class AlphaEqualHandler:
* \return the comparison result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
auto compute = [&](){
auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
......
......@@ -34,13 +34,15 @@ namespace relay {
FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) {
return FeatureSet::NoFeature();
return FeatureSet::No();
}
struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
FeatureSet fs = FeatureSet::NoFeature();
FeatureSet fs = FeatureSet::No();
void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) {
visited_.insert(expr);
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
......@@ -52,15 +54,20 @@ FeatureSet DetectFeature(const Expr& expr) {
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \
STMT \
fs += f##CONSTRUCT_NAME; \
ExprVisitor::VisitExpr_(op); \
}
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {})
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \
ExprVisitor::VisitExpr_(op); \
})
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_DEFAULT_CONSTRUCT(Function)
DETECT_CONSTRUCT(Function, {
if (!op->IsPrimitive()) {
ExprVisitor::VisitExpr_(op);
}
})
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
......@@ -69,6 +76,7 @@ FeatureSet DetectFeature(const Expr& expr) {
fs += fLetRec;
}
}
ExprVisitor::VisitExpr_(op);
})
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
......@@ -83,7 +91,7 @@ FeatureSet DetectFeature(const Expr& expr) {
}
FeatureSet DetectFeature(const Module& mod) {
FeatureSet fs = FeatureSet::NoFeature();
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
......
......@@ -139,19 +139,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// Perform unification on two types and report the error at the expression
// or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) {
// TODO(tqchen, jroesch): propagate span to solver
try {
// instantiate higher-order func types when unifying because
// we only allow polymorphism at the top level
Type first = t1;
Type second = t2;
if (auto* ft1 = t1.as<FuncTypeNode>()) {
first = InstantiateFuncType(ft1);
}
if (auto* ft2 = t2.as<FuncTypeNode>()) {
second = InstantiateFuncType(ft2);
}
return solver_.Unify(first, second, expr);
return solver_.Unify(t1, t2, expr);
} catch (const dmlc::Error &e) {
this->ReportFatalError(
expr,
......
......@@ -289,30 +289,44 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
|| op->type_params.size() != ftn->type_params.size()
|| op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr);
}
// without loss of generality, suppose op->type_params.size() >= ftn->type_params.size().
if (op->type_params.size() < ftn->type_params.size()) {
return VisitType_(ftn, GetRef<FuncType>(op));
}
// remap type vars so they match
Map<TypeVar, Type> subst_map;
for (size_t i = 0; i < op->type_params.size(); i++) {
subst_map.Set(ftn->type_params[i], op->type_params[i]);
tvm::Array<TypeVar> ft_type_params;
for (size_t i = 0; i < ftn->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], ftn->type_params[i]);
ft_type_params.push_back(op->type_params[i]);
}
for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
}
auto ft1 = GetRef<FuncType>(op);
auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map));
FuncType ft = FuncTypeNode::make(op->arg_types,
op->ret_type,
ft_type_params,
op->type_constraints);
auto ft1 = Downcast<FuncType>(Bind(ft, subst_map));
auto ft2 = GetRef<FuncType>(ftn);
Type ret_type = Unify(ft1->ret_type, ft2->ret_type);
std::vector<Type> arg_types;
for (size_t i = 0; i < ft1->arg_types.size(); i++) {
for (size_t i = 0; i < ft2->arg_types.size(); ++i) {
Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
arg_types.push_back(arg_type);
}
std::vector<TypeConstraint> type_constraints;
for (size_t i = 0; i < ft1->type_constraints.size(); i++) {
for (size_t i = 0; i < ft1->type_constraints.size(); ++i) {
Type unified_constraint = Unify(ft1->type_constraints[i],
ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>();
......@@ -321,7 +335,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints.push_back(GetRef<TypeConstraint>(tcn));
}
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
}
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
......
......@@ -63,7 +63,8 @@ def test_ad():
Feature.fLet,
Feature.fRefCreate,
Feature.fRefRead,
Feature.fRefWrite
Feature.fRefWrite,
Feature.fGraph
])
......
......@@ -30,6 +30,20 @@ def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x))
id_cps = run_infer_type(to_cps(id))
def test_double():
t = relay.TypeVar("t")
x = relay.var("x", t)
f = relay.var("f", relay.FuncType([t], t))
double = run_infer_type(relay.Function([f, x], f(f(x)), t, [t]))
double_cps = run_infer_type(to_cps(double))
# make sure cps work for recursion.
def test_recursion():
mod = relay.Module()
......
......@@ -19,6 +19,7 @@
"""
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None):
......@@ -349,6 +350,17 @@ def test_adt_match_type_annotations():
assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
def test_let_polymorphism():
id = relay.Var("id")
xt = relay.TypeVar("xt")
x = relay.Var("x", xt)
body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body)
int32 = relay.TensorType((), "int32")
assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
if __name__ == "__main__":
test_free_expr()
test_dual_op()
......@@ -366,3 +378,4 @@ if __name__ == "__main__":
test_constructor_type()
test_constructor_call()
test_adt_match()
test_let_polymorphism()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment