Commit 6783d373 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Unifier hotfix (#2437)

parent 76188a43
...@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/ */
bool WellFormed(const Expr& expr); bool WellFormed(const Expr& expr);
/*! \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get free type parameters from expression expr. /*! \brief Get free type parameters from expression expr.
* *
* Free variables are variables that are not bound by a * Free variables are variables that are not bound by a
...@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr); ...@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
*/ */
tvm::Array<Var> FreeVars(const Expr& expr); tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> AllVars(const Expr& expr);
/*! \brief Get free TypeVars from expression expr. /*! \brief Get free TypeVars from expression expr.
* *
* Free type parameters are type parameters that are not bound by a function * Free type parameters are type parameters that are not bound by a function
...@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr); ...@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
*/ */
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
/*! \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> FreeTypeVars(const Type& t);
/*! \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
/*! \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> BoundTypeVars(const Type& t);
/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);
/*! \brief Get all type variables in type t.
*
* \param t the type.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> AllTypeVars(const Type& t);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
* It will remove let bindings which are not referenced, and branches that will * It will remove let bindings which are not referenced, and branches that will
......
...@@ -158,6 +158,38 @@ def free_vars(expr): ...@@ -158,6 +158,38 @@ def free_vars(expr):
return _ir_pass.free_vars(expr) return _ir_pass.free_vars(expr)
def bound_vars(expr):
"""Get bound vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _ir_pass.bound_vars(expr)
def all_vars(expr):
"""Get all vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _ir_pass.all_vars(expr)
def free_type_vars(expr): def free_type_vars(expr):
"""Get free type variables from expression/type e """Get free type variables from expression/type e
...@@ -168,12 +200,44 @@ def free_type_vars(expr): ...@@ -168,12 +200,44 @@ def free_type_vars(expr):
Returns Returns
------- -------
free : List[tvm.relay.TypeParam] free : List[tvm.relay.TypeVar]
The list of free type variables The list of free type variables in post-DFS order
""" """
return _ir_pass.free_type_vars(expr) return _ir_pass.free_type_vars(expr)
def bound_type_vars(expr):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)
def all_type_vars(expr):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)
def simplify_inference(expr): def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase. """ Simplify the data-flow graph for inference phase.
......
...@@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
}); });
return Pair(res.foward, grad); return Pair(res.foward, grad);
}); });
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt; std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) { for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation); vt.push_back(p->type_annotation);
} }
return FunctionNode::make(f->params,
body, if (!missing) {
TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}), ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
{}); }
return FunctionNode::make(f->params, body, ret_type, {});
} }
TVM_REGISTER_API("relay._ir_pass.first_order_gradient") TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
......
...@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types, ...@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types,
return true; return true;
} }
bool MakeTupleRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size());
for (int i = 0; i < num_inputs; ++i) {
if (types[i].as<IncompleteTypeNode>()) return false;
}
Array<Type> fields;
for (int i = 0; i < num_inputs; ++i) {
fields.push_back(types[i]);
}
reporter->Assign(types[num_inputs], TupleTypeNode::make(fields));
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel); TupleGetItemRel);
TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
MakeTupleRel);
struct ResolvedTypeInfo { struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {} : checked_type(checked_type), type_args(type_args) {}
...@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type inferencer will populate it up // type inferencer will populate it up
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_; std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
// used to ensure we don't have free type vars hanging around
// (a temporary measure until we have proper generalization implemented)
Map<TypeVar, Type> instantiation_map_;
// The solver used by the inferencer. // The solver used by the inferencer.
TypeSolver solver_; TypeSolver solver_;
// relation function // relation function
...@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return Type(); return Type();
} }
} }
// Substitutes every type var in t with a corresponding incomplete type.
// This is a temporary measure to ensure type vars behave until
// generalization is properly implemented.
Type Instantiate(const Type &t) {
if (!t.defined()) {
return t;
}
auto* ft = t.as<FuncTypeNode>();
if (ft == nullptr) {
return Bind(t, instantiation_map_);
}
for (auto type_param : ft->type_params) {
instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType));
}
Type ret_type = ft->ret_type;
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}
auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints);
return Bind(strip_tvs, instantiation_map_);
}
// Lazily get type for expr // Lazily get type for expr
// will call visit to deduce it if it is not in the type_map_ // will call visit to deduce it if it is not in the type_map_
Type GetType(const Expr &expr) { Type GetType(const Expr &expr) {
...@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (it != type_map_.end() && it->second.checked_type.defined()) { if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type; return it->second.checked_type;
} }
Type ret = this->VisitExpr(expr); Type ret = Instantiate(this->VisitExpr(expr));
ResolvedTypeInfo& rti = type_map_[expr]; ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret; rti.checked_type = ret;
return ret; return ret;
...@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const TupleNode* op) final { Type VisitExpr_(const TupleNode* op) final {
if (!make_tuple_rel_.defined()) {
make_tuple_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
}
Array<Type> types; Array<Type> types;
for (Expr field : op->fields) { for (Expr field : op->fields) {
types.push_back(GetType(field)); types.push_back(GetType(field));
} }
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); return TupleTypeNode::make(types);
types.push_back(rtype);
solver_.AddConstraint(TypeRelationNode::make(
make_tuple_rel_, types, op->fields.size(), Attrs()));
return rtype;
} }
Type VisitExpr_(const TupleGetItemNode* op) final { Type VisitExpr_(const TupleGetItemNode* op) final {
...@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const LetNode* op) final { Type VisitExpr_(const LetNode* op) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = op->value.as<FunctionNode>() != nullptr;
if (is_functional_literal) {
type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}
Type vtype = GetType(op->value); Type vtype = GetType(op->value);
if (op->var->type_annotation.defined()) { if (op->var->type_annotation.defined()) {
vtype = Unify(vtype, op->var->type_annotation, op->span); vtype = Unify(vtype, op->var->type_annotation, op->span);
} }
CHECK(!type_map_.count(op->var)); CHECK(is_functional_literal || !type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[op->var].checked_type = vtype; type_map_[op->var].checked_type = vtype;
return GetType(op->body); return GetType(op->body);
...@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return rtype; return rtype;
} }
// instantiate the function type with fresh // substitute the type args in the function type
FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) { FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array<Type>& ty_args) {
tvm::Map<TypeVar, Type> subst_map; tvm::Map<TypeVar, Type> subst_map;
// Build a subsitituion map up from the function type and type arguments. // Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in. // Eventually allow the type vars to be passed in.
for (auto ty_param : fn_ty->type_params) { for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); subst_map.Set(fn_ty->type_params[i], ty_args[i]);
subst_map.Set(ty_param, fresh);
ty_args->push_back(fresh);
} }
Type ret_type = fn_ty->ret_type; Type ret_type = fn_ty->ret_type;
...@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type GeneralCall(const CallNode* call, Array<Type> arg_types) { Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
Type ftype = GetType(call->op); Type ftype = GetType(call->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>(); auto* fn_ty_node = ftype.as<FuncTypeNode>();
auto* inc_ty_node = ftype.as<IncompleteTypeNode>();
CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr)
<< "only expressions with function types can be called, found "
<< ftype << " at " << call->span;
// incomplete type => it must be a function taking the arg types
// with an unknown return type
if (inc_ty_node != nullptr) {
Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, call->span);
fn_ty_node = unified.as<FuncTypeNode>();
}
CHECK(fn_ty_node != nullptr) Array<Type> type_args = call->type_args;
<< "only expressions with function types can be called, found " if (type_args.size() == 0) {
<< ftype << " at " << call->span; for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType));
Array<Type> type_args; }
FuncType fn_ty = Instantiate(fn_ty_node, &type_args); }
CHECK(type_args.size() == fn_ty_node->type_params.size())
<< "Incorrect number of type args in " << call->span << ": "
<< "Expected " << fn_ty_node->type_params.size()
<< "but got " << type_args.size();
FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
AddTypeArgs(GetRef<Call>(call), type_args); AddTypeArgs(GetRef<Call>(call), type_args);
...@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const FunctionNode* f) final { Type VisitExpr_(const FunctionNode* f) final {
solver_.Solve();
Array<Type> arg_types;
for (auto param : f->params) { for (auto param : f->params) {
GetType(param); arg_types.push_back(GetType(param));
} }
Type rtype = GetType(f->body); Type rtype = GetType(f->body);
// Run solver using the currently known information if (f->ret_type.defined()) {
solver_.Solve(); rtype = this->Unify(f->ret_type, rtype, f->span);
// Trying to resolve
Array<Type> arg_types;
for (size_t i = 0; i < f->params.size(); ++i) {
Type atype = solver_.Resolve(GetType(f->params[i]));
CHECK(atype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span;
arg_types.push_back(atype);
} }
rtype = solver_.Resolve(rtype); auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
CHECK(rtype.as<IncompleteTypeNode>() == nullptr) return solver_.Resolve(ret);
<< "Cannot resolve return type of function at" << f->span;
// do not support constraint lifting for now.
return FuncTypeNode::make(arg_types, rtype, f->type_params, {});
} }
}; };
...@@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator {
public: public:
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap, Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
TypeSolver* solver) TypeSolver* solver)
: tmap_(tmap), solver_(solver) { : tmap_(tmap), solver_(solver) {
} }
Expr VisitExpr_(const VarNode* op) final { Expr VisitExpr_(const VarNode* op) final {
...@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) { ...@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) {
GetType(expr); GetType(expr);
// Step 1: Solve the constraints. // Step 1: Solve the constraints.
solver_.Solve(); solver_.Solve();
// Step 2: Attach resolved types to checked_type field. // Step 2: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr)); CHECK(WellFormed(resolved_expr));
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <string> #include <string>
#include "type_solver.h" #include "type_solver.h"
#include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode { ...@@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode {
TypeSolver* solver_; TypeSolver* solver_;
}; };
class TypeSolver::OccursChecker : public TypeVisitor {
public:
explicit OccursChecker(TypeSolver* solver, TypeNode* var)
: solver_(solver), var_(var), found_(false) {}
bool Check(const Type& t) {
VisitType(t);
return found_;
}
void VisitType_(const IncompleteTypeNode* op) override {
IncompleteType t = GetRef<IncompleteType>(op);
TypeNode* node = solver_->GetTypeNode(t);
found_ = found_ || (var_->FindRoot() == node->FindRoot());
}
private:
TypeSolver* solver_;
TypeNode* var_;
bool found_;
};
class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
public:
explicit Unifier(TypeSolver* solver) : solver_(solver) {}
Type Unify(const Type& src, const Type& dst) {
// Known limitation
// - handle shape pattern matching
TypeNode* lhs = solver_->GetTypeNode(dst);
TypeNode* rhs = solver_->GetTypeNode(src);
// do occur check so we don't create self-referencing structure
if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type;
}
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!CheckOccurs(lhs, rhs->resolved_type))
<< "Incomplete type " << lhs->resolved_type << " occurs in "
<< rhs->resolved_type << ", cannot unify";
solver_->MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
CHECK(!CheckOccurs(rhs, lhs->resolved_type))
<< "Incomplete type " << rhs->resolved_type << " occurs in "
<< lhs->resolved_type << ", cannot unify";
solver_->MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
CHECK(resolved.defined())
<< "Unable to unify parent types: "
<< lhs->resolved_type << " and " << rhs->resolved_type;
TypeNode* top = solver_->GetTypeNode(resolved);
solver_->MergeFromTo(lhs, top);
solver_->MergeFromTo(rhs, top);
return resolved;
}
}
// Checks whether lhs (taken to be a type var) occurs in t, meaning
// there is a recursive equality constraint, which should be rejected.
// N.b.: A tautology like ?a = ?a is okay and should be checked for
// *before* calling this method
bool CheckOccurs(TypeNode* lhs, const Type& t) {
OccursChecker rc(solver_, lhs);
return rc.Check(t);
}
// default: unify only if alpha-equal
Type VisitTypeDefault_(const Node* op, const Type& tn) override {
NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) {
return Type(nullptr);
}
return t1;
}
Type VisitType_(const TupleTypeNode* op, const Type& tn) override {
const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) {
return Type(nullptr);
}
TupleType tt1 = GetRef<TupleType>(op);
TupleType tt2 = GetRef<TupleType>(ttn);
std::vector<Type> new_fields;
for (size_t i = 0; i < tt1->fields.size(); i++) {
Type field = Unify(tt1->fields[i], tt2->fields[i]);
new_fields.push_back(field);
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const FuncTypeNode* op, const Type& tn) override {
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
|| op->type_params.size() != ftn->type_params.size()
|| op->type_constraints.size() != ftn->type_constraints.size()) {
return Type(nullptr);
}
// remap type vars so they match
Map<TypeVar, Type> subst_map;
for (size_t i = 0; i < op->type_params.size(); i++) {
subst_map.Set(ftn->type_params[i], op->type_params[i]);
}
auto ft1 = GetRef<FuncType>(op);
auto ft2 = Downcast<FuncType>(Bind(GetRef<FuncType>(ftn), subst_map));
Type ret_type = Unify(ft1->ret_type, ft2->ret_type);
std::vector<Type> arg_types;
for (size_t i = 0; i < ft1->arg_types.size(); i++) {
Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]);
arg_types.push_back(arg_type);
}
std::vector<TypeConstraint> type_constraints;
for (size_t i = 0; i < ft1->type_constraints.size(); i++) {
Type unified_constraint = Unify(ft1->type_constraints[i],
ft2->type_constraints[i]);
const auto* tcn = unified_constraint.as<TypeConstraintNode>();
CHECK(tcn) << "Two type constraints unified into a non-constraint?"
<< ft1->type_constraints[i] << " and " << ft2->type_constraints[i];
type_constraints.push_back(GetRef<TypeConstraint>(tcn));
}
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
}
private:
TypeSolver* solver_;
};
class TypeSolver::Resolver : public TypeMutator {
public:
explicit Resolver(TypeSolver* solver) : solver_(solver) {}
Type Resolve(const Type& t) {
if (!t.defined()) {
return t;
}
return VisitType(t);
}
Type VisitType_(const IncompleteTypeNode* op) override {
auto* node = solver_->GetTypeNode(GetRef<IncompleteType>(op));
return node->resolved_type;
}
private:
TypeSolver* solver_;
};
// It ends up being more compact to simply have TypeFunctor<void(const Type&) than
// a TypeVisitor because we can use the default case to dispense with
// most of the overrides.
class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
public:
explicit Propagator(TypeSolver* solver, const std::unordered_set<RelationNode*>* rels)
: solver_(solver), rels_(rels) {}
// adds the relation node to t and all child types of t
void Propagate(const Type& t) {
VisitType(t);
}
void UpdateRelSet(const Type& t) {
TypeNode* tnode = solver_->GetTypeNode(t);
for (auto* rel : *rels_) {
tnode->rel_set.insert(rel);
}
}
void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
UpdateRelSet(t);
}
void VisitType_(const TupleTypeNode* op) override {
TupleType tt = GetRef<TupleType>(op);
UpdateRelSet(tt);
for (const Type& t : tt->fields) {
Propagate(t);
}
}
void VisitType_(const FuncTypeNode* op) override {
FuncType ft = GetRef<FuncType>(op);
UpdateRelSet(ft);
Propagate(ft->ret_type);
for (auto arg_type : ft->arg_types) {
Propagate(arg_type);
}
for (auto type_param : ft->type_params) {
Propagate(type_param);
}
for (auto type_cs : ft->type_constraints) {
Propagate(type_cs);
}
}
private:
TypeSolver* solver_;
const std::unordered_set<RelationNode*>* rels_;
};
// similarly, we use TypeFunctor<void(const Type&)> so we can use
// the default visitor case to avoid more overrides
class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
public:
explicit Merger(TypeSolver* solver) : solver_(solver) {}
// Merges src node to dst, ensures *all* type relations of all
// child nodes of src are transferred to dst.
void Merge(TypeNode* src, TypeNode* dst) {
if (src == dst) return;
dst_ = dst;
VisitType(src->resolved_type);
// set parent at the end so later calls to GetTypeNode go back to src
src->parent = dst;
// now propagate relations to child nodes, since change to
// a child node should update parent too
Propagator prop(solver_, &dst->rel_set);
prop.Propagate(dst->resolved_type);
}
// Transfers any relations linked to t to the stored dst.
// Any unresolved relations are added back to the queue, since
// there is now new information
void TransferLinks(const Type& t) {
TypeNode* src = solver_->GetTypeNode(t);
if (src == dst_) return;
for (auto* rel : src->rel_set) {
// if the relation is not yet resolved, add to queue
if (!rel->resolved) {
solver_->AddToQueue(rel);
dst_->rel_set.insert(rel);
}
}
}
void VisitTypeDefault_(const Node* op) override {
NodeRef nr = GetRef<NodeRef>(op);
Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
TransferLinks(t);
}
void VisitType_(const TupleTypeNode* ttn) override {
auto tup = GetRef<TupleType>(ttn);
TransferLinks(tup);
for (auto field : tup->fields) {
VisitType(field);
}
}
void VisitType_(const FuncTypeNode* ftn) override {
auto func = GetRef<FuncType>(ftn);
TransferLinks(func);
VisitType(func->ret_type);
for (auto arg : func->arg_types) {
VisitType(arg);
}
for (auto param : func->type_params) {
VisitType(param);
}
for (auto constraint : func->type_constraints) {
VisitType(constraint);
}
}
private:
TypeSolver* solver_;
TypeNode* dst_;
};
// constructor // constructor
TypeSolver::TypeSolver() TypeSolver::TypeSolver()
: reporter_(make_node<Reporter>(this)) { : reporter_(make_node<Reporter>(this)) {
} }
// destructor // destructor
...@@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() { ...@@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() {
} }
} }
// merge src type node to dst
void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
Merger merger(this);
merger.Merge(src, dst);
}
// Add equality constraint // Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src) { Type TypeSolver::Unify(const Type& dst, const Type& src) {
// Known limitation Unifier unifier(this);
// - handle composite types whose component can be unknown. return unifier.Unify(dst, src);
// - handle shape pattern matching
TypeNode* lhs = GetTypeNode(dst);
TypeNode* rhs = GetTypeNode(src);
// do occur check so we don't create self-referencing structure
if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type;
}
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
lhs->parent = rhs;
CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type))
<< "Incompatible parent types in UF:"
<< lhs->resolved_type << " and " << rhs->resolved_type;
return rhs->resolved_type;
}
} }
// Add type constraint to the solver. // Add type constraint to the solver.
...@@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { ...@@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
tlink->value = tnode; tlink->value = tnode;
rnode->type_list.Push(tlink); rnode->type_list.Push(tlink);
// insert type->relation node // insert type->relation node
LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >(); std::unordered_set<RelationNode*> singleton { rnode };
rlink->value = rnode; Propagator prop(this, &singleton);
tnode->rel_list.Push(rlink); prop.Propagate(tnode->resolved_type);
} }
// add the relation to the working queue. // add the relation to the working queue.
this->AddToQueue(rnode); this->AddToQueue(rnode);
...@@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { ...@@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
// Resolve a type in the solver context. // Resolve a type in the solver context.
Type TypeSolver::Resolve(const Type& type) { Type TypeSolver::Resolve(const Type& type) {
Resolver resolver(this);
auto it = tmap_.find(type); auto it = tmap_.find(type);
if (it != tmap_.end()) { Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type;
return it->second->FindRoot()->resolved_type; return resolver.Resolve(t);
} else {
return type;
}
} }
bool TypeSolver::Solve() { bool TypeSolver::Solve() {
...@@ -128,7 +401,7 @@ bool TypeSolver::Solve() { ...@@ -128,7 +401,7 @@ bool TypeSolver::Solve() {
// update the relation with given evidence. // update the relation with given evidence.
Array<Type> args; Array<Type> args;
for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) {
args.push_back(tlink->value->FindRoot()->resolved_type); args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
CHECK_LE(args.size(), rel->args.size()); CHECK_LE(args.size(), rel->args.size());
} }
// call the function // call the function
...@@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") ...@@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
return solver->Solve(); return solver->Solve();
}); });
} else if (name == "Unify") { } else if (name == "Unify") {
return TypedPackedFunc<void(Type, Type)>([solver](Type lhs, Type rhs) { return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
solver->Unify(lhs, rhs); return solver->Unify(lhs, rhs);
}); });
} else if (name == "Resolve") { } else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) { return TypedPackedFunc<Type(Type)>([solver](Type t) {
......
...@@ -18,6 +18,7 @@ namespace relay { ...@@ -18,6 +18,7 @@ namespace relay {
using common::LinkNode; using common::LinkNode;
using common::LinkedList; using common::LinkedList;
/*! /*!
* \brief Interface of type solver used in type inference. * \brief Interface of type solver used in type inference.
* *
...@@ -65,6 +66,11 @@ class TypeSolver { ...@@ -65,6 +66,11 @@ class TypeSolver {
Type Unify(const Type& lhs, const Type& rhs); Type Unify(const Type& lhs, const Type& rhs);
private: private:
class OccursChecker;
class Unifier;
class Resolver;
class Propagator;
class Merger;
class Reporter; class Reporter;
struct TypeNode; struct TypeNode;
struct RelationNode; struct RelationNode;
...@@ -77,15 +83,15 @@ class TypeSolver { ...@@ -77,15 +83,15 @@ class TypeSolver {
* that can unifies the same types to the name resolved_type. * that can unifies the same types to the name resolved_type.
* *
* It also contains collection of links to related Relations, * It also contains collection of links to related Relations,
* which is stored in rel_list. * which is stored in rel_set.
*/ */
struct TypeNode { struct TypeNode {
/*! \brief The final resolved type */ /*! \brief The final resolved type */
Type resolved_type; Type resolved_type;
/*! \brief type node in the union find algorithm */ /*! \brief type node in the union find algorithm */
TypeNode* parent{nullptr}; TypeNode* parent{nullptr};
/*! \brief list of relations that is related to this type node */ /*! \brief set of relations that is related to this type node */
LinkedList<RelationNode*> rel_list; std::unordered_set<RelationNode*> rel_set;
/*! /*!
* \brief Find the root type node, perform path compression * \brief Find the root type node, perform path compression
* \return The root type node. * \return The root type node.
...@@ -125,7 +131,7 @@ class TypeSolver { ...@@ -125,7 +131,7 @@ class TypeSolver {
size_t num_resolved_rels_{0}; size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */ /*! \brief map from type node to types. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_; std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \breif Internal queue to update the relation */ /*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_; std::queue<RelationNode*> update_queue_;
/*! \brief allocator of all the internal node obhect*/ /*! \brief allocator of all the internal node obhect*/
common::Arena arena_; common::Arena arena_;
...@@ -163,22 +169,7 @@ class TypeSolver { ...@@ -163,22 +169,7 @@ class TypeSolver {
* \param src The source operand * \param src The source operand
* \param dst The dst operand. * \param dst The dst operand.
*/ */
void MergeFromTo(TypeNode* src, TypeNode* dst) { void MergeFromTo(TypeNode* src, TypeNode* dst);
if (src == dst) return;
src->parent = dst;
// move the link to the to dst
for (auto* rlink = src->rel_list.head; rlink != nullptr;) {
// store next pointer first before rlink get moved
auto* next = rlink->next;
// if the relation is not yet resolved
// send the relation to the new
if (!rlink->value->resolved) {
this->AddToQueue(rlink->value);
dst->rel_list.Push(rlink);
}
rlink = next;
}
}
}; };
} // namespace relay } // namespace relay
......
...@@ -12,105 +12,211 @@ ...@@ -12,105 +12,211 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// FreeTypeVar template<typename T>
class FreeTypeVarTVisitor : public TypeVisitor { struct InsertionSet {
std::unordered_set<T, NodeHash, NodeEqual> set;
std::vector<T> data;
void Insert(const T& t) {
if (set.count(t) == 0) {
set.insert(t);
data.push_back(t);
}
}
};
class TypeVarTVisitor : public TypeVisitor {
public: public:
FreeTypeVarTVisitor( TypeVarTVisitor(
Array<TypeVar>* free_vars, InsertionSet<TypeVar>* type_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) InsertionSet<TypeVar>* bound_type_vars)
: free_vars_(free_vars), bound_vars_(bound_vars) { } : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
void VisitType_(const TypeVarNode* tp) final { void VisitType_(const TypeVarNode* tp) final {
TypeVar var = GetRef<TypeVar>(tp); TypeVar var = GetRef<TypeVar>(tp);
if (bound_vars_->count(var) == 0) { type_vars_->Insert(var);
free_vars_->push_back(var);
}
} }
void VisitType_(const FuncTypeNode* f) final { void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) { for (auto type_param : f->type_params) {
bound_vars_->insert(type_param); type_vars_->Insert(type_param);
bound_type_vars_->Insert(type_param);
} }
TypeVisitor::VisitType_(f); TypeVisitor::VisitType_(f);
} }
private: private:
Array<TypeVar>* free_vars_; InsertionSet<TypeVar>* type_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars_; InsertionSet<TypeVar>* bound_type_vars_;
}; };
class FreeTypeVarEVisitor : private ExprVisitor { class TypeVarEVisitor : private ExprVisitor {
public: public:
Array<TypeVar> Find(const Expr& expr) { Array<TypeVar> CollectFree() {
this->VisitExpr(expr); Array<TypeVar> ret;
return free_vars_; for (const auto& v : type_vars_.data) {
if (bound_type_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
}
Array<TypeVar> CollectBound() {
Array<TypeVar> ret;
for (const auto& v : bound_type_vars_.data) {
ret.push_back(v);
}
return ret;
}
Array<TypeVar> CollectAll() {
Array<TypeVar> ret;
for (const auto& v : type_vars_.data) {
ret.push_back(v);
}
return ret;
} }
Array<TypeVar> Find(const Type& type) { Array<TypeVar> Free(const Expr& expr) {
this->VisitType(type); VisitExpr(expr);
return free_vars_; return CollectFree();
}
Array<TypeVar> Free(const Type& type) {
VisitType(type);
return CollectFree();
}
Array<TypeVar> Bound(const Expr& expr) {
VisitExpr(expr);
return CollectBound();
}
Array<TypeVar> Bound(const Type& type) {
VisitType(type);
return CollectBound();
}
Array<TypeVar> All(const Expr& expr) {
VisitExpr(expr);
return CollectAll();
}
Array<TypeVar> All(const Type& type) {
VisitType(type);
return CollectAll();
} }
void VisitExpr_(const FunctionNode* f) final { void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
bound_vars_.insert(tp); type_vars_.Insert(tp);
bound_type_vars_.Insert(tp);
} }
ExprVisitor::VisitExpr_(f); ExprVisitor::VisitExpr_(f);
} }
void VisitType(const Type& t) final { void VisitType(const Type& t) final {
FreeTypeVarTVisitor(&free_vars_, &bound_vars_) TypeVarTVisitor(&type_vars_, &bound_type_vars_)
.VisitType(t); .VisitType(t);
} }
private: private:
// The result list InsertionSet<TypeVar> type_vars_;
Array<TypeVar> free_vars_; InsertionSet<TypeVar> bound_type_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_vars_;
}; };
class FreeVarVisitor : protected ExprVisitor { class VarVisitor : protected ExprVisitor {
public: public:
Array<Var> Find(const Expr& expr) { Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr); this->VisitExpr(expr);
return free_vars_; Array<Var> ret;
for (const auto& v : vars_.data) {
if (bound_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
} }
void VisitExpr_(const VarNode* var) final { Array<Var> Bound(const Expr& expr) {
if (bound_vars_.count(var) == 0) { this->VisitExpr(expr);
free_vars_.push_back(GetRef<Var>(var)); Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
} }
return ret;
}
Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
ret.push_back(v);
}
return ret;
}
void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
}
void VisitExpr_(const VarNode* var) final {
vars_.Insert(GetRef<Var>(var));
} }
void VisitExpr_(const FunctionNode* op) final { void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) { for (const auto& param : op->params) {
bound_vars_.insert(param.operator->()); MarkBounded(param);
} }
VisitExpr(op->body); VisitExpr(op->body);
} }
void VisitExpr_(const LetNode* op) final { void VisitExpr_(const LetNode* op) final {
bound_vars_.insert(op->var.operator->()); MarkBounded(op->var);
VisitExpr(op->value); VisitExpr(op->value);
VisitExpr(op->body); VisitExpr(op->body);
} }
private: private:
// The result list InsertionSet<Var> vars_;
Array<Var> free_vars_; InsertionSet<Var> bound_vars_;
std::unordered_set<const VarNode*> bound_vars_;
}; };
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) { tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) {
return FreeTypeVarEVisitor().Find(expr); return TypeVarEVisitor().Free(expr);
} }
tvm::Array<TypeVar> FreeTypeVars(const Type& type) { tvm::Array<TypeVar> FreeTypeVars(const Type& type) {
return FreeTypeVarEVisitor().Find(type); return TypeVarEVisitor().Free(type);
}
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr) {
return TypeVarEVisitor().Bound(expr);
}
tvm::Array<TypeVar> BoundTypeVars(const Type& type) {
return TypeVarEVisitor().Bound(type);
}
tvm::Array<TypeVar> AllTypeVars(const Expr& expr) {
return TypeVarEVisitor().All(expr);
}
tvm::Array<TypeVar> AllTypeVars(const Type& type) {
return TypeVarEVisitor().All(type);
} }
tvm::Array<Var> FreeVars(const Expr& expr) { tvm::Array<Var> FreeVars(const Expr& expr) {
return FreeVarVisitor().Find(expr); return VarVisitor().Free(expr);
}
tvm::Array<Var> BoundVars(const Expr& expr) {
return VarVisitor().Bound(expr);
}
tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr);
} }
TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.free_vars")
...@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") ...@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
*ret = FreeVars(args[0]); *ret = FreeVars(args[0]);
}); });
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BoundVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AllVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
if (x.as<TypeNode>()) { if (x.as_derived<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x)); *ret = FreeTypeVars(Downcast<Type>(x));
} else { } else {
*ret = FreeTypeVars(Downcast<Expr>(x)); *ret = FreeTypeVars(Downcast<Expr>(x));
} }
}); });
TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as_derived<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x));
} else {
*ret = BoundTypeVars(Downcast<Expr>(x));
}
});
TVM_REGISTER_API("relay._ir_pass.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as_derived<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x));
} else {
*ret = AllTypeVars(Downcast<Expr>(x));
}
});
/*! /*!
* \brief Get reference counter of each internal ExprNode in body. * \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression. * \param body The body expression.
......
...@@ -6,13 +6,17 @@ ...@@ -6,13 +6,17 @@
TEST(Relay, SelfReference) { TEST(Relay, SelfReference) {
using namespace tvm; using namespace tvm;
auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool());
auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); auto x = relay::VarNode::make("x", relay::Type());
auto x = relay::VarNode::make("x", type_a); auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x }); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{})); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a);
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected));
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
import tvm
from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars():
ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
def test_tuple():
t = relay.Var('t')
fv = free_vars(relay.Tuple([t, t]))
assert len(fv) == 1
assert fv[0] == t
fv = free_vars(relay.TupleGetItem(t, 123))
assert len(fv) == 1
assert fv[0] == t
def test_free_type_vars():
tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty)
y = relay.Var("y")
let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
ftvl = free_type_vars(let)
assert len(ftvl) == 1
assert ftvl[0] == tp
import tvm
from tvm import relay
from tvm.relay.ir_pass import (free_vars, free_type_vars,
bound_vars, bound_type_vars,
all_vars, all_type_vars)
def assert_vars_match(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i] == expected[i]
def test_free_vars():
ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
def test_free_vars_tuple():
t = relay.Var('t')
fv = free_vars(relay.Tuple([t, t]))
assert len(fv) == 1
assert fv[0] == t
fv = free_vars(relay.TupleGetItem(t, 123))
assert len(fv) == 1
assert fv[0] == t
def test_free_type_vars():
tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty)
y = relay.Var("y")
let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
ftvl = free_type_vars(let)
assert len(ftvl) == 1
assert ftvl[0] == tp
def test_bound_vars():
x = relay.Var("x")
y = relay.Var("y")
z = relay.Var("z")
a = relay.Var("a")
f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([])))
assert_vars_match(bound_vars(f1), [x, y, z, a])
tup = relay.Tuple([x, y, z, a])
assert len(bound_vars(tup)) == 0
f2 = relay.Function([x, y], relay.Tuple([x, y, z, a]))
assert_vars_match(bound_vars(f2), [x, y])
def test_bound_type_vars():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
c = relay.TypeVar("c")
ft1 = relay.FuncType([a], b, [a, b])
bound_ft1 = bound_type_vars(ft1)
assert_vars_match(bound_type_vars(ft1), [a, b])
ft2 = relay.FuncType([], c, [a])
assert_vars_match(bound_type_vars(ft2), [a])
tup_ty = relay.TupleType([a, b, c])
assert len(bound_type_vars(tup_ty)) == 0
f1 = relay.Function([], relay.Tuple([]), type_params=[a, b])
assert_vars_match(bound_type_vars(f1), [a, b])
f2 = relay.Function([], relay.Tuple([]), c)
assert len(bound_type_vars(f2)) == 0
x = relay.Var("x", a)
let1 = relay.Let(x, relay.Tuple([]), x)
assert len(bound_type_vars(let1)) == 0
let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x)
assert_vars_match(bound_type_vars(let2), [b, c])
def test_all_vars():
x = relay.Var("x")
y = relay.Var("y")
z = relay.Var("z")
f1 = relay.Function([x, y], z)
assert_vars_match(all_vars(f1), [x, y, z])
f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z))
assert_vars_match(all_vars(f2), [x, y, z])
f3 = relay.Function([x], relay.Tuple([y, z]))
assert_vars_match(all_vars(f3), [x, y, z])
tup = relay.Tuple([x, y, z])
assert_vars_match(all_vars(tup), [x, y, z])
def test_all_type_vars():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
c = relay.TypeVar("c")
ft1 = relay.FuncType([b], c, [a])
assert_vars_match(all_type_vars(ft1), [a, b, c])
ft2 = relay.FuncType([], relay.TupleType([a, b, c]), [])
assert_vars_match(all_type_vars(ft2), [a, b, c])
w = relay.Var("w")
x = relay.Var("x", a)
y = relay.Var("y", b)
z = relay.Var("z", c)
f1 = relay.Function([x], y, b, [a])
assert_vars_match(all_type_vars(f1), [a, b])
f2 = relay.Function([x], relay.Let(y, x, z))
assert_vars_match(all_type_vars(f2), [a, b, c])
f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c]))
assert_vars_match(all_type_vars(f3), [a, b, c])
f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c])
assert_vars_match(all_type_vars(f4), [a, b, c])
f5 = relay.Function([w], w)
assert len(all_type_vars(f5)) == 0
...@@ -23,7 +23,7 @@ def test_monomorphic_let(): ...@@ -23,7 +23,7 @@ def test_monomorphic_let():
x = sb.let('x', relay.const(1.0, "float64")) x = sb.let('x', relay.const(1.0, "float64"))
sb.ret(x) sb.ret(x)
xchecked = relay.ir_pass.infer_type(sb.get()) xchecked = relay.ir_pass.infer_type(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64") assert xchecked.checked_type == relay.scalar_type("float64" )
def test_single_op(): def test_single_op():
...@@ -41,14 +41,15 @@ def test_add_broadcast_op(): ...@@ -41,14 +41,15 @@ def test_add_broadcast_op():
return x + y; return x + y;
} }
""" """
pass x = relay.var('x', shape=(10, 4))
# x = relay.var('x', shape=(10, 4)) y = relay.var('y', shape=(5, 10, 1))
# y = relay.var('y', shape=(5, 10, 1)) z = x + y
# z = x + y func = relay.Function([x, y], z)
# func = relay.Function([x, y], z) t1 = relay.TensorType((10, 4), 'float32')
# ttype = relay.TensorType((5, 5, 5), 'float32') t2 = relay.TensorType((5, 10, 1), 'float32')
# expected_ty = relay.FuncType([ttype, ttype], ttype) t3 = relay.TensorType((5, 10, 4), 'float32')
# assert_has_type(func.to_func(), expected_ty) expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)
def test_dual_op(): def test_dual_op():
...@@ -110,24 +111,17 @@ def test_recursion(): ...@@ -110,24 +111,17 @@ def test_recursion():
assert "%3 = @f(%1, %2)" in mod.astext() assert "%3 = @f(%1, %2)" in mod.astext()
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system.
#
# This test is to illustrate problem with our weak form of
# unification.
#
def test_incomplete_call(): def test_incomplete_call():
sb = ScopeBuilder() tt = relay.scalar_type('int32')
x = relay.var('x', dtype='int32') x = relay.var('x', tt)
f = relay.var('f') f = relay.var('f')
func = relay.Function([x, f], relay.Call(f, [x])) func = relay.Function([x, f], relay.Call(f, [x]), tt)
ft = relay.ir_pass.infer_type(func)
f_type = relay.FuncType([tt], tt)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)
try:
relay.ir_pass.infer_type(func)
assert False
except tvm.TVMError as e:
assert True
def test_tuple(): def test_tuple():
tp = relay.TensorType((10,)) tp = relay.TensorType((10,))
...@@ -136,6 +130,7 @@ def test_tuple(): ...@@ -136,6 +130,7 @@ def test_tuple():
assert (relay.ir_pass.infer_type(res).checked_type == assert (relay.ir_pass.infer_type(res).checked_type ==
relay.TupleType([tp, tp])) relay.TupleType([tp, tp]))
def test_free_expr(): def test_free_expr():
x = relay.var("x", "float32") x = relay.var("x", "float32")
y = relay.add(x, x) y = relay.add(x, x)
...@@ -161,38 +156,26 @@ def test_type_args(): ...@@ -161,38 +156,26 @@ def test_type_args():
assert sh2[1].value == 10 assert sh2[1].value == 10
def test_self_reference(): def test_global_var_recursion():
"""
Program:
def f(x) {
return x;
}
"""
a = relay.TypeVar("a")
x = relay.var("x", a)
sb = relay.ScopeBuilder()
f = relay.Function([x], x)
fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a
def test_global_var_cow_issue():
mod = relay.Module({}) mod = relay.Module({})
gv = relay.GlobalVar("foo") gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), tt = relay.scalar_type('float32')
relay.TensorType([], 'float32'))
func = relay.Function([x], relay.Call(gv, [x]), tt)
mod[gv] = func mod[gv] = func
ft = relay.ir_pass.infer_type(gv, mod)
assert mod[ft].checked_type == relay.FuncType([tt], tt)
def test_equal(): def test_equal():
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
eq = op.equal(i, relay.const(0, dtype='int32')) eq = op.equal(i, relay.const(0, dtype='int32'))
# This should fail .... func = relay.Function([i], eq)
func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) ft = relay.ir_pass.infer_type(func)
assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
if __name__ == "__main__": if __name__ == "__main__":
...@@ -204,8 +187,12 @@ if __name__ == "__main__": ...@@ -204,8 +187,12 @@ if __name__ == "__main__":
test_decl() test_decl()
test_recursion() test_recursion()
test_tuple() test_tuple()
test_generalized_tuple()
test_incomplete_call() test_incomplete_call()
test_generalized_call()
test_call_with_type_args()
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference() test_self_reference()
test_global_var_cow_issue() test_global_var_recursion()
test_equal()
import tvm import tvm
from tvm import relay from tvm import relay
from nose.tools import raises
def make_rel(name, args, num_inputs=None, attrs=None): def make_rel(name, args, num_inputs=None, attrs=None):
...@@ -48,7 +49,170 @@ def test_backward_solving(): ...@@ -48,7 +49,170 @@ def test_backward_solving():
assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")
def test_unify_tuple():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t3, t3])
unified = solver.Unify(tup1, tup2)
assert unified == tup2
def test_unify_functype():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
unit = relay.ty.TupleType([])
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10,), "float32")
ft1 = relay.ty.FuncType([t1, t2], t3)
ft2 = relay.ty.FuncType([tensor1, tensor2], unit)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_recursive_unify():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((10, 10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 20), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2])
ft1 = relay.ty.FuncType([tup1, t3], t3)
ft2 = relay.ty.FuncType([tup2, tensor3], tensor3)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_vars_under_tuples():
solver = make_solver()
t1 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t1])
unified = solver.Unify(tup1, tup1)
assert unified == tup1
t2 = relay.ty.IncompleteType()
tup2 = relay.ty.TupleType([t2, t2])
tup3 = relay.ty.TupleType([t1, t2])
tup4 = relay.ty.TupleType([t2, t1])
unified = solver.Unify(tup3, tup4)
assert (unified == tup1 or unified == tup2)
def test_binding_over_typevars():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
a = relay.ty.TypeVar('a')
b = relay.ty.TypeVar('b')
c = relay.ty.TypeVar('c')
d = relay.ty.TypeVar('d')
ft1 = relay.ty.FuncType([t1], t2, [c, d])
ft2 = relay.ty.FuncType([a], b, [a, b])
unified = solver.Unify(ft1, ft2)
assert (unified == solver.Resolve(ft1))
def test_recursive_backward_solving():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3])
solver.gen_type("Identity", [tup1], out=tup2)
assert solver.Solve()
assert solver.Resolve(tup2) == tup1
def test_backward_solving_after_child_update():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t1, t3])
tup_concrete = relay.ty.TupleType([tensor1, tensor2])
t4 = solver.gen_type("Identity", [tup1])
t5 = solver.gen_type("Identity", [tup2])
solver.gen_type("Identity", [t4], out=t5)
assert solver.Solve()
assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2
assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2
assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2
# updating the variables *inside* tup1 and tup2 should update t4 and t5
solver.gen_type("Identity", [t1], out=tensor1)
solver.gen_type("Identity", [t2], out=tensor2)
assert solver.Solve()
assert solver.Resolve(t4) == tup_concrete
assert solver.Resolve(t5) == tup_concrete
@raises(tvm._ffi.base.TVMError)
def test_incompatible_tuple_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((1, 2, 3), "float32")
tensor2 = relay.ty.TensorType((2, 3), "float32")
tensor3 = relay.ty.TensorType((3,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tup1, tup2)
@raises(tvm._ffi.base.TVMError)
def test_bad_recursive_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
solver.Unify(t1, relay.ty.TupleType([t1, t1]))
if __name__ == "__main__": if __name__ == "__main__":
test_bcast() test_bcast()
test_backward_solving() test_backward_solving()
test_unify_tuple()
test_unify_functype()
test_recursive_unify()
test_unify_vars_under_tuples()
test_recursive_backward_solving()
test_backward_solving_after_child_update()
test_incompatible_tuple_unification()
test_bad_recursive_unification()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment