Unverified Commit 9319b6f9 by Tianqi Chen Committed by GitHub

[RELAY] Refactor type inference to use type solver (#1779)

parent 6d037db7
...@@ -190,6 +190,13 @@ build* ...@@ -190,6 +190,13 @@ build*
# Jetbrain # Jetbrain
.idea .idea
.ipython
.jupyter
.nv
.pylint.d
.python_history
.pytest_cache
.local
# tmp file # tmp file
.nfs* .nfs*
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../attrs.h" #include "base.h"
#include "./base.h" #include "expr.h"
#include "./expr.h" #include "type.h"
#include "./type.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -45,7 +44,7 @@ class OpNode : public relay::ExprNode { ...@@ -45,7 +44,7 @@ class OpNode : public relay::ExprNode {
Array<AttrFieldInfo> arguments; Array<AttrFieldInfo> arguments;
/*! /*!
* \brief The type key of the attribute field * \brief The type key of the attribute field
* This can be empty, in which case it defaults to * This can be empty, in which case it defaults to anything.
*/ */
std::string attrs_type_key; std::string attrs_type_key;
/*! /*!
...@@ -156,11 +155,13 @@ class OpRegistry { ...@@ -156,11 +155,13 @@ class OpRegistry {
*/ */
inline OpRegistry& add_type_rel( inline OpRegistry& add_type_rel(
const std::string& rel_name, const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func); runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func);
/*! /*!
* \brief Set the type key of attributes. * \brief Set the type key of attributes.
* \param type_key The type of of the attrs field.x * \param type_key The type of of the attrs field.
* \return reference to self. * \return reference to self.
*/ */
inline OpRegistry& set_attrs_type_key(const std::string& type_key); inline OpRegistry& set_attrs_type_key(const std::string& type_key);
...@@ -348,23 +349,25 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, ...@@ -348,23 +349,25 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name,
inline OpRegistry& OpRegistry::add_type_rel( inline OpRegistry& OpRegistry::add_type_rel(
const std::string& rel_name, const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func) { runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name; auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
TypedEnvFunc<Array<Type>(const Array<Type>&, int)> env_type_rel_func;
if (runtime::Registry::Get(func_name)) { if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name); auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; env_type_rel_func = env_func;
} else { } else {
runtime::Registry::Register(func_name) runtime::Registry::Register(func_name)
.set_body_typed<Array<Type>(const Array<Type>&, int)>(type_rel_func); .set_body(type_rel_func.packed());
auto env_func = EnvFunc::Get(func_name); auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; env_type_rel_func = env_func;
} }
std::vector<TypeParam> type_params; Array<TypeParam> type_params;
std::vector<Type> arg_types; Array<Type> arg_types;
// Add inputs. // Add inputs.
std::string input_name_prefix = "in"; std::string input_name_prefix = "in";
...@@ -375,15 +378,27 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -375,15 +378,27 @@ inline OpRegistry& OpRegistry::add_type_rel(
arg_types.push_back(param); arg_types.push_back(param);
} }
auto ty_call_args = Array<Type>(arg_types); Array<Type> ty_call_args = arg_types;
// Add output type. // Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
type_params.push_back(out_param); type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param); ty_call_args.push_back(out_param);
// The attributes of primitive op is nullptr
//
// The attributes of primitive operator can vary at the call site.
// The type of sum is also dependent on Attrs being passed.
// So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
//
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel = TypeConstraint type_rel =
TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args); TypeRelationNode::make(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());
auto func_type = auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
......
...@@ -26,7 +26,7 @@ namespace relay { ...@@ -26,7 +26,7 @@ namespace relay {
* \return A type checked expression with its checked_type field populated. * \return A type checked expression with its checked_type field populated.
*/ */
Expr InferType(const Environment& env, const Expr& e); Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& v, const Function& e); Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
/*! /*!
* \brief Check that types are well formed by applying "kinding rules". * \brief Check that types are well formed by applying "kinding rules".
...@@ -69,7 +69,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -69,7 +69,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
* *
* For example: `forall s, Tensor[f32, s]` is equal to * For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`. * `forall w, Tensor[f32, w]`.
* *
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details. * for more details.
* *
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include "./base.h" #include "base.h"
#include "../attrs.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -116,10 +117,10 @@ class TypeParamNode : public TypeNode { ...@@ -116,10 +117,10 @@ class TypeParamNode : public TypeNode {
/*! \brief possible kinds of TypeParam */ /*! \brief possible kinds of TypeParam */
enum Kind : int { enum Kind : int {
/*! \brief template variable in shape expression */ /*! \brief template variable in shape expression */
kShapeVar = 0, kType = 0,
kShape = 1, kShapeVar = 1,
kBaseType = 2, kBaseType = 2,
kType = 3 kShape = 3
}; };
/*! /*!
* \brief The variable itself is only meaningful when * \brief The variable itself is only meaningful when
...@@ -144,6 +145,33 @@ class TypeParamNode : public TypeNode { ...@@ -144,6 +145,33 @@ class TypeParamNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
/*! /*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
}
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
/*!
* \brief Potential Constraints in the type. * \brief Potential Constraints in the type.
* \note This is reserved for future use. * \note This is reserved for future use.
*/ */
...@@ -190,7 +218,8 @@ class FuncTypeNode : public TypeNode { ...@@ -190,7 +218,8 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static FuncType make(tvm::Array<Type> arg_types, Type ret_type, TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints); tvm::Array<TypeConstraint> type_constraints);
...@@ -200,11 +229,102 @@ class FuncTypeNode : public TypeNode { ...@@ -200,11 +229,102 @@ class FuncTypeNode : public TypeNode {
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
class TypeReporter;
/*!
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Node {
public:
/*!
* \brief Create a type equality constraint.
*
* The "assign direction" acts as a hint to the solver
* showing that it is more likely to resolve dst by src.
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
*/
TVM_DLL virtual void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public NodeRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
}
TypeReporterNode* operator->() const {
return static_cast<TypeReporterNode*>(node_.get());
}
using ContainerType = TypeReporterNode;
};
/*!
* \brief User defined type constraint function.
*
* If the input type information can be used to fully decide
* the IncompleteTypes, then the function should call
* reporter.Assign to report the new types, and return true.
* Otherwise, the function should return false.
*
* \param args The arguments to the relation.
* The types are stored in the form of
* [input_type_0, input_type_1, ... input_type_n,
* output_type_0, output_type_1, ... output_type_m]
*
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
using TypeRelationFn = using TypeRelationFn =
TypedEnvFunc<Array<Type>(const Array<Type>&, int)>; TypedEnvFunc<bool(const Array<Type>& args,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter)>;
/*! /*!
* \brief Opaque type relation, is an input-output relation on types. * \brief User defined type relation, is an input-output relation on types.
*/ */
class TypeRelation; class TypeRelation;
/*! /*!
...@@ -214,24 +334,30 @@ class TypeRelation; ...@@ -214,24 +334,30 @@ class TypeRelation;
*/ */
class TypeRelationNode : public TypeConstraintNode { class TypeRelationNode : public TypeConstraintNode {
public: public:
/*! \brief The name of the function */
std::string name;
/*! /*!
* \brief The function on input and output variables which * \brief The function on input and output variables which
* this is not directly serializable, * this is not directly serializable,
* need to be looked-up in the environment. * need to be looked-up in the environment.
*/ */
TypeRelationFn func_; TypeRelationFn func;
/*! \brief The type arguments to the type function. */ /*! \brief The type arguments to the type function. */
tvm::Array<Type> args; tvm::Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
} }
TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array<Type> args); TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation"; static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
...@@ -239,30 +365,6 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -239,30 +365,6 @@ class TypeRelationNode : public TypeConstraintNode {
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
// The following fields contains advanced typing // The following fields contains advanced typing
// Only keep the class name and reserved for future usage. // Only keep the class name and reserved for future usage.
class GenericTensorType; class GenericTensorType;
......
...@@ -175,7 +175,17 @@ class TypedPackedFunc<R(Args...)> { ...@@ -175,7 +175,17 @@ class TypedPackedFunc<R(Args...)> {
* *
* \param packed The packed function * \param packed The packed function
*/ */
inline explicit TypedPackedFunc(PackedFunc packed); inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
/*!
* \brief constructor from TVMRetValue
* \param value The TVMRetValue
*/
inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
/*!
* \brief constructor from TVMArgValue
* \param value The TVMArgValue
*/
inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
/*! /*!
* \brief construct from a lambda function with the same signature. * \brief construct from a lambda function with the same signature.
* *
...@@ -196,7 +206,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -196,7 +206,7 @@ class TypedPackedFunc<R(Args...)> {
std::is_convertible<FLambda, std::is_convertible<FLambda,
std::function<R(Args...)> std::function<R(Args...)>
>::value>::type> >::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) { TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda); this->AssignTypedLambda(typed_lambda);
} }
/*! /*!
...@@ -1144,6 +1154,14 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) ...@@ -1144,6 +1154,14 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
: packed_(packed) {} : packed_(packed) {}
template<typename R, typename ...Args> template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value)
: packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
: packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args>
template<typename FType> template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
......
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node from .base import NodeBase, register_relay_node
from ._ir_pass import _get_checked_type
from . import _make from . import _make
from .. import convert from .. import convert
class Expr(NodeBase): class Expr(NodeBase):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
def checked_type(self): def checked_type(self):
return _get_checked_type(self) ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
def __call__(self, *args): def __call__(self, *args):
converted_args = [] converted_args = []
......
...@@ -52,11 +52,10 @@ class Kind(IntEnum): ...@@ -52,11 +52,10 @@ class Kind(IntEnum):
with. For example one's of kind BaseType can only be `float32`, `int32`, with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on. and so on.
""" """
ShapeVar = 0 Type = 0
Shape = 1 ShapeVar = 1
BaseType = 2 BaseType = 2
Type = 3 Shape = 3
@register_relay_node @register_relay_node
class TypeParam(Type): class TypeParam(Type):
...@@ -68,7 +67,7 @@ class TypeParam(Type): ...@@ -68,7 +67,7 @@ class TypeParam(Type):
functions which are generic over types. functions which are generic over types.
""" """
def __init__(self, var, kind): def __init__(self, var, kind=Kind.Type):
"""Construct a TypeParam. """Construct a TypeParam.
Parameters Parameters
...@@ -76,7 +75,7 @@ class TypeParam(Type): ...@@ -76,7 +75,7 @@ class TypeParam(Type):
var: tvm.expr.Var var: tvm.expr.Var
The tvm.Var which backs the type parameter. The tvm.Var which backs the type parameter.
kind: Kind kind: Kind, optional
The kind of the type parameter. The kind of the type parameter.
Returns Returns
...@@ -130,8 +129,7 @@ class FuncType(Type): ...@@ -130,8 +129,7 @@ class FuncType(Type):
arg_types, arg_types,
ret_type, ret_type,
type_params, type_params,
type_constraints type_constraints):
):
"""Construct a function type. """Construct a function type.
Parameters Parameters
...@@ -153,6 +151,29 @@ class FuncType(Type): ...@@ -153,6 +151,29 @@ class FuncType(Type):
@register_relay_node @register_relay_node
class IncompleteType(Type): class IncompleteType(Type):
"""An incomplete type.""" """An incomplete type."""
def __init__(self, kind=Kind.Type):
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind) self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
/*!
* Copyright 2018 by Contributors
*
* \file arena.h
* \brief Arena allocator that allocates
* memory chunks and frees them all during destruction time.
*/
#ifndef TVM_COMMON_ARENA_H_
#define TVM_COMMON_ARENA_H_
#include <type_traits>
namespace tvm {
namespace common {
const constexpr int kArenaPageSize = 16 << 10;
/*!
* \brief Arena allocator that allocates memory from continuous
* chunk and frees them all only during destruction.
*/
class Arena {
public:
Arena() {
// eagerly allocate the first page.
head_ = reinterpret_cast<PageHeader*>(new Page());
head_->next = nullptr;
head_->ptr = sizeof(PageHeader);
}
~Arena() {
// delete all the allocated pages.
while (head_ != nullptr) {
Page* page = reinterpret_cast<Page*>(head_);
head_ = head_->next;
delete page;
}
}
/*!
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
*/
template<typename T>
T* Alloc() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
}
private:
// page size 16 KB
// The page data type;
using Page = std::aligned_storage<kArenaPageSize, 1024>::type;
/*! \brief Page header */
struct PageHeader {
/*! \brief points to the next page */
PageHeader* next;
/*! \brief memory allocator ptr inside page */
size_t ptr;
};
/* \brief The page header */
PageHeader* head_{nullptr};
/*!
* \brief Align ptr by upper bound.
* \param ptr The pointer value.
* \param align The alignment requirement.
*/
size_t UpperAlign(size_t ptr, size_t align) {
return ptr + (align - (ptr % align)) % align;
}
/*!
* \brief Internal aligned alloc function.
* \param size The size of the memory.
* \param align The alignment requirement.
*/
void* Alloc(size_t size, size_t align) {
size_t ptr = UpperAlign(head_->ptr, align);
if (ptr + size <= kArenaPageSize) {
head_->ptr = ptr + size;
return reinterpret_cast<char*>(head_) + ptr;
} else {
PageHeader* new_head = reinterpret_cast<PageHeader*>(new Page());
new_head->next = head_;
ptr = UpperAlign(sizeof(PageHeader), align);
CHECK_LE(ptr + size, kArenaPageSize);
new_head->ptr = ptr + size;
head_ = new_head;
return reinterpret_cast<char*>(head_) + ptr;
}
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_ARENA_H_
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <tvm/relay/environment.h> #include <tvm/relay/environment.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <sstream> #include <sstream>
#include "./../pass/resolve.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -49,7 +48,7 @@ void EnvironmentNode::Add(const GlobalVar &var, ...@@ -49,7 +48,7 @@ void EnvironmentNode::Add(const GlobalVar &var,
auto checked_func = GetRef<Function>(func_node); auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type)); CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
if (!update) { if (!update) {
...@@ -68,7 +67,7 @@ void EnvironmentNode::Add(const GlobalVar &var, ...@@ -68,7 +67,7 @@ void EnvironmentNode::Add(const GlobalVar &var,
this->functions.Set(var, checked_func); this->functions.Set(var, checked_func);
} }
} else { } else {
throw Error("internal error: unknown item type, unreachable code"); LOG(FATAL) << "internal error: unknown item type, unreachable code";
} }
} }
......
...@@ -55,7 +55,27 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -55,7 +55,27 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->kind << ")"; << node->kind << ")";
}); });
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type, IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
}
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>(
[](const IncompleteTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints) { tvm::Array<TypeConstraint> type_constraints) {
NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>(); NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
...@@ -79,24 +99,28 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -79,24 +99,28 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->type_constraints << ")"; << node->type_constraints << ")";
}); });
TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) { TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>(); NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
n->name = std::move(name); n->func = std::move(func);
n->func_ = std::move(func);
n->args = std::move(args); n->args = std::move(args);
n->num_inputs = num_inputs;
n->attrs = std::move(attrs);
return TypeRelation(n); return TypeRelation(n);
} }
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2]); *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, .set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
tvm::IRPrinter *p) { p->stream << "TypeRelationNode("
p->stream << "TypeRelationNode(" << node->name << ", " << node->args << node->func->name
<< ")"; << ", " << node->args << ")";
}); });
TupleType TupleTypeNode::make(Array<Type> fields) { TupleType TupleTypeNode::make(Array<Type> fields) {
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <tvm/relay/logging.h> #include <tvm/relay/logging.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <numeric> #include <numeric>
#include "../pass/incomplete_type.h"
#include "./type_relations.h" #include "./type_relations.h"
namespace tvm { namespace tvm {
...@@ -30,18 +29,19 @@ int ToInt(const tvm::Expr& e) { ...@@ -30,18 +29,19 @@ int ToInt(const tvm::Expr& e) {
return imm->value; return imm->value;
} }
Array<Type> IdentityRel(const Array<Type>& types, int num_args) { bool IdentityRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); int num_inputs,
auto t1 = ToTensorType(types[0]); const Attrs& attrs,
if (t1 && types[1].as<IncompleteTypeNode>()) { const TypeReporter& reporter) {
return {t1, t1}; for (size_t i = 1; i < types.size(); ++i) {
} else { reporter->Assign(types[i], types[0]);
return types;
} }
return true;
} }
static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, Type ConcreteBroadcast(const TensorType& t1,
DataType output_dtype) { const TensorType& t2,
DataType output_dtype) {
RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2
<< std::endl; << std::endl;
auto sh1 = t1->shape; auto sh1 = t1->shape;
...@@ -73,7 +73,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, ...@@ -73,7 +73,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2,
Array<ShapeExpr> smaller; Array<ShapeExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) { for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); smaller.push_back(make_const(tvm::Int(64), 1));
} }
if (sh1.size() < sh2.size()) { if (sh1.size() < sh2.size()) {
...@@ -93,46 +93,52 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, ...@@ -93,46 +93,52 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2,
CHECK_EQ(larger.size(), smaller.size()); CHECK_EQ(larger.size(), smaller.size());
Array<HalideIR::Expr> out_shape; Array<ShapeExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) { for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>(); auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>(); auto right = larger[i].as<tvm::ir::IntImm>();
CHECK(left); CHECK(left);
CHECK(right); CHECK(right);
int64_t dim = std::max(left->value, right->value); int64_t dim = std::max(left->value, right->value);
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); out_shape.push_back(make_const(tvm::Int(64), dim));
} }
return TensorTypeNode::make(out_shape, output_dtype); return TensorTypeNode::make(out_shape, output_dtype);
} }
} }
Array<Type> BroadcastRel(const Array<Type>& types, int num_args) { bool BroadcastRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1]
<< "Out: " << types[2] << std::endl; << "Out: " << types[2] << std::endl;
if (auto t1 = ToTensorType(types[0])) { if (auto t0 = ToTensorType(types[0])) {
if (auto t2 = ToTensorType(types[1])) { if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t1->dtype, t2->dtype); CHECK_EQ(t0->dtype, t1->dtype);
return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype));
return true;
} }
} }
return false;
return types;
} }
/* A relation which specifies broadcasting rules for operations which bool BroadcastCompRel(const Array<Type>& types,
compute boolean results. int num_inputs,
*/ const Attrs& attrs,
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
if (auto t1 = ToTensorType(types[0])) { RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1]
if (auto t2 = ToTensorType(types[1])) { << "Out: " << types[2] << std::endl;
return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool()));
return true;
} }
} }
return false;
return types;
} }
/*! \brief Handle concrete concat case from known input to output. */ /*! \brief Handle concrete concat case from known input to output. */
...@@ -175,10 +181,10 @@ inline Type ConcreteConcatRel(const Type& input_type) { ...@@ -175,10 +181,10 @@ inline Type ConcreteConcatRel(const Type& input_type) {
auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0); auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0);
Array<tvm::Expr> out_shape = { tvm::ir::IntImm::make(HalideIR::Int(64), out_axis_dim) }; Array<tvm::Expr> out_shape = { make_const(Int(64), out_axis_dim) };
for (auto dim : dims) { for (auto dim : dims) {
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); out_shape.push_back(make_const(Int(64), dim));
} }
return TensorTypeNode::make(out_shape, dtype); return TensorTypeNode::make(out_shape, dtype);
...@@ -188,19 +194,18 @@ inline Type ConcreteConcatRel(const Type& input_type) { ...@@ -188,19 +194,18 @@ inline Type ConcreteConcatRel(const Type& input_type) {
} }
} }
Array<Type> ConcatRel(const Array<Type>& types, int num_args) { bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
if (types[0].as<TupleTypeNode>()) {
if (types[0].as<IncompleteTypeNode>() && types[1].as<IncompleteTypeNode>()) { reporter->Assign(types[1], ConcreteConcatRel(types[0]));
return types; return true;
} else if (types[1].as<IncompleteTypeNode>()) {
return { types[0], ConcreteConcatRel(types[0]) };
} else {
throw TypeRelationError(
"can not deduce relationship between the " \
"type of concat's input and output");
} }
return false;
} }
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -24,42 +24,72 @@ struct TypeRelationError : Error { ...@@ -24,42 +24,72 @@ struct TypeRelationError : Error {
: Error(msg) {} : Error(msg) {}
}; };
/*! \brief The identity type relation maps a single input variable /*!
* to the output variable. * \brief The identity type relation, all the types are equal.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> IdentityRel(const Array<Type>& types, int num_args); bool IdentityRel(const Array<Type>& types,
/*! \brief The broadcast type relation, implements the broadcasting int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type. * rule over the two input types producing the broadcasted type.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> BroadcastRel(const Array<Type>& types, int num_args); bool BroadcastRel(const Array<Type>& types,
/*! \brief The broadcast type relation, implements the broadcasting int num_inputs,
* rule over the two input types producing the broadcasted type. const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
* *
* This differs from BroadcastRel in the return dtype, * This differs from BroadcastRel in the return dtype,
* it instead returns bool, for use in comparsion operators * it instead returns bool(uint8), for use in comparsion operators
* such as equal, not_equal, lt, and so on. * such as equal, not_equal, lt, and so on.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args); bool BroadcastCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
/*! \brief The concat relation. /*!
* \brief The The concat relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
* *
* This relation takes a single input which must be a single tensor * This differs from BroadcastRel in the return dtype,
* or an arbitrary sized tuple. It combines these input dimensions * it instead returns bool(uint8), for use in comparsion operators
* together to produce the output example. * such as equal, not_equal, lt, and so on.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> ConcatRel(const Array<Type>& types, int num_args); bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2018 by Contributors
* \file incomplete_type.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*!
* \brief Represents a portion of an incomplete type.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); }
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file resolve.cc
* \brief Resolve incomplete types to complete types.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include "./resolve.h"
#include "./type_visitor.h"
namespace tvm {
namespace relay {
struct ResolveTypeType : TypeMutator {
const TypeUnifier &unifier;
explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {}
Type VisitType(const Type &t) override {
if (!t.defined()) {
auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
unifier->Insert(inc_ty);
return inc_ty;
} else {
return TypeMutator::VisitType(t);
}
}
Type VisitType_(const IncompleteTypeNode *op) override {
return unifier->Subst(GetRef<IncompleteType>(op));
}
};
struct ResolveTypeExpr : ExprMutator {
const TypeUnifier &unifier;
explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {}
Expr Mutate(const Expr &e) {
// NB: a bit tricky here.
//
// We want to store resolved type without having
// to re-typecheck the entire term.
//
// Since we know that e : T[...] under some holes
// then it is the case that if we resolve types
// present in e, then we can type it under T
// with the wholes filled in.
//
// We will visit e like normal building a new
// term, then resolve e's old type and write
// it back into the new node.
auto new_e = ExprMutator::Mutate(e);
CHECK(e->checked_type_.defined());
auto resolved_cty = VisitType(e->checked_type_);
new_e->checked_type_ = resolved_cty;
return new_e;
}
Type VisitType(const Type &t) {
return ResolveTypeType(unifier).VisitType(t);
}
};
Type Resolve(const TypeUnifier &unifier, const Type &ty) {
CHECK(ty.defined());
return ResolveTypeType(unifier).VisitType(ty);
}
Expr Resolve(const TypeUnifier &unifier, const Expr &expr) {
return ResolveTypeExpr(unifier).Mutate(expr);
}
struct FullyResolved : TypeVisitor<> {
bool incomplete;
FullyResolved() : incomplete(true) {}
void VisitType(const Type &t) override {
if (!t.defined()) {
incomplete = true;
} else {
return TypeVisitor<>::VisitType(t);
}
}
void VisitType_(const IncompleteTypeNode *ty_var) override {
incomplete = false;
}
};
bool IsFullyResolved(const Type &t) {
auto fr = FullyResolved();
fr.VisitType(t);
return fr.incomplete;
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/resolve.h
* \brief Resolve incomplete types to complete types.
*/
#ifndef TVM_RELAY_PASS_RESOLVE_H_
#define TVM_RELAY_PASS_RESOLVE_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./unifier.h"
namespace tvm {
namespace relay {
/*! \brief Resolve a type containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The type to resolve.
* \returns The resolved type.
*/
Type Resolve(const TypeUnifier& unifier, const Type& ty);
/*! \brief Resolve an expression containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The expression to resolve.
* \returns The resolved expression.
*/
Expr Resolve(const TypeUnifier& unifier, const Expr& expr);
/*! \brief Check if all types have been filled in.
* \param t The type.
* \returns True if the type is resolved, false otherwise.
*/
bool IsFullyResolved(const Type& t);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_RESOLVE_H_
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include "./incomplete_type.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -22,607 +22,360 @@ ...@@ -22,607 +22,360 @@
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "./incomplete_type.h" #include "type_solver.h"
#include "./resolve.h" #include "type_subst.h"
#include "./type_subst.h"
#include "./type_visitor.h"
#include "./unifier.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
//
using namespace tvm::runtime; // The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
// // We declare this for forward compatibility. // - solver.AddConstraint and solver.Unify are called to populate the necessary constraints
struct ConstraintData {}; // - Solve the constraints (solver_.Solve)
// - Recreate expression with the resolved checked_type (Resolver.VisitExpr)
/*! \brief A more efficient representation of the type relation //
* data needed for type checking. class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
*/ public:
struct TypeRelationData : ConstraintData { // constructors
std::string name; TypeInferencer()
std::vector<Type> args; : env_(EnvironmentNode::make({})) {
TypeRelationFn func;
Span span;
explicit TypeRelationData(const TypeRelation& ty_rel)
: TypeRelationData(ty_rel->args, ty_rel->func_, ty_rel->span) {}
TypeRelationData(const Array<Type>& args, const TypeRelationFn& func, const Span& sp)
: func(func), span(sp) {
for (auto arg : args) {
this->args.push_back(arg);
}
} }
explicit TypeInferencer(Environment env)
TypeRelation ToTypeRel() const { : env_(env) {
Array<Type> args = Array<Type>(this->args.begin(), this->args.end());
return TypeRelationNode::make(
this->name, this->func, args);
} }
};
struct TypeContext {
std::unordered_map<Var, Type, NodeHash> var_map;
std::vector<std::vector<TypeRelationData> > constraints;
TypeContext() { constraints.push_back({}); } // inference the type of expr.
Expr Infer(Expr expr);
void Insert(const Var& id, const Type& t) { var_map[id] = t; } private:
// type resolver that maps back to type
void AddConstraint(const TypeConstraint& constraint) { class Resolver;
constraints.back().push_back(TypeRelationData(Downcast<TypeRelation>(constraint))); // internal environment
Environment env_;
// map from expression to checked type
// type inferencer will populate it up
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
// The solver used by the inferencer.
TypeSolver solver_;
// Unify two types
Type Unify(const Type& t1, const Type& t2, const Span& span) {
// TODO(tqchen, jroesch): propagate span to solver
try {
return solver_.Unify(t1, t2);
} catch (const dmlc::Error &e) {
LOG(FATAL)
<< "Error unifying `"
<< t1
<< "` and `"
<< t2
<< "`: " << e.what();
return Type();
}
} }
// Lazily get type for expr
Type Lookup(const Var& var) { // will call visit to deduce it if it is not in the type_map_
auto type = var_map.find(var); Type GetType(const Expr &expr) {
if (type != var_map.end()) { auto it = type_map_.find(expr);
return (*type).second; if (it != type_map_.end()) {
} else { return it->second;
throw FatalTypeError(std::string("undeclared local variable: ") + var->name_hint);
} }
Type ret = this->VisitExpr(expr);
type_map_[expr] = ret;
return ret;
} }
struct Scope { // Visitor logics
TypeContext& tc; Type VisitExpr_(const VarNode* op) final {
explicit Scope(TypeContext& tc) : tc(tc) { tc.constraints.push_back({}); } // The type of Var can already been lookedup in type_map_;
~Scope() { tc.constraints.pop_back(); } LOG(FATAL) << "Cannot find binding for var " << GetRef<Var>(op);
}; return Type();
};
struct CheckedExpr {
Expr expr;
Type type;
CheckedExpr(Expr e, Type t) : expr(e), type(t) {}
CheckedExpr() {}
};
enum SolverResult : int;
class TypeInferencer : private ExprFunctor<CheckedExpr(const Expr&)> {
private:
TypeContext context;
public:
Environment env;
TypeUnifier unifier;
template <typename T>
T WithScope(const std::function<T()>& f) {
TypeContext::Scope fr(context);
return f();
} }
TypeInferencer(); Type VisitExpr_(const ParamNode* op) final {
TypeInferencer(Environment env, TypeUnifier unifier) // directly handled by Funtion
: env(env), unifier(unifier) {} LOG(FATAL) << "not reached";
explicit TypeInferencer(Environment env); return Type();
CheckedExpr Infer(const Expr &expr);
FuncType Instantiate(FuncType fn_ty, tvm::Array<Type> &ty_args);
Type Normalize(const Type& t);
void ReportError(const std::string& msg, Span sp);
[[noreturn]] void FatalError(const std::string& msg, Span sp);
Type Unify(const Type &t1, const Type& t2, Span sp);
Type Resolve(const Type &t);
Expr Resolve(const Expr &e);
/*! \brief Attempt to solve a single relation. */
void Solve(TypeRelationData& ty_rel);
/*! \brief Attempt to solve all pending relations.
*
* If the solver
*/
SolverResult Solve(std::vector<TypeRelationData>& rels);
/*! \brief Check that all relations hold. */
bool RelationsHold(bool scope_only = false);
/*! \brief Visit a function node, extra flag controls behavior. */
CheckedExpr VisitFunction(const Function& f, bool generalize);
private:
CheckedExpr VisitExpr_(const VarNode* op) override;
CheckedExpr VisitExpr_(const GlobalVarNode* op) override;
CheckedExpr VisitExpr_(const ConstantNode* op) override;
CheckedExpr VisitExpr_(const TupleNode* op) override;
CheckedExpr VisitExpr_(const ParamNode* op) override;
CheckedExpr VisitExpr_(const FunctionNode* op) override;
CheckedExpr VisitExpr_(const CallNode* op) override;
CheckedExpr VisitExpr_(const LetNode* op) override;
CheckedExpr VisitExpr_(const IfNode* op) override;
CheckedExpr VisitExpr_(const OpNode* op) override;
};
TypeInferencer::TypeInferencer() {
this->env = EnvironmentNode::make({});
this->unifier = TypeUnifierNode::make(UnionFindNode::make({}));
}
TypeInferencer::TypeInferencer(Environment env) : env(env) {
this->unifier = TypeUnifierNode::make(UnionFindNode::make({}));
}
CheckedExpr TypeInferencer::Infer(const Expr& expr) {
RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl;
CheckedExpr checked_expr = this->VisitExpr(expr);
RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type
<< std::endl;
Type final_type = checked_expr.type;
RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type
<< std::endl;
checked_expr.expr->checked_type_ = final_type;
return checked_expr;
}
CheckedExpr TypeInferencer::VisitExpr_(const VarNode* op) {
auto var = GetRef<Var>(op);
return {var, this->context.Lookup(var)};
}
CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode* op) {
GlobalVar var = GetRef<GlobalVar>(op);
Expr e = this->env->Lookup(var);
return {var, e->checked_type()};
}
CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode* const_node) {
return {GetRef<Constant>(const_node), const_node->tensor_type()};
}
CheckedExpr TypeInferencer::VisitExpr_(const TupleNode* op) {
Tuple pl = GetRef<Tuple>(op);
std::vector<Expr> field_exprs;
std::vector<Type> field_types;
for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) {
auto checked_field = Infer(*field);
field_exprs.push_back(checked_field.expr);
field_types.push_back(checked_field.type);
} }
return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; Type VisitExpr_(const GlobalVarNode* op) final {
} GlobalVar var = GetRef<GlobalVar>(op);
Expr e = env_->Lookup(var);
CheckedExpr TypeInferencer::VisitExpr_(const ParamNode* param) { return e->checked_type();
// We should trigger error here and move param code direclty into function }
// checking.
auto rtype = this->Resolve(param->type);
// This is a special case ... not sure if there is a better way
// to handle this.
param->var->checked_type_ = rtype;
return {ParamNode::make(param->var, rtype), rtype};
}
CheckedExpr TypeInferencer::VisitFunction(const Function& f, bool generalize) {
// First we add the parameters to the context allowing us to check their
// types.
// TODO(@jroesch): support polymorphism
std::vector<Type> param_types; Type VisitExpr_(const ConstantNode* op) final {
std::vector<Param> params; return op->tensor_type();
}
return this->WithScope<CheckedExpr>([&]() -> CheckedExpr { Type VisitExpr_(const TupleNode* op) final {
for (auto param : f->params) { // TODO(tqchen, jroesch)
CheckedExpr checked_param = this->Infer(param); // tuple should be a constraint in the type solver
Type arg_type; // to handle cases where the field type is not known.
param_types.push_back(checked_param.type); Array<Type> fields;
params.push_back(GetRef<Param>(checked_param.expr.as<ParamNode>())); for (Expr field : op->fields) {
this->context.Insert(param->var, checked_param.type); fields.push_back(GetType(field));
} }
return TupleTypeNode::make(fields);
}
auto checked_body = this->Infer(f->body); Type VisitExpr_(const OpNode* op) final {
auto inferred_rtype = checked_body.type; return op->op_type;
auto annotated_rtype = Resolve(f->ret_type); }
auto unified_rtype = this->Unify(inferred_rtype, annotated_rtype, f->span);
CHECK(RelationsHold(true));
Array<TypeConstraint> cs;
for (auto cons : this->context.constraints.back()) { Type VisitExpr_(const LetNode* op) final {
cs.push_back(cons.ToTypeRel()); Type vtype = GetType(op->value);
if (op->value_type.defined()) {
vtype = Unify(vtype, op->value_type, op->span);
} }
CHECK(!type_map_.count(op->var));
return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), // NOTE: no scoping is necessary becase var are unique in program
FuncTypeNode::make(param_types, unified_rtype, {}, cs)}; type_map_[op->var] = vtype;
}); return GetType(op->body);
}
CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode* op) {
return this->VisitFunction(GetRef<Function>(op), false);
}
FuncType TypeInferencer::Instantiate(FuncType fn_ty,
tvm::Array<Type>& ty_args) {
tvm::Map<TypeParam, Type> subst_map;
// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
for (auto ty_param : fn_ty->type_params) {
IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind);
this->unifier->Insert(fresh);
ty_args.push_back(fresh);
subst_map.Set(ty_param, fresh);
} }
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, Type VisitExpr_(const IfNode* op) final {
fn_ty->type_constraints); // Ensure the type of the guard is of Tensor[Bool, ()],
inst_ty = TypeSubst(inst_ty, subst_map); // that is a rank-0 boolean tensor.
Type cond_type = this->GetType(op->cond);
CHECK(KindCheck(this->env, inst_ty)); this->Unify(cond_type,
TensorTypeNode::Scalar(tvm::Bool()),
return GetRef<FuncType>(inst_ty.as<FuncTypeNode>()); op->cond->span);
} Type checked_true = this->GetType(op->true_branch);
Type checked_false = this->GetType(op->false_branch);
CheckedExpr TypeInferencer::VisitExpr_(const CallNode* op) { return this->Unify(checked_true, checked_false, op->span);
Call c = GetRef<Call>(op);
auto checked_op = this->Infer(c->op);
RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl
<< "fn_ty=" << checked_op.type << std::endl;
auto fn_ty_node = checked_op.type.as<FuncTypeNode>();
if (!fn_ty_node) {
this->FatalError("only expressions with function types can be called",
c->op->span);
} }
// We now have a function type. // Handle special case basic primitive operator,
FuncType fn_ty = GetRef<FuncType>(fn_ty_node); // if successful return the return type
Type PrimitiveCall(const FuncTypeNode* op,
tvm::Array<Type> ty_args; Array<Type> arg_types,
if (ty_args.size() != 0) { const Attrs& attrs) {
throw Error("found manually suplied type args, not supported"); if (op->type_params.size() != arg_types.size() + 1) return Type();
if (op->type_constraints.size() != 1) return Type();
const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return Type();
// validate if the type parameter matches up
for (size_t i = 0; i < op->type_params.size(); ++i) {
if (!op->type_params[i].same_as(rel->args[i])) return Type();
}
Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make(
rel->func, arg_types, arg_types.size() - 1, attrs));
return rtype;
} }
fn_ty = Instantiate(fn_ty, ty_args); // instantiate the function type with fresh
FuncType Instantiate(const FuncTypeNode* fn_ty, Array<Type>* ty_args) {
tvm::Map<TypeParam, Type> subst_map;
std::vector<Type> arg_types; // Build a subsitituion map up from the function type and type arguments.
std::vector<Expr> checked_args; // Eventually allow the type vars to be passed in.
for (auto ty_param : fn_ty->type_params) {
for (auto arg : c->args) { IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind);
auto checked_arg = this->Infer(arg); subst_map.Set(ty_param, fresh);
arg_types.push_back(checked_arg.type); ty_args->push_back(fresh);
checked_args.push_back(checked_arg.expr); }
Type ret_type = fn_ty->ret_type;
// If the function type is incomplete, place a new IncompleteType
// This relax the fn_ty to inputs -> Any
// The type checking can still pass when there are additional constraints on the type
// This is a temporary work around to check recursive functions whose
// return type is not yet known.
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
}
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {},
fn_ty->type_constraints);
inst_ty = TypeSubst(inst_ty, subst_map);
return Downcast<FuncType>(inst_ty);
} }
auto type_arity = fn_ty->arg_types.size(); // Handle general call node.
auto number_of_args = arg_types.size(); Type GeneralCall(const CallNode* op, Array<Type> arg_types) {
Type ftype = GetType(op->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>();
CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at "
<< op->span;
Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
size_t type_arity = fn_ty->arg_types.size();
size_t number_of_args = arg_types.size();
if (type_arity != number_of_args) {
if (type_arity < number_of_args) {
LOG(FATAL) << "the function is provided too many arguments " << op->span;
} else {
LOG(FATAL) << "the function is provided too few arguments" << op->span;
}
}
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span);
}
if (type_arity != number_of_args) { for (auto cs : fn_ty->type_constraints) {
if (type_arity < number_of_args) { solver_.AddConstraint(cs);
this->FatalError("the function is provided too many arguments", c->span);
} else {
this->FatalError("the function is provided too few arguments", c->span);
} }
return fn_ty->ret_type;
} }
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { Type VisitExpr_(const CallNode* op) final {
this->Unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); // Fast path: well-formed primitive op
Array<Type> arg_types;
for (Expr arg : op->args) {
arg_types.push_back(GetType(arg));
}
if (const OpNode* opnode = op->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
arg_types,
op->attrs);
if (rtype.defined()) return rtype;
}
return GeneralCall(op, arg_types);
} }
// After we unify the arguments we should know more about the type Type VisitExpr_(const FunctionNode* f) final {
// arguments, let's run a quick pass over them to find new for (auto param : f->params) {
// representatives. type_map_[param->var] = param->type;
type_map_[param] = param->type;
for (size_t i = 0; i < ty_args.size(); i++) { }
ty_args.Set(i, this->unifier->Subst(ty_args[i])); Type rtype = GetType(f->body);
// Run solver using the currently known information
solver_.Solve();
// Trying to resolve
Array<Type> arg_types;
for (size_t i = 0; i < f->params.size(); ++i) {
Param param = f->params[i];
Type atype = solver_.Resolve(param->type);
CHECK(atype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span;
arg_types.push_back(atype);
}
rtype = solver_.Resolve(rtype);
CHECK(rtype.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve return type of function at" << f->span;
// do not support constraint lifting for now.
return FuncTypeNode::make(arg_types, rtype, f->type_params, {});
} }
};
// Add type constraints from the function types. class TypeInferencer::Resolver : public ExprMutator {
for (auto cs : fn_ty->type_constraints) { public:
context.AddConstraint(cs); Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {
} }
auto new_call = Expr VisitExpr_(const VarNode* op) final {
CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); return AttachCheckedType(op);
return {new_call, fn_ty->ret_type};
}
CheckedExpr TypeInferencer::VisitExpr_(const LetNode* op) {
Let let = GetRef<Let>(op);
CheckedExpr checked_value;
Type annotated_ty = Resolve(let->value_type);
// If we are let-defining a function, we want to be able to
// recursively name the function in order to support recursive
// local definitions.
if (let->value.as<FunctionNode>()) {
context.Insert(let->var, annotated_ty);
checked_value = Infer(let->value);
} else {
checked_value = Infer(let->value);
} }
Type unified_ty = this->Unify(checked_value.type, annotated_ty, let->span); Expr VisitExpr_(const ConstantNode* op) final {
return AttachCheckedType(op);
// Update type context with unified type now that we have
// solved this equation.
context.Insert(let->var, unified_ty);
auto checked_body = Infer(let->body);
auto checked_let = LetNode::make(let->var, checked_value.expr,
checked_body.expr, let->value_type);
return {checked_let, checked_body.type};
}
CheckedExpr TypeInferencer::VisitExpr_(const IfNode* op) {
If ifn = GetRef<If>(op);
// Ensure the type of the guard is of Tensor[Bool, ()],
// that is a rank-0 boolean tensor.
auto checked_cond = this->Infer(ifn->cond);
auto cond_type = checked_cond.type;
this->Unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()),
ifn->cond->span);
auto checked_true = this->Infer(ifn->true_branch);
auto checked_false = this->Infer(ifn->false_branch);
auto unified_type =
this->Unify(checked_true.type, checked_false.type, ifn->span);
auto checked_if =
IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr);
return {checked_if, unified_type};
}
CheckedExpr TypeInferencer::VisitExpr_(const OpNode* op_node) {
auto op = GetRef<Op>(op_node);
return {op, op->op_type};
}
Type TypeInferencer::Resolve(const Type &t) {
if (t.defined()) {
return ::tvm::relay::Resolve(this->unifier, t);
} else {
return IncompleteTypeNode::make(TypeParamNode::Kind::kType);
} }
}
Expr TypeInferencer::Resolve(const Expr &e) { Expr VisitExpr_(const GlobalVarNode* op) final {
CHECK(e.defined()); return AttachCheckedType(op);
return ::tvm::relay::Resolve(this->unifier, e); }
}
void TypeInferencer::Solve(TypeRelationData & ty_rel) { Expr VisitExpr_(const OpNode* op) final {
Array<Type> normalized_args; return ExprMutator::VisitExpr_(op);
}
for (auto arg : ty_rel.args) { Expr VisitExpr_(const TupleNode* op) final {
normalized_args.push_back(Resolve(arg)); return AttachCheckedType(op);
} }
auto new_args = ty_rel.func(normalized_args, ty_rel.args.size()); Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
CHECK(new_args.size() == normalized_args.size()); Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<Type> final_args; return AttachCheckedType(op);
}
for (size_t i = 0; i < new_args.size(); i++) { Expr VisitExpr_(const CallNode* op) final {
ty_rel.args[i] = Unify(normalized_args[i], new_args[i], ty_rel.span); return AttachCheckedType(op);
} }
}
int NumSolvedVars(const Array<Type>& vars) { Expr VisitExpr_(const LetNode* op) final {
int num = 0; return AttachCheckedType(op);
for (auto var : vars) {
if (!var.as<IncompleteTypeNode>()) {
num += 1;
}
} }
return num;
}
enum SolverResult : int { Expr VisitExpr_(const IfNode* op) final {
Failed = -1, return AttachCheckedType(op);
Progress = 0, }
Done = 1,
};
SolverResult TypeInferencer::Solve(std::vector<TypeRelationData>& rels) { // attach checked type to the mutated node.
// We start in the done state with zero progress. template<typename T>
SolverResult status = SolverResult::Done; Expr AttachCheckedType(const T* op) {
int progress = 0; auto it = tmap_.find(GetRef<Expr>(op));
CHECK(it != tmap_.end());
do { Type checked_type = solver_->Resolve(it->second);
// Upon rentering the loop we reset the state. CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
status = SolverResult::Done; << "Cannot resolve type of " << GetRef<Expr>(op)
progress = 0; << " at " << op->span;
Expr new_e = ExprMutator::VisitExpr_(op);
std::vector<int> complete; if (!checked_type.same_as(new_e->checked_type_)) {
// Copy on write optimization
int i = 0; // If new_e is an old expression,
// We will now process each relation in order. // we make a copy mutating an existing reference.
for (TypeRelationData& ty_rel : rels) { if (!new_e.node_.unique()) {
int arity = ty_rel.args.size(); new_e = Expr(make_node<T>(*new_e.as<T>()));
int pre_solved = NumSolvedVars(ty_rel.args);
RELAY_LOG(INFO) << "TypeInferencer::Solve: "
<< "TypeRelation= "
<< ", Arity=" << arity << ", Solved=" << pre_solved
<< std::endl;
// If the relation is already solved then we will make no progress but try
// to set the status to done.
if (pre_solved == arity) {
status = static_cast<SolverResult>((status && SolverResult::Done));
complete.push_back(i);
// If there are unsolved variables we will try to solve some.
} else if (pre_solved < arity) {
Solve(ty_rel);
int post_solved = NumSolvedVars(ty_rel.args);
// If we solved any variables we will try to downgrade status to
// progress update the type relation, and then bump the progress counter
// by one.
if (post_solved > pre_solved) {
status =
static_cast<SolverResult>((status && SolverResult::Progress));
progress += 1;
}
} }
i++; new_e->checked_type_ = checked_type;
} }
return new_e;
// If we made no progress and we aren't finished, then the state should be
// downgraded to fail, then we should exit the loop.
if (progress == 0 && status != SolverResult::Done) {
status = SolverResult::Failed;
break;
}
// Remove the satisfied relations.
for (auto i : complete) {
if (rels.size() > 1) {
rels[i] = rels.back();
rels.pop_back();
} else {
rels.pop_back();
}
}
std::reverse(rels.begin(), rels.end());
} while (status == SolverResult::Progress);
return status;
}
bool TypeInferencer::RelationsHold(bool scope_only) {
// If we are only checking the top scope,
// slice out the constraints.
//
// Otherwise we use all of them.
std::vector<std::vector<TypeRelationData> > constraints;
if (scope_only) {
constraints = {context.constraints[0]};
} else {
constraints = context.constraints;
} }
RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only Type VisitType(const Type &t) final {
<< std::endl; return solver_->Resolve(t);
bool all_hold = true;
for (auto ty_rels : context.constraints) {
auto status = Solve(ty_rels);
RELAY_LOG(INFO) << "status= " << status << std::endl;
if (status == SolverResult::Failed || status == SolverResult::Progress) {
all_hold = false;
} else if (status == SolverResult::Done) {
continue;
} else {
throw InternalError("found invalid value for SolverResult");
}
} }
return all_hold; private:
const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_;
};
Expr TypeInferencer::Infer(Expr expr) {
// step 0: populate the constraints
GetType(expr);
// step 1: solve the constraints
solver_.Solve();
// step 2: attach resolved types to checked_type field
return Resolver(type_map_, &solver_).VisitExpr(expr);
} }
Expr InferType(const Environment& env, const Expr& e) { Expr InferType(const Environment& env, const Expr& e) {
TypeInferencer ti(env); return TypeInferencer(env).Infer(e);
auto checked_expr = ti.Infer(e);
CHECK(ti.RelationsHold());
return ti.Resolve(checked_expr.expr);
} }
Expr InferType(const Environment& env, const GlobalVar& var, Expr InferType(const Environment& env,
const GlobalVar& var,
const Function& func) { const Function& func) {
TypeInferencer ti(env); Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func_copy->checked_type_ = func_copy->fn_type();
func->type_params);
func_copy->checked_type_ = ti.Resolve(func_copy->fn_type());
env->functions.Set(var, func_copy); env->functions.Set(var, func_copy);
auto checked_expr = ti.Infer(func); Expr func_ret = TypeInferencer(env).Infer(func_copy);
CHECK(ti.RelationsHold());
auto map_node = env->functions.CopyOnWrite(); auto map_node = env->functions.CopyOnWrite();
map_node->data.erase(var.node_); map_node->data.erase(var.node_);
return ti.Resolve(checked_expr.expr); return func_ret;
}
void TypeInferencer::FatalError(const std::string& msg, Span sp) {
throw FatalTypeError(
"internal error: this exception should"
"be handled and errors reported with Environment::display_errors\n" +
msg);
}
Type TypeInferencer::Unify(const Type& t1, const Type& t2, Span sp) {
try {
return this->unifier->Unify(t1, t2);
} catch (const dmlc::Error &e) {
std::stringstream ss;
ss << "Error unifying `";
ss << t1;
ss << "` and `";
ss << t2;
ss << "`: " << e.what();
this->FatalError(ss.str(), sp);
}
} }
TVM_REGISTER_API("relay._ir_pass.check_expr") TVM_REGISTER_API("relay._ir_pass.check_expr")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0]; Environment env = args[0];
Expr e = args[1]; Expr e = args[1];
*ret = InferType(env, e); *ret = InferType(env, e);
}); });
// TODO(@jroesch): put in a better namespace.
TVM_REGISTER_API("relay._ir_pass._get_checked_type")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expr e = args[0];
*ret = e->checked_type();
});
/* Incomplete Type */
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
}
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const IncompleteTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file type_solver.cc
* \brief Type solver implementations.
*/
#include <string>
#include "type_solver.h"
namespace tvm {
namespace relay {
class TypeSolver::Reporter : public TypeReporterNode {
public:
explicit Reporter(TypeSolver* solver)
: solver_(solver) {}
void Assign(const Type& dst, const Type& src) final {
solver_->Unify(dst, src);
}
void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) final {
// TODO(tqchen)
}
private:
TypeSolver* solver_;
};
// constructor
TypeSolver::TypeSolver()
: reporter_(make_node<Reporter>(this)) {
}
// destructor
TypeSolver::~TypeSolver() {
// call destructor of all non-POD arena object
for (TypeNode* ptr : type_nodes_) {
ptr->~TypeNode();
}
for (RelationNode* ptr : rel_nodes_) {
ptr->~RelationNode();
}
}
// Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src) {
// Known limitation
// - handle composite types whose component can be unknown.
// - handle shape pattern matching
TypeNode* lhs = GetTypeNode(dst);
TypeNode* rhs = GetTypeNode(src);
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
lhs->parent = rhs;
CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type))
<< "Incompatible parent types in UF:"
<< lhs->resolved_type << " and " << rhs->resolved_type;
return rhs->resolved_type;
}
}
// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = make<RelationNode>();
rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode);
// populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list
LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >();
TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode;
rnode->type_list.Push(tlink);
// insert type->relation node
LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >();
rlink->value = rnode;
tnode->rel_list.Push(rlink);
}
// add the relation to the working queue.
this->AddToQueue(rnode);
} else {
LOG(FATAL) << "Do not know how to handle constraint type"
<< constraint->type_key();
}
}
// Resolve a type in the solver context.
Type TypeSolver::Resolve(const Type& type) {
auto it = tmap_.find(type);
if (it != tmap_.end()) {
return it->second->FindRoot()->resolved_type;
} else {
return type;
}
}
bool TypeSolver::Solve() {
// update until queue is empty
while (!update_queue_.empty()) {
RelationNode* rnode = update_queue_.front();
const auto& rel = rnode->rel;
update_queue_.pop();
CHECK(!rnode->resolved);
// update the relation with given evidence.
Array<Type> args;
for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) {
args.push_back(tlink->value->FindRoot()->resolved_type);
CHECK_LE(args.size(), rel->args.size());
}
// call the function
bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
// mark inqueue as false after the function call
// so that rnode itself won't get enqueued again.
rnode->inqueue = false;
if (resolved) {
++num_resolved_rels_;
}
rnode->resolved = resolved;
}
// This criterion is not necessarily right for all the possible cases
// TODO(tqchen): We should also count the number of in-complete types.
return num_resolved_rels_ == rel_nodes_.size();
}
// Expose type solver only for debugging purposes.
TVM_REGISTER_API("relay._ir_pass._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto solver = std::make_shared<TypeSolver>();
auto mod = [solver](std::string name) -> PackedFunc {
if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() {
return solver->Solve();
});
} else if (name == "Unify") {
return TypedPackedFunc<void(Type, Type)>([solver](Type lhs, Type rhs) {
solver->Unify(lhs, rhs);
});
} else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) {
return solver->Resolve(t);
});
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
return solver->AddConstraint(c);
});
} else {
return PackedFunc();
}
};
*ret = runtime::TypedPackedFunc<runtime::PackedFunc(std::string)>(mod);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file type_solver.h
* \brief Solver logic for type inference.
*/
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <vector>
#include <queue>
#include "../../common/arena.h"
namespace tvm {
namespace relay {
/*!
* \brief Interface of type solver used in type inference.
*
* TypeSolver works on a list of constraints among incomplete types.
* The user will populate the constraints by AddConstraint and Assign.
* Then we can call Solve to trying to resolve the unknown.
*
* This can be viewed as "type program(computational graph)" of types, where
* the type constraint are operators of the graph and the incomplete
* types are intermediate value of the graph.
* If all the input types are concretely known, we should be able to
* just run a forward pass on the "type program" to get all the types.
*
* The list of constraints representation means we are storing it as a bipartite
* graph instead of a DAG. This is because some constraints might go both direction.
* TypeSolver could take advantage of bidirectional constraints to deduce input
* value given output ones. Never-the-less, we should keep in mind that
* there is a "forward direction" that the TypeSolver should take advantage of.
*/
class TypeSolver {
public:
TypeSolver();
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
* \param constraint The constraint to be added.
*/
void AddConstraint(const TypeConstraint& constraint);
/*!
* \brief Resolve type to the solution type in the solver.
* \param type The type to be resolved.
* \return The resolved type.
*/
Type Resolve(const Type& type);
/*!
* \brief Start to solve the types using the current known information.
* \return Whether all the incomplete types has been fully resolved.
*/
bool Solve();
/*!
* \brief Unify lhs and rhs.
* \param lhs The left operand.
* \param rhs The right operand
*/
Type Unify(const Type& lhs, const Type& rhs);
private:
class Reporter;
struct TypeNode;
struct RelationNode;
// Internally the solver maintains a bipartite graph of Relation and Types.
// All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver.
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
/*!
* \brief type node struct
* TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type.
*
* It also contains collection of links to related Relations,
* which is stored in rel_list.
*/
struct TypeNode {
/*! \brief The final resolved type */
Type resolved_type;
/*! \brief type node in the union find algorithm */
TypeNode* parent{nullptr};
/*! \brief list of relations that is related to this type node */
LinkedList<RelationNode*> rel_list;
/*!
* \brief Find the root type node, perform path compression
* \return The root type node.
*/
TypeNode* FindRoot() {
// fast path
if (this->parent == nullptr) return this;
// slow path with path compression.
TypeNode* root = this;
while (root->parent != nullptr) {
root = root->parent;
}
for (TypeNode* p = this; p != root;) {
TypeNode* parent = p->parent;
p->parent = root;
p = parent;
}
return root;
}
};
/*! \brief relation node */
struct RelationNode {
/*! \brief Whether the relation is in the queue to be solved */
bool inqueue{false};
/*! \brief Whether the relation is resolved */
bool resolved{false};
/*! \brief The corresponding type relation */
TypeRelation rel;
/*! \brief list types to this relation */
LinkedList<TypeNode*> type_list;
};
/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */
size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \breif Internal queue to update the relation */
std::queue<RelationNode*> update_queue_;
/*! \brief allocator of all the internal node obhect*/
common::Arena arena_;
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template<typename T>
T* make() {
T* ptr = arena_.Alloc<T>();
// call constructor
new (ptr) T();
return ptr;
}
/*!
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
* \return The type node.
*/
TypeNode* GetTypeNode(const Type& t) {
auto it = tmap_.find(t);
if (it != tmap_.end()) {
return it->second->FindRoot();
} else {
TypeNode* n = make<TypeNode>();
type_nodes_.push_back(n);
n->resolved_type = t;
tmap_[t] = n;
return n;
}
}
/*!
* \brief Add relation node rel to the update queue
* \param rel The relation node
*/
void AddToQueue(RelationNode* rel) {
if (rel->inqueue) return;
CHECK(!rel->resolved);
rel->inqueue = true;
update_queue_.push(rel);
}
/*!
* \brief Merge rhs type node to lhs
* \param src The source operand
* \param dst The dst operand.
*/
void MergeFromTo(TypeNode* src, TypeNode* dst) {
if (src == dst) return;
src->parent = dst;
// move the link to the to dst
for (auto* rlink = src->rel_list.head; rlink != nullptr;) {
// store next pointer first before rlink get moved
auto* next = rlink->next;
// if the relation is not yet resolved
// send the relation to the new
if (!rlink->value->resolved) {
this->AddToQueue(rlink->value);
dst->rel_list.Push(rlink);
}
rlink = next;
}
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SOLVER_H_
...@@ -108,11 +108,14 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> { ...@@ -108,11 +108,14 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
for (const Type& t : type_rel->args) { for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t)); new_args.push_back(this->VisitType(t));
} }
return TypeRelationNode::make(type_rel->name, type_rel->func_, new_args); return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
} }
Type VisitType_(const IncompleteTypeNode* op) override { Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<IncompleteType>(op); return GetRef<Type>(op);
} }
}; };
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/src/relay/pass/unifier.cc
* \brief The type unifier which solves a system of equations between
* incomplete types.
*/
#include "./unifier.h"
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/type.h>
#include "./type_subst.h"
#include "./type_visitor.h"
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
UnionFind UnionFindNode::make(tvm::Map<IncompleteType, Type> uf_map) {
auto n = make_node<UnionFindNode>();
n->uf_map = uf_map;
return UnionFind(n);
}
void UnionFindNode::Insert(const IncompleteType& v) { this->uf_map.Set(v, v); }
void UnionFindNode::debug() {
for (const auto& entry : this->uf_map) {
RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl;
}
}
void UnionFindNode::AssertAlphaEqual(const Type& l, const Type& r) {
if (!AlphaEqual(l, r)) {
std::stringstream ss;
ss << "Incompatible parent types in UF:" << l << " and " << r;
throw UnionFindError(ss.str());
}
}
void UnionFindNode::Unify(const IncompleteType& v1, const Type& t) {
RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t
<< std::endl;
auto parent1 = this->Find(v1);
// if t is a type var, then unify parents
const IncompleteTypeNode *tvn2 = t.as<IncompleteTypeNode>();
if (tvn2) {
auto v2 = GetRef<IncompleteType>(tvn2);
auto parent2 = this->Find(v2);
// if parents are exactly equal, then we're done
if (parent1 == parent2) {
return;
}
// if first parent is a type var, then can just set its union find map to
// second parent
if (const IncompleteTypeNode *pvn1 = parent1.as<IncompleteTypeNode>()) {
auto pv1 = GetRef<IncompleteType>(pvn1);
this->uf_map.Set(pv1, parent2);
return;
}
// if second parent is a type var but first isn't, can set second type var
if (const IncompleteTypeNode *pvn2 = parent2.as<IncompleteTypeNode>()) {
auto pv2 = GetRef<IncompleteType>(pvn2);
this->uf_map.Set(pv2, parent1);
return;
}
// if both parents are not type vars themselves, check alpha-equality
AssertAlphaEqual(parent1, parent2);
return;
}
// if t is not a type var, then unify with v1's parent if parent is a type
// var; else, check alpha-equality for compatibility
if (const IncompleteTypeNode *pvn1 = parent1.as<IncompleteTypeNode>()) {
auto pv1 = GetRef<IncompleteType>(pvn1);
this->uf_map.Set(pv1, t);
return;
}
AssertAlphaEqual(parent1, t);
}
Type UnionFindNode::Find(const IncompleteType& v) {
// The node has no mapping, so its representative is just itself.
if (this->uf_map.find(v) == this->uf_map.end()) {
return v;
}
Type parent = this->uf_map.at(v);
if (v == parent) {
return v;
}
// if parent is not a type var, then it must be the representative type
const IncompleteTypeNode *rep = parent.as<IncompleteTypeNode>();
if (!rep) {
return parent;
}
// otherwise, recurse and perform path compression
IncompleteType pv = GetRef<IncompleteType>(rep);
Type higher_up = this->Find(pv);
this->uf_map.Set(v, higher_up);
return higher_up;
}
TVM_REGISTER_API("relay._make.UnionFind")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() == 0) {
*ret = UnionFindNode::make({});
} else {
*ret = UnionFindNode::make(args[0]);
}
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<UnionFindNode>([](const UnionFindNode *node,
tvm::IRPrinter *p) {
p->stream << "UnionFindNode(" << node->uf_map << ")";
});
TypeUnifier TypeUnifierNode::make(UnionFind union_find) {
auto n = make_node<TypeUnifierNode>();
n->union_find = union_find;
return TypeUnifier(n);
}
void TypeUnifierNode::Insert(const IncompleteType& v) {
this->union_find->Insert(v);
}
Type TypeUnifierNode::Unify(const Type& t1, const Type& t2) {
RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2
<< std::endl;
Type unified = this->VisitType(t1, t2);
// TODO(@jroesch): Restore this code when we finish kind checker.
// if (!check_kind(unified)) {
// throw UnificationError("Invalid kinds in unified type");
// }
return unified;
}
struct IncompleteTypeSubst : TypeMutator {
const TypeUnifierNode *unifier;
IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {}
// type var: look it up in the type map and recurse
Type VisitType_(const IncompleteTypeNode* op) override {
auto tv = GetRef<IncompleteType>(op);
auto parent = unifier->union_find->Find(tv);
if (parent == tv) {
return tv;
}
return this->VisitType(parent);
}
};
Type TypeUnifierNode::Subst(const Type& t) {
IncompleteTypeSubst tvsubst(this);
// normalize first so substitutions in quantifiers will be correct
Type ret = tvsubst.VisitType(t);
// TODO(@jroesch): Restore this code when we finish kind checker.
// if (!check_kind(ret)) {
// std::stringstream ss;
// ss << "Invalid Kinds in substituted type!";
// ss << t << std::endl;
// ss << ret << std::endl;
// throw SubstitutionError(ss.str());
// }
return ret;
}
Type TypeUnifierNode::VisitType(const Type& t1, const Type t2) {
// When the right hand size is a type variable immediately unify.
if (const IncompleteTypeNode *tvn2 = t2.as<IncompleteTypeNode>()) {
return this->UnifyWithIncompleteType(t1, GetRef<IncompleteType>(tvn2));
} else {
return TypeFunctor<Type(const Type &t1, const Type t2)>::VisitType(t1, t2);
}
}
Type TypeUnifierNode::UnifyWithIncompleteType(const Type& t1,
const IncompleteType tv2) {
RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2
<< std::endl;
// Fix unify to return new representative
this->union_find->Unify(tv2, t1);
auto rep = this->union_find->Find(tv2);
RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl;
return rep;
}
Type TypeUnifierNode::VisitType_(const IncompleteTypeNode* t1, const Type rt2) {
IncompleteType tv1 = GetRef<IncompleteType>(t1);
RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2
<< std::endl;
this->union_find->Unify(tv1, rt2);
auto rep = this->union_find->Find(tv1);
RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl;
return rep;
}
Type TypeUnifierNode::VisitType_(const TypeParamNode* t1, const Type rt2) {
TypeParam ti1 = GetRef<TypeParam>(t1);
if (const TypeParamNode *tin2 = rt2.as<TypeParamNode>()) {
TypeParam ti2 = GetRef<TypeParam>(tin2);
if (ti1 != ti2) {
throw UnificationError("Attempting to unify non-matching TypeParams");
}
return ti1;
}
throw UnificationError("Unable to unify TypeParamNode");
}
Type TypeUnifierNode::VisitType_(const FuncTypeNode* t1, const Type rt2) {
FuncType ft1 = GetRef<FuncType>(t1);
if (const FuncTypeNode *tan2 = rt2.as<FuncTypeNode>()) {
FuncType ft2 = GetRef<FuncType>(tan2);
if (ft1->type_params.size() != ft2->type_params.size()) {
throw UnificationError(
"unable to unify functions with differing number of type parameters");
}
tvm::Map<TypeParam, Type> subst_map;
for (size_t i = 0; i < ft1->arg_types.size(); i++) {
subst_map.Set(ft1->type_params[i], ft2->type_params[i]);
}
ft1 = Downcast<FuncType>(TypeSubst(ft1, subst_map));
if (ft1->arg_types.size() != ft2->arg_types.size()) {
throw UnificationError("unable to unify functions of different arities");
}
tvm::Array<Type> unified_args;
for (size_t i = 0; i < ft1->arg_types.size(); i++) {
unified_args.push_back(
this->VisitType(ft1->arg_types[i], ft2->arg_types[i]));
}
Type unified_ret_type = this->VisitType(ft1->ret_type, ft2->ret_type);
return FuncTypeNode::make(unified_args, unified_ret_type, {}, {});
}
throw UnificationError("unable to unify function types");
}
Type TypeUnifierNode::VisitType_(const TensorTypeNode* t1, const Type rt2) {
TensorType tt1 = GetRef<TensorType>(t1);
if (const TensorTypeNode *ttn2 = rt2.as<TensorTypeNode>()) {
TensorType tt2 = GetRef<TensorType>(ttn2);
if (!AlphaEqual(tt1, tt2)) {
throw UnificationError("dtypes do not match");
}
RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape
<< " s2= " << tt2->shape << std::endl;
if (tt1->shape.size() != tt2->shape.size()) {
throw UnificationError("shapes are not of the same length");
}
for (size_t i = 0U; i < tt1->shape.size(); i++) {
if (!tt1->shape[i].same_as(tt2->shape[i])) {
throw UnificationError("shapes do not match at index");
}
}
return rt2;
}
throw UnificationError("Cannot unify TensorTypeNode");
}
Type TypeUnifierNode::VisitType_(const TupleTypeNode* t1, const Type rt2) {
TupleType pt1 = GetRef<TupleType>(t1);
if (const TupleTypeNode *ptn2 = rt2.as<TupleTypeNode>()) {
TupleType pt2 = GetRef<TupleType>(ptn2);
std::vector<Type> unified_fields;
if (pt1->fields.size() != pt2->fields.size()) {
throw UnificationError("Product types are of different dimensions");
}
for (size_t i = 0U; i < pt1->fields.size(); i++) {
Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]);
unified_fields.push_back(unified);
}
return TupleTypeNode::make(unified_fields);
}
throw UnificationError("Cannot unify TupleTypeNode");
}
Type TypeUnifierNode::VisitType_(const TypeRelationNode* tr1, const Type t2) {
throw InternalError("Cannot unify different type relations");
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file include/tvm/relay/pass/unifier.h
* \brief The type unifier which solves a system of equations between
* incomplete types.
*/
#ifndef TVM_RELAY_PASS_UNIFIER_H_
#define TVM_RELAY_PASS_UNIFIER_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./type_functor.h"
namespace tvm {
namespace relay {
struct UnionFindError : dmlc::Error {
explicit UnionFindError(const std::string& msg) : Error(msg) {}
};
struct UnificationError : dmlc::Error {
explicit UnificationError(const std::string& msg) : Error(msg) {}
};
struct SubstitutionError : dmlc::Error {
explicit SubstitutionError(const std::string& msg) : Error(msg) {}
};
/*! \brief A union-find data structure for the type-checker */
class UnionFind;
class UnionFindNode : public Node {
public:
/*! \brief The inernal map from incomplete types to their representatives. */
tvm::Map<IncompleteType, Type> uf_map;
UnionFindNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); }
TVM_DLL static UnionFind make(tvm::Map<IncompleteType, Type> uf_map);
/*! \brief Insert it into the union find.
* \param it The type to add to the union find.
*/
void Insert(const IncompleteType& it);
/*! \brief Union operation, combine two equivalence classes.
* \param it The incomplete type to unify.
* \param ty The other type.
*/
void Unify(const IncompleteType& it, const Type& t);
/*! \brief Find operation, returns the representative of the argument.
* \param it The element to lookup.
*/
Type Find(const IncompleteType& it);
void debug();
void AssertAlphaEqual(const Type& l, const Type& r);
static constexpr const char* _type_key = "relay.UnionFind";
TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node);
};
class UnionFind : public NodeRef {
public:
UnionFind() {}
explicit UnionFind(NodePtr<tvm::Node> p) : NodeRef(p) {}
// The union find structure is mutable so we do not use the standard macros
// and expose the pointer via `->`.
UnionFindNode* operator->() const {
return static_cast<UnionFindNode*>(node_.get());
}
using ContainerType = UnionFindNode;
};
class TypeUnifier;
class TypeUnifierNode : public Node,
private TypeFunctor<Type(const Type&, const Type)> {
public:
UnionFind union_find;
TypeUnifierNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("union_find", &union_find); }
TVM_DLL static TypeUnifier make(UnionFind uf);
/*! \brief Introduces a new type var into the unifier */
void Insert(const IncompleteType& v);
/*! \brief Unifies two types if possible, throws a unification error if it
* cannot */
Type Unify(const Type& t1, const Type& t2);
/*! \brief Attempts to substitute all type vars in t with concrete types,
* throws substitution error if it cannot concretize*/
Type Subst(const Type& t);
// /*! \brief Checks the kinds in the given type */
// Type CheckKinds(const Type& t);
static constexpr const char* _type_key = "relay.TypeUnifier";
TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node);
private:
/*! \brief Unify incomplete type with another type. */
Type UnifyWithIncompleteType(const Type& t1, const IncompleteType tvn2);
/*! \brief Implements unification between two types with incomplete portions.
*/
Type VisitType(const Type& t1, const Type t2) override;
// Visitor Cases
Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override;
Type VisitType_(const TensorTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeParamNode* t1, const Type t2) override;
Type VisitType_(const FuncTypeNode* t1, const Type t2) override;
Type VisitType_(const TupleTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeRelationNode* s1, const Type t2) override;
};
class TypeUnifier : public NodeRef {
public:
TypeUnifier() {}
explicit TypeUnifier(NodePtr<tvm::Node> p) : NodeRef(p) {}
// no const so that unifier can be mutable as a member of typechecker
inline TypeUnifierNode* operator->() const {
return static_cast<TypeUnifierNode*>(node_.get());
}
using ContainerType = TypeUnifierNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_UNIFIER_H_
...@@ -27,8 +27,8 @@ def test_tensor_type(): ...@@ -27,8 +27,8 @@ def test_tensor_type():
def test_type_param(): def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Shape) tp = relay.TypeParam('name', relay.Kind.Type)
assert tp.kind == relay.Kind.Shape assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span # assert tp.span # TODO allow us to set span
str(tp) str(tp)
......
...@@ -76,10 +76,10 @@ def test_add_broadcast_op(): ...@@ -76,10 +76,10 @@ def test_add_broadcast_op():
assert_has_type(func.to_func(), expected_ty) assert_has_type(func.to_func(), expected_ty)
def test_dual_op(): def test_dual_op():
"""Program: """Program:
fn (x : Tensor[f32, (10, 10)]) { fn (x : Tensor[f32, (10, 10)]) {
let t1 = log(x); let t1 = log(x);
let t2 = add(t1, x); let t2 = add(t1, x);
return t1; return t1;
} }
""" """
...@@ -93,8 +93,8 @@ def test_dual_op(): ...@@ -93,8 +93,8 @@ def test_dual_op():
def test_decl(): def test_decl():
"""Program: """Program:
def f(x : Tensor[f32, (10, 10)]) { def f(x : Tensor[f32, (10, 10)]) {
let lx = log(x); let lx = log(x);
return lx; return lx;
} }
...@@ -125,7 +125,7 @@ def test_recursion(): ...@@ -125,7 +125,7 @@ def test_recursion():
n = b.param('n', ty='int32') n = b.param('n', ty='int32')
data = b.param('data', ty='float32') data = b.param('data', ty='float32')
with b.decl(f, n, data): with b.decl(f, n, data):
with b.if_scope(equal(n, convert(0.0))): with b.if_scope(equal(n, convert(0))):
b.ret(f(subtract(n, convert(1)), log(data))) b.ret(f(subtract(n, convert(1)), log(data)))
with b.else_scope(): with b.else_scope():
b.ret(data) b.ret(data)
...@@ -152,11 +152,12 @@ def test_concat(): ...@@ -152,11 +152,12 @@ def test_concat():
assert_decl_has_type(ib.env, try_concat2, fn_ty) assert_decl_has_type(ib.env, try_concat2, fn_ty)
if __name__ == "__main__": if __name__ == "__main__":
# test_monomorphic_let() test_recursion()
# test_single_op()
# test_add_op() test_monomorphic_let()
# test_add_broadcast_op() test_single_op()
# test_dual_op() test_add_op()
# test_decl() test_add_broadcast_op()
# test_recursion() test_dual_op()
test_decl()
test_concat() test_concat()
import tvm
from tvm import relay
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
def make_rel(name, args, num_inputs=None, attrs=None):
func = tvm.get_env_func("tvm.relay.type_relation." + name)
if num_inputs is None:
num_inputs = len(args) - 1
return relay.ty.TypeRelation(func, args, num_inputs, attrs)
def make_solver():
solver = relay._ir_pass._test_type_solver()
solver.Solve = solver("Solve")
solver.Unify = solver("Unify")
solver.Resolve = solver("Resolve")
solver.AddConstraint = solver("AddConstraint")
def gen_type(name, args, out=None):
out = out if out else relay.ty.IncompleteType()
solver.AddConstraint(make_rel(name, args + [out]))
return out
solver.gen_type = gen_type
return solver
def test_bcast():
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
t1 = relay.ty.TensorType((10, 1), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t2 = solver.gen_type("Broadcast", [t0, t1])
t3 = solver.gen_type("Identity", [t2])
t4 = solver.gen_type("Broadcast", [t3, tc])
assert solver.Solve()
assert solver.Resolve(t2) == relay.ty.TensorType((10, 20), "float32")
assert solver.Resolve(t4) == relay.ty.TensorType((10, 10, 20), "float32")
def test_backward_solving():
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t3 = solver.gen_type("Broadcast", [t0, t1])
t2 = solver.gen_type("Identity", [t1], out=tc)
assert solver.Solve()
assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")
if __name__ == "__main__":
test_bcast()
test_backward_solving()
*.pb
*.mlmodel
*.ttf
*.txt
*synset*txt
*.cfg
ssd_model
*.names
*.jpg
*.pbtxt
*.weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment