Unverified Commit 9319b6f9 by Tianqi Chen Committed by GitHub

[RELAY] Refactor type inference to use type solver (#1779)

parent 6d037db7
...@@ -190,6 +190,13 @@ build* ...@@ -190,6 +190,13 @@ build*
# Jetbrain # Jetbrain
.idea .idea
.ipython
.jupyter
.nv
.pylint.d
.python_history
.pytest_cache
.local
# tmp file # tmp file
.nfs* .nfs*
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../attrs.h" #include "base.h"
#include "./base.h" #include "expr.h"
#include "./expr.h" #include "type.h"
#include "./type.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -45,7 +44,7 @@ class OpNode : public relay::ExprNode { ...@@ -45,7 +44,7 @@ class OpNode : public relay::ExprNode {
Array<AttrFieldInfo> arguments; Array<AttrFieldInfo> arguments;
/*! /*!
* \brief The type key of the attribute field * \brief The type key of the attribute field
* This can be empty, in which case it defaults to * This can be empty, in which case it defaults to anything.
*/ */
std::string attrs_type_key; std::string attrs_type_key;
/*! /*!
...@@ -156,11 +155,13 @@ class OpRegistry { ...@@ -156,11 +155,13 @@ class OpRegistry {
*/ */
inline OpRegistry& add_type_rel( inline OpRegistry& add_type_rel(
const std::string& rel_name, const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func); runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func);
/*! /*!
* \brief Set the type key of attributes. * \brief Set the type key of attributes.
* \param type_key The type of of the attrs field.x * \param type_key The type of of the attrs field.
* \return reference to self. * \return reference to self.
*/ */
inline OpRegistry& set_attrs_type_key(const std::string& type_key); inline OpRegistry& set_attrs_type_key(const std::string& type_key);
...@@ -348,23 +349,25 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, ...@@ -348,23 +349,25 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name,
inline OpRegistry& OpRegistry::add_type_rel( inline OpRegistry& OpRegistry::add_type_rel(
const std::string& rel_name, const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func) { runtime::TypedPackedFunc<bool(const Array<Type>&,
int,
const Attrs&,
const TypeReporter&)> type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name; auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
TypedEnvFunc<Array<Type>(const Array<Type>&, int)> env_type_rel_func;
if (runtime::Registry::Get(func_name)) { if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name); auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; env_type_rel_func = env_func;
} else { } else {
runtime::Registry::Register(func_name) runtime::Registry::Register(func_name)
.set_body_typed<Array<Type>(const Array<Type>&, int)>(type_rel_func); .set_body(type_rel_func.packed());
auto env_func = EnvFunc::Get(func_name); auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; env_type_rel_func = env_func;
} }
std::vector<TypeParam> type_params; Array<TypeParam> type_params;
std::vector<Type> arg_types; Array<Type> arg_types;
// Add inputs. // Add inputs.
std::string input_name_prefix = "in"; std::string input_name_prefix = "in";
...@@ -375,15 +378,27 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -375,15 +378,27 @@ inline OpRegistry& OpRegistry::add_type_rel(
arg_types.push_back(param); arg_types.push_back(param);
} }
auto ty_call_args = Array<Type>(arg_types); Array<Type> ty_call_args = arg_types;
// Add output type. // Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
type_params.push_back(out_param); type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param); ty_call_args.push_back(out_param);
// The attributes of primitive op is nullptr
//
// The attributes of primitive operator can vary at the call site.
// The type of sum is also dependent on Attrs being passed.
// So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
//
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel = TypeConstraint type_rel =
TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args); TypeRelationNode::make(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());
auto func_type = auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
......
...@@ -26,7 +26,7 @@ namespace relay { ...@@ -26,7 +26,7 @@ namespace relay {
* \return A type checked expression with its checked_type field populated. * \return A type checked expression with its checked_type field populated.
*/ */
Expr InferType(const Environment& env, const Expr& e); Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& v, const Function& e); Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
/*! /*!
* \brief Check that types are well formed by applying "kinding rules". * \brief Check that types are well formed by applying "kinding rules".
...@@ -69,7 +69,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); ...@@ -69,7 +69,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2);
* *
* For example: `forall s, Tensor[f32, s]` is equal to * For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`. * `forall w, Tensor[f32, w]`.
* *
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details. * for more details.
* *
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include "./base.h" #include "base.h"
#include "../attrs.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -116,10 +117,10 @@ class TypeParamNode : public TypeNode { ...@@ -116,10 +117,10 @@ class TypeParamNode : public TypeNode {
/*! \brief possible kinds of TypeParam */ /*! \brief possible kinds of TypeParam */
enum Kind : int { enum Kind : int {
/*! \brief template variable in shape expression */ /*! \brief template variable in shape expression */
kShapeVar = 0, kType = 0,
kShape = 1, kShapeVar = 1,
kBaseType = 2, kBaseType = 2,
kType = 3 kShape = 3
}; };
/*! /*!
* \brief The variable itself is only meaningful when * \brief The variable itself is only meaningful when
...@@ -144,6 +145,33 @@ class TypeParamNode : public TypeNode { ...@@ -144,6 +145,33 @@ class TypeParamNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
/*! /*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
}
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
/*!
* \brief Potential Constraints in the type. * \brief Potential Constraints in the type.
* \note This is reserved for future use. * \note This is reserved for future use.
*/ */
...@@ -190,7 +218,8 @@ class FuncTypeNode : public TypeNode { ...@@ -190,7 +218,8 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static FuncType make(tvm::Array<Type> arg_types, Type ret_type, TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints); tvm::Array<TypeConstraint> type_constraints);
...@@ -200,11 +229,102 @@ class FuncTypeNode : public TypeNode { ...@@ -200,11 +229,102 @@ class FuncTypeNode : public TypeNode {
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
class TypeReporter;
/*!
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Node {
public:
/*!
* \brief Create a type equality constraint.
*
* The "assign direction" acts as a hint to the solver
* showing that it is more likely to resolve dst by src.
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
*/
TVM_DLL virtual void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public NodeRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
}
TypeReporterNode* operator->() const {
return static_cast<TypeReporterNode*>(node_.get());
}
using ContainerType = TypeReporterNode;
};
/*!
* \brief User defined type constraint function.
*
* If the input type information can be used to fully decide
* the IncompleteTypes, then the function should call
* reporter.Assign to report the new types, and return true.
* Otherwise, the function should return false.
*
* \param args The arguments to the relation.
* The types are stored in the form of
* [input_type_0, input_type_1, ... input_type_n,
* output_type_0, output_type_1, ... output_type_m]
*
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
using TypeRelationFn = using TypeRelationFn =
TypedEnvFunc<Array<Type>(const Array<Type>&, int)>; TypedEnvFunc<bool(const Array<Type>& args,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter)>;
/*! /*!
* \brief Opaque type relation, is an input-output relation on types. * \brief User defined type relation, is an input-output relation on types.
*/ */
class TypeRelation; class TypeRelation;
/*! /*!
...@@ -214,24 +334,30 @@ class TypeRelation; ...@@ -214,24 +334,30 @@ class TypeRelation;
*/ */
class TypeRelationNode : public TypeConstraintNode { class TypeRelationNode : public TypeConstraintNode {
public: public:
/*! \brief The name of the function */
std::string name;
/*! /*!
* \brief The function on input and output variables which * \brief The function on input and output variables which
* this is not directly serializable, * this is not directly serializable,
* need to be looked-up in the environment. * need to be looked-up in the environment.
*/ */
TypeRelationFn func_; TypeRelationFn func;
/*! \brief The type arguments to the type function. */ /*! \brief The type arguments to the type function. */
tvm::Array<Type> args; tvm::Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
} }
TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array<Type> args); TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation"; static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
...@@ -239,30 +365,6 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -239,30 +365,6 @@ class TypeRelationNode : public TypeConstraintNode {
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
// The following fields contains advanced typing // The following fields contains advanced typing
// Only keep the class name and reserved for future usage. // Only keep the class name and reserved for future usage.
class GenericTensorType; class GenericTensorType;
......
...@@ -175,7 +175,17 @@ class TypedPackedFunc<R(Args...)> { ...@@ -175,7 +175,17 @@ class TypedPackedFunc<R(Args...)> {
* *
* \param packed The packed function * \param packed The packed function
*/ */
inline explicit TypedPackedFunc(PackedFunc packed); inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
/*!
* \brief constructor from TVMRetValue
* \param value The TVMRetValue
*/
inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
/*!
* \brief constructor from TVMArgValue
* \param value The TVMArgValue
*/
inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
/*! /*!
* \brief construct from a lambda function with the same signature. * \brief construct from a lambda function with the same signature.
* *
...@@ -196,7 +206,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -196,7 +206,7 @@ class TypedPackedFunc<R(Args...)> {
std::is_convertible<FLambda, std::is_convertible<FLambda,
std::function<R(Args...)> std::function<R(Args...)>
>::value>::type> >::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) { TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda); this->AssignTypedLambda(typed_lambda);
} }
/*! /*!
...@@ -1144,6 +1154,14 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) ...@@ -1144,6 +1154,14 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
: packed_(packed) {} : packed_(packed) {}
template<typename R, typename ...Args> template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value)
: packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
: packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args>
template<typename FType> template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
......
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node from .base import NodeBase, register_relay_node
from ._ir_pass import _get_checked_type
from . import _make from . import _make
from .. import convert from .. import convert
class Expr(NodeBase): class Expr(NodeBase):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
def checked_type(self): def checked_type(self):
return _get_checked_type(self) ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret
def __call__(self, *args): def __call__(self, *args):
converted_args = [] converted_args = []
......
...@@ -52,11 +52,10 @@ class Kind(IntEnum): ...@@ -52,11 +52,10 @@ class Kind(IntEnum):
with. For example one's of kind BaseType can only be `float32`, `int32`, with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on. and so on.
""" """
ShapeVar = 0 Type = 0
Shape = 1 ShapeVar = 1
BaseType = 2 BaseType = 2
Type = 3 Shape = 3
@register_relay_node @register_relay_node
class TypeParam(Type): class TypeParam(Type):
...@@ -68,7 +67,7 @@ class TypeParam(Type): ...@@ -68,7 +67,7 @@ class TypeParam(Type):
functions which are generic over types. functions which are generic over types.
""" """
def __init__(self, var, kind): def __init__(self, var, kind=Kind.Type):
"""Construct a TypeParam. """Construct a TypeParam.
Parameters Parameters
...@@ -76,7 +75,7 @@ class TypeParam(Type): ...@@ -76,7 +75,7 @@ class TypeParam(Type):
var: tvm.expr.Var var: tvm.expr.Var
The tvm.Var which backs the type parameter. The tvm.Var which backs the type parameter.
kind: Kind kind: Kind, optional
The kind of the type parameter. The kind of the type parameter.
Returns Returns
...@@ -130,8 +129,7 @@ class FuncType(Type): ...@@ -130,8 +129,7 @@ class FuncType(Type):
arg_types, arg_types,
ret_type, ret_type,
type_params, type_params,
type_constraints type_constraints):
):
"""Construct a function type. """Construct a function type.
Parameters Parameters
...@@ -153,6 +151,29 @@ class FuncType(Type): ...@@ -153,6 +151,29 @@ class FuncType(Type):
@register_relay_node @register_relay_node
class IncompleteType(Type): class IncompleteType(Type):
"""An incomplete type.""" """An incomplete type."""
def __init__(self, kind=Kind.Type):
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind) self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
/*!
* Copyright 2018 by Contributors
*
* \file arena.h
* \brief Arena allocator that allocates
* memory chunks and frees them all during destruction time.
*/
#ifndef TVM_COMMON_ARENA_H_
#define TVM_COMMON_ARENA_H_
#include <type_traits>
namespace tvm {
namespace common {
const constexpr int kArenaPageSize = 16 << 10;
/*!
* \brief Arena allocator that allocates memory from continuous
* chunk and frees them all only during destruction.
*/
class Arena {
public:
Arena() {
// eagerly allocate the first page.
head_ = reinterpret_cast<PageHeader*>(new Page());
head_->next = nullptr;
head_->ptr = sizeof(PageHeader);
}
~Arena() {
// delete all the allocated pages.
while (head_ != nullptr) {
Page* page = reinterpret_cast<Page*>(head_);
head_ = head_->next;
delete page;
}
}
/*!
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
*/
template<typename T>
T* Alloc() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
}
private:
// page size 16 KB
// The page data type;
using Page = std::aligned_storage<kArenaPageSize, 1024>::type;
/*! \brief Page header */
struct PageHeader {
/*! \brief points to the next page */
PageHeader* next;
/*! \brief memory allocator ptr inside page */
size_t ptr;
};
/* \brief The page header */
PageHeader* head_{nullptr};
/*!
* \brief Align ptr by upper bound.
* \param ptr The pointer value.
* \param align The alignment requirement.
*/
size_t UpperAlign(size_t ptr, size_t align) {
return ptr + (align - (ptr % align)) % align;
}
/*!
* \brief Internal aligned alloc function.
* \param size The size of the memory.
* \param align The alignment requirement.
*/
void* Alloc(size_t size, size_t align) {
size_t ptr = UpperAlign(head_->ptr, align);
if (ptr + size <= kArenaPageSize) {
head_->ptr = ptr + size;
return reinterpret_cast<char*>(head_) + ptr;
} else {
PageHeader* new_head = reinterpret_cast<PageHeader*>(new Page());
new_head->next = head_;
ptr = UpperAlign(sizeof(PageHeader), align);
CHECK_LE(ptr + size, kArenaPageSize);
new_head->ptr = ptr + size;
head_ = new_head;
return reinterpret_cast<char*>(head_) + ptr;
}
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_ARENA_H_
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <tvm/relay/environment.h> #include <tvm/relay/environment.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <sstream> #include <sstream>
#include "./../pass/resolve.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -49,7 +48,7 @@ void EnvironmentNode::Add(const GlobalVar &var, ...@@ -49,7 +48,7 @@ void EnvironmentNode::Add(const GlobalVar &var,
auto checked_func = GetRef<Function>(func_node); auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type(); auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type)); CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
if (!update) { if (!update) {
...@@ -68,7 +67,7 @@ void EnvironmentNode::Add(const GlobalVar &var, ...@@ -68,7 +67,7 @@ void EnvironmentNode::Add(const GlobalVar &var,
this->functions.Set(var, checked_func); this->functions.Set(var, checked_func);
} }
} else { } else {
throw Error("internal error: unknown item type, unreachable code"); LOG(FATAL) << "internal error: unknown item type, unreachable code";
} }
} }
......
...@@ -55,7 +55,27 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -55,7 +55,27 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->kind << ")"; << node->kind << ")";
}); });
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type, IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
}
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>(
[](const IncompleteTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params, tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints) { tvm::Array<TypeConstraint> type_constraints) {
NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>(); NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
...@@ -79,24 +99,28 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -79,24 +99,28 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->type_constraints << ")"; << node->type_constraints << ")";
}); });
TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) { TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>(); NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
n->name = std::move(name); n->func = std::move(func);
n->func_ = std::move(func);
n->args = std::move(args); n->args = std::move(args);
n->num_inputs = num_inputs;
n->attrs = std::move(attrs);
return TypeRelation(n); return TypeRelation(n);
} }
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2]); *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, .set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
tvm::IRPrinter *p) { p->stream << "TypeRelationNode("
p->stream << "TypeRelationNode(" << node->name << ", " << node->args << node->func->name
<< ")"; << ", " << node->args << ")";
}); });
TupleType TupleTypeNode::make(Array<Type> fields) { TupleType TupleTypeNode::make(Array<Type> fields) {
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <tvm/relay/logging.h> #include <tvm/relay/logging.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <numeric> #include <numeric>
#include "../pass/incomplete_type.h"
#include "./type_relations.h" #include "./type_relations.h"
namespace tvm { namespace tvm {
...@@ -30,18 +29,19 @@ int ToInt(const tvm::Expr& e) { ...@@ -30,18 +29,19 @@ int ToInt(const tvm::Expr& e) {
return imm->value; return imm->value;
} }
Array<Type> IdentityRel(const Array<Type>& types, int num_args) { bool IdentityRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); int num_inputs,
auto t1 = ToTensorType(types[0]); const Attrs& attrs,
if (t1 && types[1].as<IncompleteTypeNode>()) { const TypeReporter& reporter) {
return {t1, t1}; for (size_t i = 1; i < types.size(); ++i) {
} else { reporter->Assign(types[i], types[0]);
return types;
} }
return true;
} }
static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, Type ConcreteBroadcast(const TensorType& t1,
DataType output_dtype) { const TensorType& t2,
DataType output_dtype) {
RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2
<< std::endl; << std::endl;
auto sh1 = t1->shape; auto sh1 = t1->shape;
...@@ -73,7 +73,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, ...@@ -73,7 +73,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2,
Array<ShapeExpr> smaller; Array<ShapeExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) { for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); smaller.push_back(make_const(tvm::Int(64), 1));
} }
if (sh1.size() < sh2.size()) { if (sh1.size() < sh2.size()) {
...@@ -93,46 +93,52 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, ...@@ -93,46 +93,52 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2,
CHECK_EQ(larger.size(), smaller.size()); CHECK_EQ(larger.size(), smaller.size());
Array<HalideIR::Expr> out_shape; Array<ShapeExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) { for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>(); auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>(); auto right = larger[i].as<tvm::ir::IntImm>();
CHECK(left); CHECK(left);
CHECK(right); CHECK(right);
int64_t dim = std::max(left->value, right->value); int64_t dim = std::max(left->value, right->value);
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); out_shape.push_back(make_const(tvm::Int(64), dim));
} }
return TensorTypeNode::make(out_shape, output_dtype); return TensorTypeNode::make(out_shape, output_dtype);
} }
} }
Array<Type> BroadcastRel(const Array<Type>& types, int num_args) { bool BroadcastRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1]
<< "Out: " << types[2] << std::endl; << "Out: " << types[2] << std::endl;
if (auto t1 = ToTensorType(types[0])) { if (auto t0 = ToTensorType(types[0])) {
if (auto t2 = ToTensorType(types[1])) { if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t1->dtype, t2->dtype); CHECK_EQ(t0->dtype, t1->dtype);
return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype));
return true;
} }
} }
return false;
return types;
} }
/* A relation which specifies broadcasting rules for operations which bool BroadcastCompRel(const Array<Type>& types,
compute boolean results. int num_inputs,
*/ const Attrs& attrs,
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
if (auto t1 = ToTensorType(types[0])) { RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1]
if (auto t2 = ToTensorType(types[1])) { << "Out: " << types[2] << std::endl;
return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::Bool()));
return true;
} }
} }
return false;
return types;
} }
/*! \brief Handle concrete concat case from known input to output. */ /*! \brief Handle concrete concat case from known input to output. */
...@@ -175,10 +181,10 @@ inline Type ConcreteConcatRel(const Type& input_type) { ...@@ -175,10 +181,10 @@ inline Type ConcreteConcatRel(const Type& input_type) {
auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0); auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0);
Array<tvm::Expr> out_shape = { tvm::ir::IntImm::make(HalideIR::Int(64), out_axis_dim) }; Array<tvm::Expr> out_shape = { make_const(Int(64), out_axis_dim) };
for (auto dim : dims) { for (auto dim : dims) {
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); out_shape.push_back(make_const(Int(64), dim));
} }
return TensorTypeNode::make(out_shape, dtype); return TensorTypeNode::make(out_shape, dtype);
...@@ -188,19 +194,18 @@ inline Type ConcreteConcatRel(const Type& input_type) { ...@@ -188,19 +194,18 @@ inline Type ConcreteConcatRel(const Type& input_type) {
} }
} }
Array<Type> ConcatRel(const Array<Type>& types, int num_args) { bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
if (types[0].as<TupleTypeNode>()) {
if (types[0].as<IncompleteTypeNode>() && types[1].as<IncompleteTypeNode>()) { reporter->Assign(types[1], ConcreteConcatRel(types[0]));
return types; return true;
} else if (types[1].as<IncompleteTypeNode>()) {
return { types[0], ConcreteConcatRel(types[0]) };
} else {
throw TypeRelationError(
"can not deduce relationship between the " \
"type of concat's input and output");
} }
return false;
} }
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -24,42 +24,72 @@ struct TypeRelationError : Error { ...@@ -24,42 +24,72 @@ struct TypeRelationError : Error {
: Error(msg) {} : Error(msg) {}
}; };
/*! \brief The identity type relation maps a single input variable /*!
* to the output variable. * \brief The identity type relation, all the types are equal.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> IdentityRel(const Array<Type>& types, int num_args); bool IdentityRel(const Array<Type>& types,
/*! \brief The broadcast type relation, implements the broadcasting int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type. * rule over the two input types producing the broadcasted type.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> BroadcastRel(const Array<Type>& types, int num_args); bool BroadcastRel(const Array<Type>& types,
/*! \brief The broadcast type relation, implements the broadcasting int num_inputs,
* rule over the two input types producing the broadcasted type. const Attrs& attrs,
const TypeReporter& reporter);
/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
* *
* This differs from BroadcastRel in the return dtype, * This differs from BroadcastRel in the return dtype,
* it instead returns bool, for use in comparsion operators * it instead returns bool(uint8), for use in comparsion operators
* such as equal, not_equal, lt, and so on. * such as equal, not_equal, lt, and so on.
* *
* \param types The input and output types to the relation. * \param types The input and output types to the relation.
* \param num_args The number of input arguments. * \param num_inputs The number of input arguments.
* \return The (potentially partial) solution to the relation. * \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args); bool BroadcastCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
/*! \brief The concat relation. /*!
* \brief The The concat relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
* *
* This relation takes a single input which must be a single tensor * This differs from BroadcastRel in the return dtype,
* or an arbitrary sized tuple. It combines these input dimensions * it instead returns bool(uint8), for use in comparsion operators
* together to produce the output example. * such as equal, not_equal, lt, and so on.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/ */
Array<Type> ConcatRel(const Array<Type>& types, int num_args); bool ConcatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
/*!
* Copyright (c) 2018 by Contributors
* \file incomplete_type.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*!
* \brief Represents a portion of an incomplete type.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); }
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file resolve.cc
* \brief Resolve incomplete types to complete types.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include "./resolve.h"
#include "./type_visitor.h"
namespace tvm {
namespace relay {
struct ResolveTypeType : TypeMutator {
const TypeUnifier &unifier;
explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {}
Type VisitType(const Type &t) override {
if (!t.defined()) {
auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
unifier->Insert(inc_ty);
return inc_ty;
} else {
return TypeMutator::VisitType(t);
}
}
Type VisitType_(const IncompleteTypeNode *op) override {
return unifier->Subst(GetRef<IncompleteType>(op));
}
};
struct ResolveTypeExpr : ExprMutator {
const TypeUnifier &unifier;
explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {}
Expr Mutate(const Expr &e) {
// NB: a bit tricky here.
//
// We want to store resolved type without having
// to re-typecheck the entire term.
//
// Since we know that e : T[...] under some holes
// then it is the case that if we resolve types
// present in e, then we can type it under T
// with the wholes filled in.
//
// We will visit e like normal building a new
// term, then resolve e's old type and write
// it back into the new node.
auto new_e = ExprMutator::Mutate(e);
CHECK(e->checked_type_.defined());
auto resolved_cty = VisitType(e->checked_type_);
new_e->checked_type_ = resolved_cty;
return new_e;
}
Type VisitType(const Type &t) {
return ResolveTypeType(unifier).VisitType(t);
}
};
Type Resolve(const TypeUnifier &unifier, const Type &ty) {
CHECK(ty.defined());
return ResolveTypeType(unifier).VisitType(ty);
}
Expr Resolve(const TypeUnifier &unifier, const Expr &expr) {
return ResolveTypeExpr(unifier).Mutate(expr);
}
struct FullyResolved : TypeVisitor<> {
bool incomplete;
FullyResolved() : incomplete(true) {}
void VisitType(const Type &t) override {
if (!t.defined()) {
incomplete = true;
} else {
return TypeVisitor<>::VisitType(t);
}
}
void VisitType_(const IncompleteTypeNode *ty_var) override {
incomplete = false;
}
};
bool IsFullyResolved(const Type &t) {
auto fr = FullyResolved();
fr.VisitType(t);
return fr.incomplete;
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/resolve.h
* \brief Resolve incomplete types to complete types.
*/
#ifndef TVM_RELAY_PASS_RESOLVE_H_
#define TVM_RELAY_PASS_RESOLVE_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./unifier.h"
namespace tvm {
namespace relay {
/*! \brief Resolve a type containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The type to resolve.
* \returns The resolved type.
*/
Type Resolve(const TypeUnifier& unifier, const Type& ty);
/*! \brief Resolve an expression containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The expression to resolve.
* \returns The resolved expression.
*/
Expr Resolve(const TypeUnifier& unifier, const Expr& expr);
/*! \brief Check if all types have been filled in.
* \param t The type.
* \returns True if the type is resolved, false otherwise.
*/
bool IsFullyResolved(const Type& t);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_RESOLVE_H_
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include "./incomplete_type.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_solver.cc
* \brief Type solver implementations.
*/
#include <string>
#include "type_solver.h"
namespace tvm {
namespace relay {
class TypeSolver::Reporter : public TypeReporterNode {
public:
explicit Reporter(TypeSolver* solver)
: solver_(solver) {}
void Assign(const Type& dst, const Type& src) final {
solver_->Unify(dst, src);
}
void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) final {
// TODO(tqchen)
}
private:
TypeSolver* solver_;
};
// constructor
TypeSolver::TypeSolver()
: reporter_(make_node<Reporter>(this)) {
}
// destructor
TypeSolver::~TypeSolver() {
// call destructor of all non-POD arena object
for (TypeNode* ptr : type_nodes_) {
ptr->~TypeNode();
}
for (RelationNode* ptr : rel_nodes_) {
ptr->~RelationNode();
}
}
// Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src) {
// Known limitation
// - handle composite types whose component can be unknown.
// - handle shape pattern matching
TypeNode* lhs = GetTypeNode(dst);
TypeNode* rhs = GetTypeNode(src);
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
lhs->parent = rhs;
CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type))
<< "Incompatible parent types in UF:"
<< lhs->resolved_type << " and " << rhs->resolved_type;
return rhs->resolved_type;
}
}
// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = make<RelationNode>();
rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode);
// populate the type information.
for (size_t i = 0; i < op->args.size(); ++i) {
// insert link to the type list
LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >();
TypeNode* tnode = GetTypeNode(op->args[i]);
tlink->value = tnode;
rnode->type_list.Push(tlink);
// insert type->relation node
LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >();
rlink->value = rnode;
tnode->rel_list.Push(rlink);
}
// add the relation to the working queue.
this->AddToQueue(rnode);
} else {
LOG(FATAL) << "Do not know how to handle constraint type"
<< constraint->type_key();
}
}
// Resolve a type in the solver context.
Type TypeSolver::Resolve(const Type& type) {
auto it = tmap_.find(type);
if (it != tmap_.end()) {
return it->second->FindRoot()->resolved_type;
} else {
return type;
}
}
bool TypeSolver::Solve() {
// update until queue is empty
while (!update_queue_.empty()) {
RelationNode* rnode = update_queue_.front();
const auto& rel = rnode->rel;
update_queue_.pop();
CHECK(!rnode->resolved);
// update the relation with given evidence.
Array<Type> args;
for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) {
args.push_back(tlink->value->FindRoot()->resolved_type);
CHECK_LE(args.size(), rel->args.size());
}
// call the function
bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
// mark inqueue as false after the function call
// so that rnode itself won't get enqueued again.
rnode->inqueue = false;
if (resolved) {
++num_resolved_rels_;
}
rnode->resolved = resolved;
}
// This criterion is not necessarily right for all the possible cases
// TODO(tqchen): We should also count the number of in-complete types.
return num_resolved_rels_ == rel_nodes_.size();
}
// Expose type solver only for debugging purposes.
TVM_REGISTER_API("relay._ir_pass._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto solver = std::make_shared<TypeSolver>();
auto mod = [solver](std::string name) -> PackedFunc {
if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() {
return solver->Solve();
});
} else if (name == "Unify") {
return TypedPackedFunc<void(Type, Type)>([solver](Type lhs, Type rhs) {
solver->Unify(lhs, rhs);
});
} else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) {
return solver->Resolve(t);
});
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
return solver->AddConstraint(c);
});
} else {
return PackedFunc();
}
};
*ret = runtime::TypedPackedFunc<runtime::PackedFunc(std::string)>(mod);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file type_solver.h
* \brief Solver logic for type inference.
*/
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <vector>
#include <queue>
#include "../../common/arena.h"
namespace tvm {
namespace relay {
/*!
* \brief Interface of type solver used in type inference.
*
* TypeSolver works on a list of constraints among incomplete types.
* The user will populate the constraints by AddConstraint and Assign.
* Then we can call Solve to trying to resolve the unknown.
*
* This can be viewed as "type program(computational graph)" of types, where
* the type constraint are operators of the graph and the incomplete
* types are intermediate value of the graph.
* If all the input types are concretely known, we should be able to
* just run a forward pass on the "type program" to get all the types.
*
* The list of constraints representation means we are storing it as a bipartite
* graph instead of a DAG. This is because some constraints might go both direction.
* TypeSolver could take advantage of bidirectional constraints to deduce input
* value given output ones. Never-the-less, we should keep in mind that
* there is a "forward direction" that the TypeSolver should take advantage of.
*/
class TypeSolver {
public:
TypeSolver();
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
* \param constraint The constraint to be added.
*/
void AddConstraint(const TypeConstraint& constraint);
/*!
* \brief Resolve type to the solution type in the solver.
* \param type The type to be resolved.
* \return The resolved type.
*/
Type Resolve(const Type& type);
/*!
* \brief Start to solve the types using the current known information.
* \return Whether all the incomplete types has been fully resolved.
*/
bool Solve();
/*!
* \brief Unify lhs and rhs.
* \param lhs The left operand.
* \param rhs The right operand
*/
Type Unify(const Type& lhs, const Type& rhs);
private:
class Reporter;
struct TypeNode;
struct RelationNode;
// Internally the solver maintains a bipartite graph of Relation and Types.
// All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver.
/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};
/*!
* \brief type node struct
* TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type.
*
* It also contains collection of links to related Relations,
* which is stored in rel_list.
*/
struct TypeNode {
/*! \brief The final resolved type */
Type resolved_type;
/*! \brief type node in the union find algorithm */
TypeNode* parent{nullptr};
/*! \brief list of relations that is related to this type node */
LinkedList<RelationNode*> rel_list;
/*!
* \brief Find the root type node, perform path compression
* \return The root type node.
*/
TypeNode* FindRoot() {
// fast path
if (this->parent == nullptr) return this;
// slow path with path compression.
TypeNode* root = this;
while (root->parent != nullptr) {
root = root->parent;
}
for (TypeNode* p = this; p != root;) {
TypeNode* parent = p->parent;
p->parent = root;
p = parent;
}
return root;
}
};
/*! \brief relation node */
struct RelationNode {
/*! \brief Whether the relation is in the queue to be solved */
bool inqueue{false};
/*! \brief Whether the relation is resolved */
bool resolved{false};
/*! \brief The corresponding type relation */
TypeRelation rel;
/*! \brief list types to this relation */
LinkedList<TypeNode*> type_list;
};
/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */
size_t num_resolved_rels_{0};
/*! \brief map from type node to types. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \breif Internal queue to update the relation */
std::queue<RelationNode*> update_queue_;
/*! \brief allocator of all the internal node obhect*/
common::Arena arena_;
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template<typename T>
T* make() {
T* ptr = arena_.Alloc<T>();
// call constructor
new (ptr) T();
return ptr;
}
/*!
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
* \return The type node.
*/
TypeNode* GetTypeNode(const Type& t) {
auto it = tmap_.find(t);
if (it != tmap_.end()) {
return it->second->FindRoot();
} else {
TypeNode* n = make<TypeNode>();
type_nodes_.push_back(n);
n->resolved_type = t;
tmap_[t] = n;
return n;
}
}
/*!
* \brief Add relation node rel to the update queue
* \param rel The relation node
*/
void AddToQueue(RelationNode* rel) {
if (rel->inqueue) return;
CHECK(!rel->resolved);
rel->inqueue = true;
update_queue_.push(rel);
}
/*!
* \brief Merge rhs type node to lhs
* \param src The source operand
* \param dst The dst operand.
*/
void MergeFromTo(TypeNode* src, TypeNode* dst) {
if (src == dst) return;
src->parent = dst;
// move the link to the to dst
for (auto* rlink = src->rel_list.head; rlink != nullptr;) {
// store next pointer first before rlink get moved
auto* next = rlink->next;
// if the relation is not yet resolved
// send the relation to the new
if (!rlink->value->resolved) {
this->AddToQueue(rlink->value);
dst->rel_list.Push(rlink);
}
rlink = next;
}
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SOLVER_H_
...@@ -108,11 +108,14 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> { ...@@ -108,11 +108,14 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
for (const Type& t : type_rel->args) { for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t)); new_args.push_back(this->VisitType(t));
} }
return TypeRelationNode::make(type_rel->name, type_rel->func_, new_args); return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
} }
Type VisitType_(const IncompleteTypeNode* op) override { Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<IncompleteType>(op); return GetRef<Type>(op);
} }
}; };
......
/*!
* Copyright (c) 2018 by Contributors
* \file include/tvm/relay/pass/unifier.h
* \brief The type unifier which solves a system of equations between
* incomplete types.
*/
#ifndef TVM_RELAY_PASS_UNIFIER_H_
#define TVM_RELAY_PASS_UNIFIER_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./type_functor.h"
namespace tvm {
namespace relay {
struct UnionFindError : dmlc::Error {
explicit UnionFindError(const std::string& msg) : Error(msg) {}
};
struct UnificationError : dmlc::Error {
explicit UnificationError(const std::string& msg) : Error(msg) {}
};
struct SubstitutionError : dmlc::Error {
explicit SubstitutionError(const std::string& msg) : Error(msg) {}
};
/*! \brief A union-find data structure for the type-checker */
class UnionFind;
class UnionFindNode : public Node {
public:
/*! \brief The inernal map from incomplete types to their representatives. */
tvm::Map<IncompleteType, Type> uf_map;
UnionFindNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); }
TVM_DLL static UnionFind make(tvm::Map<IncompleteType, Type> uf_map);
/*! \brief Insert it into the union find.
* \param it The type to add to the union find.
*/
void Insert(const IncompleteType& it);
/*! \brief Union operation, combine two equivalence classes.
* \param it The incomplete type to unify.
* \param ty The other type.
*/
void Unify(const IncompleteType& it, const Type& t);
/*! \brief Find operation, returns the representative of the argument.
* \param it The element to lookup.
*/
Type Find(const IncompleteType& it);
void debug();
void AssertAlphaEqual(const Type& l, const Type& r);
static constexpr const char* _type_key = "relay.UnionFind";
TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node);
};
class UnionFind : public NodeRef {
public:
UnionFind() {}
explicit UnionFind(NodePtr<tvm::Node> p) : NodeRef(p) {}
// The union find structure is mutable so we do not use the standard macros
// and expose the pointer via `->`.
UnionFindNode* operator->() const {
return static_cast<UnionFindNode*>(node_.get());
}
using ContainerType = UnionFindNode;
};
class TypeUnifier;
class TypeUnifierNode : public Node,
private TypeFunctor<Type(const Type&, const Type)> {
public:
UnionFind union_find;
TypeUnifierNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("union_find", &union_find); }
TVM_DLL static TypeUnifier make(UnionFind uf);
/*! \brief Introduces a new type var into the unifier */
void Insert(const IncompleteType& v);
/*! \brief Unifies two types if possible, throws a unification error if it
* cannot */
Type Unify(const Type& t1, const Type& t2);
/*! \brief Attempts to substitute all type vars in t with concrete types,
* throws substitution error if it cannot concretize*/
Type Subst(const Type& t);
// /*! \brief Checks the kinds in the given type */
// Type CheckKinds(const Type& t);
static constexpr const char* _type_key = "relay.TypeUnifier";
TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node);
private:
/*! \brief Unify incomplete type with another type. */
Type UnifyWithIncompleteType(const Type& t1, const IncompleteType tvn2);
/*! \brief Implements unification between two types with incomplete portions.
*/
Type VisitType(const Type& t1, const Type t2) override;
// Visitor Cases
Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override;
Type VisitType_(const TensorTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeParamNode* t1, const Type t2) override;
Type VisitType_(const FuncTypeNode* t1, const Type t2) override;
Type VisitType_(const TupleTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeRelationNode* s1, const Type t2) override;
};
class TypeUnifier : public NodeRef {
public:
TypeUnifier() {}
explicit TypeUnifier(NodePtr<tvm::Node> p) : NodeRef(p) {}
// no const so that unifier can be mutable as a member of typechecker
inline TypeUnifierNode* operator->() const {
return static_cast<TypeUnifierNode*>(node_.get());
}
using ContainerType = TypeUnifierNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_UNIFIER_H_
...@@ -27,8 +27,8 @@ def test_tensor_type(): ...@@ -27,8 +27,8 @@ def test_tensor_type():
def test_type_param(): def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Shape) tp = relay.TypeParam('name', relay.Kind.Type)
assert tp.kind == relay.Kind.Shape assert tp.kind == relay.Kind.Type
# assert tp.span # TODO allow us to set span # assert tp.span # TODO allow us to set span
str(tp) str(tp)
......
...@@ -76,10 +76,10 @@ def test_add_broadcast_op(): ...@@ -76,10 +76,10 @@ def test_add_broadcast_op():
assert_has_type(func.to_func(), expected_ty) assert_has_type(func.to_func(), expected_ty)
def test_dual_op(): def test_dual_op():
"""Program: """Program:
fn (x : Tensor[f32, (10, 10)]) { fn (x : Tensor[f32, (10, 10)]) {
let t1 = log(x); let t1 = log(x);
let t2 = add(t1, x); let t2 = add(t1, x);
return t1; return t1;
} }
""" """
...@@ -93,8 +93,8 @@ def test_dual_op(): ...@@ -93,8 +93,8 @@ def test_dual_op():
def test_decl(): def test_decl():
"""Program: """Program:
def f(x : Tensor[f32, (10, 10)]) { def f(x : Tensor[f32, (10, 10)]) {
let lx = log(x); let lx = log(x);
return lx; return lx;
} }
...@@ -125,7 +125,7 @@ def test_recursion(): ...@@ -125,7 +125,7 @@ def test_recursion():
n = b.param('n', ty='int32') n = b.param('n', ty='int32')
data = b.param('data', ty='float32') data = b.param('data', ty='float32')
with b.decl(f, n, data): with b.decl(f, n, data):
with b.if_scope(equal(n, convert(0.0))): with b.if_scope(equal(n, convert(0))):
b.ret(f(subtract(n, convert(1)), log(data))) b.ret(f(subtract(n, convert(1)), log(data)))
with b.else_scope(): with b.else_scope():
b.ret(data) b.ret(data)
...@@ -152,11 +152,12 @@ def test_concat(): ...@@ -152,11 +152,12 @@ def test_concat():
assert_decl_has_type(ib.env, try_concat2, fn_ty) assert_decl_has_type(ib.env, try_concat2, fn_ty)
if __name__ == "__main__": if __name__ == "__main__":
# test_monomorphic_let() test_recursion()
# test_single_op()
# test_add_op() test_monomorphic_let()
# test_add_broadcast_op() test_single_op()
# test_dual_op() test_add_op()
# test_decl() test_add_broadcast_op()
# test_recursion() test_dual_op()
test_decl()
test_concat() test_concat()
import tvm
from tvm import relay
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
def make_rel(name, args, num_inputs=None, attrs=None):
func = tvm.get_env_func("tvm.relay.type_relation." + name)
if num_inputs is None:
num_inputs = len(args) - 1
return relay.ty.TypeRelation(func, args, num_inputs, attrs)
def make_solver():
solver = relay._ir_pass._test_type_solver()
solver.Solve = solver("Solve")
solver.Unify = solver("Unify")
solver.Resolve = solver("Resolve")
solver.AddConstraint = solver("AddConstraint")
def gen_type(name, args, out=None):
out = out if out else relay.ty.IncompleteType()
solver.AddConstraint(make_rel(name, args + [out]))
return out
solver.gen_type = gen_type
return solver
def test_bcast():
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
t1 = relay.ty.TensorType((10, 1), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t2 = solver.gen_type("Broadcast", [t0, t1])
t3 = solver.gen_type("Identity", [t2])
t4 = solver.gen_type("Broadcast", [t3, tc])
assert solver.Solve()
assert solver.Resolve(t2) == relay.ty.TensorType((10, 20), "float32")
assert solver.Resolve(t4) == relay.ty.TensorType((10, 10, 20), "float32")
def test_backward_solving():
solver = make_solver()
t0 = relay.ty.TensorType((10, 20), "float32")
tc = relay.ty.TensorType((10, 1, 1), "float32")
t1 = relay.ty.IncompleteType()
t3 = solver.gen_type("Broadcast", [t0, t1])
t2 = solver.gen_type("Identity", [t1], out=tc)
assert solver.Solve()
assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32")
if __name__ == "__main__":
test_bcast()
test_backward_solving()
*.pb
*.mlmodel
*.ttf
*.txt
*synset*txt
*.cfg
ssd_model
*.names
*.jpg
*.pbtxt
*.weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment