Commit e0a20ad4 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay] Expand type unification and other utilities (#2189)

parent dfb101a0
...@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); ...@@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/ */
bool WellFormed(const Expr& expr); bool WellFormed(const Expr& expr);
/*! \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> BoundVars(const Expr& expr);
/*! \brief Get free type parameters from expression expr. /*! \brief Get free type parameters from expression expr.
* *
* Free variables are variables that are not bound by a * Free variables are variables that are not bound by a
...@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr); ...@@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
*/ */
tvm::Array<Var> FreeVars(const Expr& expr); tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> AllVars(const Expr& expr);
/*! \brief Get free TypeVars from expression expr. /*! \brief Get free TypeVars from expression expr.
* *
* Free type parameters are type parameters that are not bound by a function * Free type parameters are type parameters that are not bound by a function
...@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr); ...@@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
*/ */
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
/*! \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> FreeTypeVars(const Type& t);
/*! \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
/*! \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> BoundTypeVars(const Type& t);
/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);
/*! \brief Get all type variables in type t.
*
* \param t the type.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> AllTypeVars(const Type& t);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
* It will remove let bindings which are not referenced, and branches that will * It will remove let bindings which are not referenced, and branches that will
......
...@@ -158,6 +158,38 @@ def free_vars(expr): ...@@ -158,6 +158,38 @@ def free_vars(expr):
return _ir_pass.free_vars(expr) return _ir_pass.free_vars(expr)
def bound_vars(expr):
"""Get bound vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _ir_pass.bound_vars(expr)
def all_vars(expr):
"""Get all vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _ir_pass.all_vars(expr)
def free_type_vars(expr): def free_type_vars(expr):
"""Get free type variables from expression/type e """Get free type variables from expression/type e
...@@ -168,12 +200,44 @@ def free_type_vars(expr): ...@@ -168,12 +200,44 @@ def free_type_vars(expr):
Returns Returns
------- -------
free : List[tvm.relay.TypeParam] free : List[tvm.relay.TypeVar]
The list of free type variables The list of free type variables in post-DFS order
""" """
return _ir_pass.free_type_vars(expr) return _ir_pass.free_type_vars(expr)
def bound_type_vars(expr):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)
def all_type_vars(expr):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)
def simplify_inference(expr): def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase. """ Simplify the data-flow graph for inference phase.
......
...@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types, ...@@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array<Type>& types,
return true; return true;
} }
bool MakeTupleRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size());
for (int i = 0; i < num_inputs; ++i) {
if (types[i].as<IncompleteTypeNode>()) return false;
}
Array<Type> fields;
for (int i = 0; i < num_inputs; ++i) {
fields.push_back(types[i]);
}
reporter->Assign(types[num_inputs], TupleTypeNode::make(fields));
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel); TupleGetItemRel);
TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
MakeTupleRel);
struct ResolvedTypeInfo { struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {} : checked_type(checked_type), type_args(type_args) {}
...@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// type inferencer will populate it up // type inferencer will populate it up
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_; std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
// used to ensure we don't have free type vars hanging around
// (a temporary measure until we have proper generalization implemented)
Map<TypeVar, Type> instantiation_map_;
// The solver used by the inferencer. // The solver used by the inferencer.
TypeSolver solver_; TypeSolver solver_;
// relation function // relation function
...@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return Type(); return Type();
} }
} }
// Substitutes every type var in t with a corresponding incomplete type.
// This is a temporary measure to ensure type vars behave until
// generalization is properly implemented.
Type Instantiate(const Type &t) {
if (!t.defined()) {
return t;
}
auto* ft = t.as<FuncTypeNode>();
if (ft == nullptr) {
return Bind(t, instantiation_map_);
}
for (auto type_param : ft->type_params) {
instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType));
}
Type ret_type = ft->ret_type;
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}
auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints);
return Bind(strip_tvs, instantiation_map_);
}
// Lazily get type for expr // Lazily get type for expr
// will call visit to deduce it if it is not in the type_map_ // will call visit to deduce it if it is not in the type_map_
Type GetType(const Expr &expr) { Type GetType(const Expr &expr) {
...@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (it != type_map_.end() && it->second.checked_type.defined()) { if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type; return it->second.checked_type;
} }
Type ret = this->VisitExpr(expr); Type ret = Instantiate(this->VisitExpr(expr));
ResolvedTypeInfo& rti = type_map_[expr]; ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret; rti.checked_type = ret;
return ret; return ret;
...@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const TupleNode* op) final { Type VisitExpr_(const TupleNode* op) final {
if (!make_tuple_rel_.defined()) {
make_tuple_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
}
Array<Type> types; Array<Type> types;
for (Expr field : op->fields) { for (Expr field : op->fields) {
types.push_back(GetType(field)); types.push_back(GetType(field));
} }
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); return TupleTypeNode::make(types);
types.push_back(rtype);
solver_.AddConstraint(TypeRelationNode::make(
make_tuple_rel_, types, op->fields.size(), Attrs()));
return rtype;
} }
Type VisitExpr_(const TupleGetItemNode* op) final { Type VisitExpr_(const TupleGetItemNode* op) final {
...@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const LetNode* op) final { Type VisitExpr_(const LetNode* op) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = op->value.as<FunctionNode>() != nullptr;
if (is_functional_literal) {
type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
}
Type vtype = GetType(op->value); Type vtype = GetType(op->value);
if (op->var->type_annotation.defined()) { if (op->var->type_annotation.defined()) {
vtype = Unify(vtype, op->var->type_annotation, op->span); vtype = Unify(vtype, op->var->type_annotation, op->span);
} }
CHECK(!type_map_.count(op->var)); CHECK(is_functional_literal || !type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[op->var].checked_type = vtype; type_map_[op->var].checked_type = vtype;
return GetType(op->body); return GetType(op->body);
...@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return rtype; return rtype;
} }
// instantiate the function type with fresh // substitute the type args in the function type
FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) { FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array<Type>& ty_args) {
tvm::Map<TypeVar, Type> subst_map; tvm::Map<TypeVar, Type> subst_map;
// Build a subsitituion map up from the function type and type arguments. // Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in. // Eventually allow the type vars to be passed in.
for (auto ty_param : fn_ty->type_params) { for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); subst_map.Set(fn_ty->type_params[i], ty_args[i]);
subst_map.Set(ty_param, fresh);
ty_args->push_back(fresh);
} }
Type ret_type = fn_ty->ret_type; Type ret_type = fn_ty->ret_type;
...@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Type GeneralCall(const CallNode* call, Array<Type> arg_types) { Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
Type ftype = GetType(call->op); Type ftype = GetType(call->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>(); auto* fn_ty_node = ftype.as<FuncTypeNode>();
auto* inc_ty_node = ftype.as<IncompleteTypeNode>();
CHECK(fn_ty_node != nullptr) CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr)
<< "only expressions with function types can be called, found " << "only expressions with function types can be called, found "
<< ftype << " at " << call->span; << ftype << " at " << call->span;
Array<Type> type_args; // incomplete type => it must be a function taking the arg types
FuncType fn_ty = Instantiate(fn_ty_node, &type_args); // with an unknown return type
if (inc_ty_node != nullptr) {
Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, call->span);
fn_ty_node = unified.as<FuncTypeNode>();
}
Array<Type> type_args = call->type_args;
if (type_args.size() == 0) {
for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType));
}
}
CHECK(type_args.size() == fn_ty_node->type_params.size())
<< "Incorrect number of type args in " << call->span << ": "
<< "Expected " << fn_ty_node->type_params.size()
<< "but got " << type_args.size();
FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
AddTypeArgs(GetRef<Call>(call), type_args); AddTypeArgs(GetRef<Call>(call), type_args);
...@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const FunctionNode* f) final { Type VisitExpr_(const FunctionNode* f) final {
solver_.Solve();
Array<Type> arg_types;
for (auto param : f->params) { for (auto param : f->params) {
GetType(param); arg_types.push_back(GetType(param));
} }
Type rtype = GetType(f->body); Type rtype = GetType(f->body);
// Run solver using the currently known information if (f->ret_type.defined()) {
solver_.Solve(); rtype = this->Unify(f->ret_type, rtype, f->span);
// Trying to resolve }
Array<Type> arg_types; auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
for (size_t i = 0; i < f->params.size(); ++i) { return solver_.Resolve(ret);
Type atype = solver_.Resolve(GetType(f->params[i]));
CHECK(atype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span;
arg_types.push_back(atype);
}
rtype = solver_.Resolve(rtype);
CHECK(rtype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve return type of function at" << f->span;
// do not support constraint lifting for now.
return FuncTypeNode::make(arg_types, rtype, f->type_params, {});
} }
}; };
...@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) { ...@@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) {
GetType(expr); GetType(expr);
// Step 1: Solve the constraints. // Step 1: Solve the constraints.
solver_.Solve(); solver_.Solve();
// Step 2: Attach resolved types to checked_type field. // Step 2: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr)); CHECK(WellFormed(resolved_expr));
......
...@@ -18,6 +18,7 @@ namespace relay { ...@@ -18,6 +18,7 @@ namespace relay {
using common::LinkNode; using common::LinkNode;
using common::LinkedList; using common::LinkedList;
/*! /*!
* \brief Interface of type solver used in type inference. * \brief Interface of type solver used in type inference.
* *
...@@ -65,6 +66,11 @@ class TypeSolver { ...@@ -65,6 +66,11 @@ class TypeSolver {
Type Unify(const Type& lhs, const Type& rhs); Type Unify(const Type& lhs, const Type& rhs);
private: private:
class OccursChecker;
class Unifier;
class Resolver;
class Propagator;
class Merger;
class Reporter; class Reporter;
struct TypeNode; struct TypeNode;
struct RelationNode; struct RelationNode;
...@@ -77,15 +83,15 @@ class TypeSolver { ...@@ -77,15 +83,15 @@ class TypeSolver {
* that can unifies the same types to the name resolved_type. * that can unifies the same types to the name resolved_type.
* *
* It also contains collection of links to related Relations, * It also contains collection of links to related Relations,
* which is stored in rel_list. * which is stored in rel_set.
*/ */
struct TypeNode { struct TypeNode {
/*! \brief The final resolved type */ /*! \brief The final resolved type */
Type resolved_type; Type resolved_type;
/*! \brief type node in the union find algorithm */ /*! \brief type node in the union find algorithm */
TypeNode* parent{nullptr}; TypeNode* parent{nullptr};
/*! \brief list of relations that is related to this type node */ /*! \brief set of relations that is related to this type node */
LinkedList<RelationNode*> rel_list; std::unordered_set<RelationNode*> rel_set;
/*! /*!
* \brief Find the root type node, perform path compression * \brief Find the root type node, perform path compression
* \return The root type node. * \return The root type node.
...@@ -125,7 +131,7 @@ class TypeSolver { ...@@ -125,7 +131,7 @@ class TypeSolver {
size_t num_resolved_rels_{0}; size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */ /*! \brief map from type node to types. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_; std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \breif Internal queue to update the relation */ /*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_; std::queue<RelationNode*> update_queue_;
/*! \brief allocator of all the internal node obhect*/ /*! \brief allocator of all the internal node obhect*/
common::Arena arena_; common::Arena arena_;
...@@ -163,22 +169,7 @@ class TypeSolver { ...@@ -163,22 +169,7 @@ class TypeSolver {
* \param src The source operand * \param src The source operand
* \param dst The dst operand. * \param dst The dst operand.
*/ */
void MergeFromTo(TypeNode* src, TypeNode* dst) { void MergeFromTo(TypeNode* src, TypeNode* dst);
if (src == dst) return;
src->parent = dst;
// move the link to the to dst
for (auto* rlink = src->rel_list.head; rlink != nullptr;) {
// store next pointer first before rlink get moved
auto* next = rlink->next;
// if the relation is not yet resolved
// send the relation to the new
if (!rlink->value->resolved) {
this->AddToQueue(rlink->value);
dst->rel_list.Push(rlink);
}
rlink = next;
}
}
}; };
} // namespace relay } // namespace relay
......
...@@ -12,105 +12,211 @@ ...@@ -12,105 +12,211 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// FreeTypeVar template<typename T>
class FreeTypeVarTVisitor : public TypeVisitor { struct InsertionSet {
std::unordered_set<T, NodeHash, NodeEqual> set;
std::vector<T> data;
void Insert(const T& t) {
if (set.count(t) == 0) {
set.insert(t);
data.push_back(t);
}
}
};
class TypeVarTVisitor : public TypeVisitor {
public: public:
FreeTypeVarTVisitor( TypeVarTVisitor(
Array<TypeVar>* free_vars, InsertionSet<TypeVar>* type_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) InsertionSet<TypeVar>* bound_type_vars)
: free_vars_(free_vars), bound_vars_(bound_vars) { } : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
void VisitType_(const TypeVarNode* tp) final { void VisitType_(const TypeVarNode* tp) final {
TypeVar var = GetRef<TypeVar>(tp); TypeVar var = GetRef<TypeVar>(tp);
if (bound_vars_->count(var) == 0) { type_vars_->Insert(var);
free_vars_->push_back(var);
}
} }
void VisitType_(const FuncTypeNode* f) final { void VisitType_(const FuncTypeNode* f) final {
for (auto type_param : f->type_params) { for (auto type_param : f->type_params) {
bound_vars_->insert(type_param); type_vars_->Insert(type_param);
bound_type_vars_->Insert(type_param);
} }
TypeVisitor::VisitType_(f); TypeVisitor::VisitType_(f);
} }
private: private:
Array<TypeVar>* free_vars_; InsertionSet<TypeVar>* type_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars_; InsertionSet<TypeVar>* bound_type_vars_;
}; };
class FreeTypeVarEVisitor : private ExprVisitor { class TypeVarEVisitor : private ExprVisitor {
public: public:
Array<TypeVar> Find(const Expr& expr) { Array<TypeVar> CollectFree() {
this->VisitExpr(expr); Array<TypeVar> ret;
return free_vars_; for (const auto& v : type_vars_.data) {
if (bound_type_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
}
Array<TypeVar> CollectBound() {
Array<TypeVar> ret;
for (const auto& v : bound_type_vars_.data) {
ret.push_back(v);
}
return ret;
}
Array<TypeVar> CollectAll() {
Array<TypeVar> ret;
for (const auto& v : type_vars_.data) {
ret.push_back(v);
}
return ret;
}
Array<TypeVar> Free(const Expr& expr) {
VisitExpr(expr);
return CollectFree();
} }
Array<TypeVar> Find(const Type& type) { Array<TypeVar> Free(const Type& type) {
this->VisitType(type); VisitType(type);
return free_vars_; return CollectFree();
}
Array<TypeVar> Bound(const Expr& expr) {
VisitExpr(expr);
return CollectBound();
}
Array<TypeVar> Bound(const Type& type) {
VisitType(type);
return CollectBound();
}
Array<TypeVar> All(const Expr& expr) {
VisitExpr(expr);
return CollectAll();
}
Array<TypeVar> All(const Type& type) {
VisitType(type);
return CollectAll();
} }
void VisitExpr_(const FunctionNode* f) final { void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
bound_vars_.insert(tp); type_vars_.Insert(tp);
bound_type_vars_.Insert(tp);
} }
ExprVisitor::VisitExpr_(f); ExprVisitor::VisitExpr_(f);
} }
void VisitType(const Type& t) final { void VisitType(const Type& t) final {
FreeTypeVarTVisitor(&free_vars_, &bound_vars_) TypeVarTVisitor(&type_vars_, &bound_type_vars_)
.VisitType(t); .VisitType(t);
} }
private: private:
// The result list InsertionSet<TypeVar> type_vars_;
Array<TypeVar> free_vars_; InsertionSet<TypeVar> bound_type_vars_;
std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_vars_;
}; };
class FreeVarVisitor : protected ExprVisitor { class VarVisitor : protected ExprVisitor {
public: public:
Array<Var> Find(const Expr& expr) { Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr); this->VisitExpr(expr);
return free_vars_; Array<Var> ret;
for (const auto& v : vars_.data) {
if (bound_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
} }
void VisitExpr_(const VarNode* var) final { Array<Var> Bound(const Expr& expr) {
if (bound_vars_.count(var) == 0) { this->VisitExpr(expr);
free_vars_.push_back(GetRef<Var>(var)); Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
}
return ret;
}
Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
ret.push_back(v);
} }
return ret;
}
void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
}
void VisitExpr_(const VarNode* var) final {
vars_.Insert(GetRef<Var>(var));
} }
void VisitExpr_(const FunctionNode* op) final { void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) { for (const auto& param : op->params) {
bound_vars_.insert(param.operator->()); MarkBounded(param);
} }
VisitExpr(op->body); VisitExpr(op->body);
} }
void VisitExpr_(const LetNode* op) final { void VisitExpr_(const LetNode* op) final {
bound_vars_.insert(op->var.operator->()); MarkBounded(op->var);
VisitExpr(op->value); VisitExpr(op->value);
VisitExpr(op->body); VisitExpr(op->body);
} }
private: private:
// The result list InsertionSet<Var> vars_;
Array<Var> free_vars_; InsertionSet<Var> bound_vars_;
std::unordered_set<const VarNode*> bound_vars_;
}; };
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) { tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) {
return FreeTypeVarEVisitor().Find(expr); return TypeVarEVisitor().Free(expr);
} }
tvm::Array<TypeVar> FreeTypeVars(const Type& type) { tvm::Array<TypeVar> FreeTypeVars(const Type& type) {
return FreeTypeVarEVisitor().Find(type); return TypeVarEVisitor().Free(type);
}
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr) {
return TypeVarEVisitor().Bound(expr);
}
tvm::Array<TypeVar> BoundTypeVars(const Type& type) {
return TypeVarEVisitor().Bound(type);
}
tvm::Array<TypeVar> AllTypeVars(const Expr& expr) {
return TypeVarEVisitor().All(expr);
}
tvm::Array<TypeVar> AllTypeVars(const Type& type) {
return TypeVarEVisitor().All(type);
} }
tvm::Array<Var> FreeVars(const Expr& expr) { tvm::Array<Var> FreeVars(const Expr& expr) {
return FreeVarVisitor().Find(expr); return VarVisitor().Free(expr);
}
tvm::Array<Var> BoundVars(const Expr& expr) {
return VarVisitor().Bound(expr);
}
tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr);
} }
TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.free_vars")
...@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") ...@@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars")
*ret = FreeVars(args[0]); *ret = FreeVars(args[0]);
}); });
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BoundVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AllVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
if (x.as<TypeNode>()) { if (x.as_derived<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x)); *ret = FreeTypeVars(Downcast<Type>(x));
} else { } else {
*ret = FreeTypeVars(Downcast<Expr>(x)); *ret = FreeTypeVars(Downcast<Expr>(x));
} }
}); });
TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as_derived<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x));
} else {
*ret = BoundTypeVars(Downcast<Expr>(x));
}
});
TVM_REGISTER_API("relay._ir_pass.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
if (x.as_derived<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x));
} else {
*ret = AllTypeVars(Downcast<Expr>(x));
}
});
/*! /*!
* \brief Get reference counter of each internal ExprNode in body. * \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression. * \param body The body expression.
......
...@@ -6,13 +6,17 @@ ...@@ -6,13 +6,17 @@
TEST(Relay, SelfReference) { TEST(Relay, SelfReference) {
using namespace tvm; using namespace tvm;
auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool());
auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); auto x = relay::VarNode::make("x", relay::Type());
auto x = relay::VarNode::make("x", type_a); auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x }); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{})); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a);
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected));
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
import tvm
from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars():
ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
def test_tuple():
t = relay.Var('t')
fv = free_vars(relay.Tuple([t, t]))
assert len(fv) == 1
assert fv[0] == t
fv = free_vars(relay.TupleGetItem(t, 123))
assert len(fv) == 1
assert fv[0] == t
def test_free_type_vars():
tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty)
y = relay.Var("y")
let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
ftvl = free_type_vars(let)
assert len(ftvl) == 1
assert ftvl[0] == tp
import tvm
from tvm import relay
from tvm.relay.ir_pass import (free_vars, free_type_vars,
bound_vars, bound_type_vars,
all_vars, all_type_vars)
def assert_vars_match(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i] == expected[i]
def test_free_vars():
ty = relay.TensorType([], "int32")
x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
def test_free_vars_tuple():
t = relay.Var('t')
fv = free_vars(relay.Tuple([t, t]))
assert len(fv) == 1
assert fv[0] == t
fv = free_vars(relay.TupleGetItem(t, 123))
assert len(fv) == 1
assert fv[0] == t
def test_free_type_vars():
tp = relay.TypeVar("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
x = relay.Var("x", ty)
y = relay.Var("y")
let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y
ftvl = free_type_vars(let)
assert len(ftvl) == 1
assert ftvl[0] == tp
def test_bound_vars():
x = relay.Var("x")
y = relay.Var("y")
z = relay.Var("z")
a = relay.Var("a")
f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([])))
assert_vars_match(bound_vars(f1), [x, y, z, a])
tup = relay.Tuple([x, y, z, a])
assert len(bound_vars(tup)) == 0
f2 = relay.Function([x, y], relay.Tuple([x, y, z, a]))
assert_vars_match(bound_vars(f2), [x, y])
def test_bound_type_vars():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
c = relay.TypeVar("c")
ft1 = relay.FuncType([a], b, [a, b])
bound_ft1 = bound_type_vars(ft1)
assert_vars_match(bound_type_vars(ft1), [a, b])
ft2 = relay.FuncType([], c, [a])
assert_vars_match(bound_type_vars(ft2), [a])
tup_ty = relay.TupleType([a, b, c])
assert len(bound_type_vars(tup_ty)) == 0
f1 = relay.Function([], relay.Tuple([]), type_params=[a, b])
assert_vars_match(bound_type_vars(f1), [a, b])
f2 = relay.Function([], relay.Tuple([]), c)
assert len(bound_type_vars(f2)) == 0
x = relay.Var("x", a)
let1 = relay.Let(x, relay.Tuple([]), x)
assert len(bound_type_vars(let1)) == 0
let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x)
assert_vars_match(bound_type_vars(let2), [b, c])
def test_all_vars():
x = relay.Var("x")
y = relay.Var("y")
z = relay.Var("z")
f1 = relay.Function([x, y], z)
assert_vars_match(all_vars(f1), [x, y, z])
f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z))
assert_vars_match(all_vars(f2), [x, y, z])
f3 = relay.Function([x], relay.Tuple([y, z]))
assert_vars_match(all_vars(f3), [x, y, z])
tup = relay.Tuple([x, y, z])
assert_vars_match(all_vars(tup), [x, y, z])
def test_all_type_vars():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
c = relay.TypeVar("c")
ft1 = relay.FuncType([b], c, [a])
assert_vars_match(all_type_vars(ft1), [a, b, c])
ft2 = relay.FuncType([], relay.TupleType([a, b, c]), [])
assert_vars_match(all_type_vars(ft2), [a, b, c])
w = relay.Var("w")
x = relay.Var("x", a)
y = relay.Var("y", b)
z = relay.Var("z", c)
f1 = relay.Function([x], y, b, [a])
assert_vars_match(all_type_vars(f1), [a, b])
f2 = relay.Function([x], relay.Let(y, x, z))
assert_vars_match(all_type_vars(f2), [a, b, c])
f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c]))
assert_vars_match(all_type_vars(f3), [a, b, c])
f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c])
assert_vars_match(all_type_vars(f4), [a, b, c])
f5 = relay.Function([w], w)
assert len(all_type_vars(f5)) == 0
...@@ -23,7 +23,7 @@ def test_monomorphic_let(): ...@@ -23,7 +23,7 @@ def test_monomorphic_let():
x = sb.let('x', relay.const(1.0, "float64")) x = sb.let('x', relay.const(1.0, "float64"))
sb.ret(x) sb.ret(x)
xchecked = relay.ir_pass.infer_type(sb.get()) xchecked = relay.ir_pass.infer_type(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64") assert xchecked.checked_type == relay.scalar_type("float64" )
def test_single_op(): def test_single_op():
...@@ -41,14 +41,15 @@ def test_add_broadcast_op(): ...@@ -41,14 +41,15 @@ def test_add_broadcast_op():
return x + y; return x + y;
} }
""" """
pass x = relay.var('x', shape=(10, 4))
# x = relay.var('x', shape=(10, 4)) y = relay.var('y', shape=(5, 10, 1))
# y = relay.var('y', shape=(5, 10, 1)) z = x + y
# z = x + y func = relay.Function([x, y], z)
# func = relay.Function([x, y], z) t1 = relay.TensorType((10, 4), 'float32')
# ttype = relay.TensorType((5, 5, 5), 'float32') t2 = relay.TensorType((5, 10, 1), 'float32')
# expected_ty = relay.FuncType([ttype, ttype], ttype) t3 = relay.TensorType((5, 10, 4), 'float32')
# assert_has_type(func.to_func(), expected_ty) expected_ty = relay.FuncType([t1, t2], t3)
assert_has_type(func, expected_ty)
def test_dual_op(): def test_dual_op():
...@@ -110,24 +111,17 @@ def test_recursion(): ...@@ -110,24 +111,17 @@ def test_recursion():
assert "%3 = @f(%1, %2)" in mod.astext() assert "%3 = @f(%1, %2)" in mod.astext()
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system.
#
# This test is to illustrate problem with our weak form of
# unification.
#
def test_incomplete_call(): def test_incomplete_call():
sb = ScopeBuilder() tt = relay.scalar_type('int32')
x = relay.var('x', dtype='int32') x = relay.var('x', tt)
f = relay.var('f') f = relay.var('f')
func = relay.Function([x, f], relay.Call(f, [x])) func = relay.Function([x, f], relay.Call(f, [x]), tt)
ft = relay.ir_pass.infer_type(func)
f_type = relay.FuncType([tt], tt)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)
try:
relay.ir_pass.infer_type(func)
assert False
except tvm.TVMError as e:
assert True
def test_tuple(): def test_tuple():
tp = relay.TensorType((10,)) tp = relay.TensorType((10,))
...@@ -136,6 +130,7 @@ def test_tuple(): ...@@ -136,6 +130,7 @@ def test_tuple():
assert (relay.ir_pass.infer_type(res).checked_type == assert (relay.ir_pass.infer_type(res).checked_type ==
relay.TupleType([tp, tp])) relay.TupleType([tp, tp]))
def test_free_expr(): def test_free_expr():
x = relay.var("x", "float32") x = relay.var("x", "float32")
y = relay.add(x, x) y = relay.add(x, x)
...@@ -161,38 +156,26 @@ def test_type_args(): ...@@ -161,38 +156,26 @@ def test_type_args():
assert sh2[1].value == 10 assert sh2[1].value == 10
def test_self_reference(): def test_global_var_recursion():
"""
Program:
def f(x) {
return x;
}
"""
a = relay.TypeVar("a")
x = relay.var("x", a)
sb = relay.ScopeBuilder()
f = relay.Function([x], x)
fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a
def test_global_var_cow_issue():
mod = relay.Module({}) mod = relay.Module({})
gv = relay.GlobalVar("foo") gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), tt = relay.scalar_type('float32')
relay.TensorType([], 'float32'))
func = relay.Function([x], relay.Call(gv, [x]), tt)
mod[gv] = func mod[gv] = func
ft = relay.ir_pass.infer_type(gv, mod)
assert mod[ft].checked_type == relay.FuncType([tt], tt)
def test_equal(): def test_equal():
i = relay.var('i', shape=[], dtype='int32') i = relay.var('i', shape=[], dtype='int32')
eq = op.equal(i, relay.const(0, dtype='int32')) eq = op.equal(i, relay.const(0, dtype='int32'))
# This should fail .... func = relay.Function([i], eq)
func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) ft = relay.ir_pass.infer_type(func)
assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
if __name__ == "__main__": if __name__ == "__main__":
...@@ -204,8 +187,12 @@ if __name__ == "__main__": ...@@ -204,8 +187,12 @@ if __name__ == "__main__":
test_decl() test_decl()
test_recursion() test_recursion()
test_tuple() test_tuple()
test_generalized_tuple()
test_incomplete_call() test_incomplete_call()
test_generalized_call()
test_call_with_type_args()
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference() test_self_reference()
test_global_var_cow_issue() test_global_var_recursion()
test_equal()
import tvm import tvm
from tvm import relay from tvm import relay
from nose.tools import raises
def make_rel(name, args, num_inputs=None, attrs=None): def make_rel(name, args, num_inputs=None, attrs=None):
...@@ -48,7 +49,170 @@ def test_backward_solving(): ...@@ -48,7 +49,170 @@ def test_backward_solving():
assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")
def test_unify_tuple():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t3, t3])
unified = solver.Unify(tup1, tup2)
assert unified == tup2
def test_unify_functype():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
unit = relay.ty.TupleType([])
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10,), "float32")
ft1 = relay.ty.FuncType([t1, t2], t3)
ft2 = relay.ty.FuncType([tensor1, tensor2], unit)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_recursive_unify():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((10, 10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 20), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2])
ft1 = relay.ty.FuncType([tup1, t3], t3)
ft2 = relay.ty.FuncType([tup2, tensor3], tensor3)
unified = solver.Unify(ft1, ft2)
assert unified == ft2
def test_unify_vars_under_tuples():
solver = make_solver()
t1 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t1])
unified = solver.Unify(tup1, tup1)
assert unified == tup1
t2 = relay.ty.IncompleteType()
tup2 = relay.ty.TupleType([t2, t2])
tup3 = relay.ty.TupleType([t1, t2])
tup4 = relay.ty.TupleType([t2, t1])
unified = solver.Unify(tup3, tup4)
assert (unified == tup1 or unified == tup2)
def test_binding_over_typevars():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
a = relay.ty.TypeVar('a')
b = relay.ty.TypeVar('b')
c = relay.ty.TypeVar('c')
d = relay.ty.TypeVar('d')
ft1 = relay.ty.FuncType([t1], t2, [c, d])
ft2 = relay.ty.FuncType([a], b, [a, b])
unified = solver.Unify(ft1, ft2)
assert (unified == solver.Resolve(ft1))
def test_recursive_backward_solving():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
tensor3 = relay.ty.TensorType((10,), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3])
solver.gen_type("Identity", [tup1], out=tup2)
assert solver.Solve()
assert solver.Resolve(tup2) == tup1
def test_backward_solving_after_child_update():
solver = make_solver()
tensor1 = relay.ty.TensorType((10, 20), "float32")
tensor2 = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.IncompleteType()
tup1 = relay.ty.TupleType([t1, t2])
tup2 = relay.ty.TupleType([t1, t3])
tup_concrete = relay.ty.TupleType([tensor1, tensor2])
t4 = solver.gen_type("Identity", [tup1])
t5 = solver.gen_type("Identity", [tup2])
solver.gen_type("Identity", [t4], out=t5)
assert solver.Solve()
assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2
assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2
assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2
# updating the variables *inside* tup1 and tup2 should update t4 and t5
solver.gen_type("Identity", [t1], out=tensor1)
solver.gen_type("Identity", [t2], out=tensor2)
assert solver.Solve()
assert solver.Resolve(t4) == tup_concrete
assert solver.Resolve(t5) == tup_concrete
@raises(tvm._ffi.base.TVMError)
def test_incompatible_tuple_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
tensor1 = relay.ty.TensorType((1, 2, 3), "float32")
tensor2 = relay.ty.TensorType((2, 3), "float32")
tensor3 = relay.ty.TensorType((3,), "float32")
tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2])
tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tup1, tup2)
@raises(tvm._ffi.base.TVMError)
def test_bad_recursive_unification():
solver = make_solver()
t1 = relay.ty.IncompleteType()
solver.Unify(t1, relay.ty.TupleType([t1, t1]))
if __name__ == "__main__": if __name__ == "__main__":
test_bcast() test_bcast()
test_backward_solving() test_backward_solving()
test_unify_tuple()
test_unify_functype()
test_recursive_unify()
test_unify_vars_under_tuples()
test_recursive_backward_solving()
test_backward_solving_after_child_update()
test_incompatible_tuple_unification()
test_bad_recursive_unification()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment