Unverified Commit f4c4fde4 by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Polish ir/type (#4705)

- Use consistent constructor style to construct objects.
- Move env_func to ir as it is mainly used to construct IRs.
- Make docs consistent.
parent 3f2abfbc
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file tvm/relay/adt.h * \file tvm/ir/adt.h
* \brief Algebraic data type definitions. * \brief Algebraic data type definitions.
* *
* We adopt relay's ADT definition as a unified class * We adopt relay's ADT definition as a unified class
......
...@@ -18,21 +18,24 @@ ...@@ -18,21 +18,24 @@
*/ */
/*! /*!
* \file tvm/node/env_func.h * \file tvm/ir/env_func.h
* \brief Serializable global function. * \brief Serializable global function used in IR.
*/ */
#ifndef TVM_NODE_ENV_FUNC_H_ #ifndef TVM_IR_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_ #define TVM_IR_ENV_FUNC_H_
#include <tvm/node/reflection.h> #include <tvm/node/reflection.h>
#include <string> #include <string>
#include <utility> #include <utility>
namespace tvm { namespace tvm {
/*! /*!
* \brief Node container of EnvFunc * \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \sa EnvFunc * \sa EnvFunc
*/ */
class EnvFuncNode : public Object { class EnvFuncNode : public Object {
...@@ -53,11 +56,8 @@ class EnvFuncNode : public Object { ...@@ -53,11 +56,8 @@ class EnvFuncNode : public Object {
}; };
/*! /*!
* \brief A serializable function backed by TVM's global environment. * \brief Managed reference to EnvFuncNode.
* * \sa EnvFuncNode
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
*/ */
class EnvFunc : public ObjectRef { class EnvFunc : public ObjectRef {
public: public:
...@@ -140,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef { ...@@ -140,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
}; };
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_ENV_FUNC_H_ #endif // TVM_IR_ENV_FUNC_H_
...@@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode { ...@@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode {
/*! \brief The name of the variable, this only acts as a hint. */ /*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint; std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
......
...@@ -48,13 +48,13 @@ class IRModule; ...@@ -48,13 +48,13 @@ class IRModule;
class IRModuleNode : public Object { class IRModuleNode : public Object {
public: public:
/*! \brief A map from ids to all global functions. */ /*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions; Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */ /*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions; Map<GlobalTypeVar, TypeData> type_definitions;
IRModuleNode() {} IRModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions); v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_); v->Visit("global_var_map_", &global_var_map_);
...@@ -146,7 +146,7 @@ class IRModuleNode : public Object { ...@@ -146,7 +146,7 @@ class IRModuleNode : public Object {
* \brief Collect all global vars defined in this module. * \brief Collect all global vars defined in this module.
* \returns An array of global vars * \returns An array of global vars
*/ */
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const; TVM_DLL Array<GlobalVar> GetGlobalVars() const;
/*! /*!
* \brief Look up a global function by its name. * \brief Look up a global function by its name.
...@@ -159,7 +159,7 @@ class IRModuleNode : public Object { ...@@ -159,7 +159,7 @@ class IRModuleNode : public Object {
* \brief Collect all global type vars defined in this module. * \brief Collect all global type vars defined in this module.
* \returns An array of global type vars * \returns An array of global type vars
*/ */
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const; TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*! /*!
* \brief Look up a global function by its variable. * \brief Look up a global function by its variable.
...@@ -235,12 +235,12 @@ class IRModuleNode : public Object { ...@@ -235,12 +235,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that /*! \brief A map from string names to global variables that
* ensures global uniqueness. * ensures global uniqueness.
*/ */
tvm::Map<std::string, GlobalVar> global_var_map_; Map<std::string, GlobalVar> global_var_map_;
/*! \brief A map from string names to global type variables (ADT names) /*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness. * that ensures global uniqueness.
*/ */
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_; Map<std::string, GlobalTypeVar> global_type_var_map_;
/*! \brief A map from constructor tags to constructor objects /*! \brief A map from constructor tags to constructor objects
* for convenient access * for convenient access
...@@ -266,8 +266,8 @@ class IRModule : public ObjectRef { ...@@ -266,8 +266,8 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module. * \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module * \param import_set Set of imported files in the module
*/ */
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions, TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {}, Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {}); std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */ /*! \brief default constructor */
IRModule() {} IRModule() {}
...@@ -296,8 +296,8 @@ class IRModule : public ObjectRef { ...@@ -296,8 +296,8 @@ class IRModule : public ObjectRef {
*/ */
TVM_DLL static IRModule FromExpr( TVM_DLL static IRModule FromExpr(
const RelayExpr& expr, const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {}, const Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {}); const Map<GlobalTypeVar, TypeData>& type_definitions = {});
/*! /*!
* \brief Parse text format source file into an IRModule. * \brief Parse text format source file into an IRModule.
......
...@@ -91,7 +91,7 @@ class OpNode : public RelayExprNode { ...@@ -91,7 +91,7 @@ class OpNode : public RelayExprNode {
*/ */
int32_t support_level = 10; int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("op_type", &op_type); v->Visit("op_type", &op_type);
v->Visit("description", &description); v->Visit("description", &description);
...@@ -476,7 +476,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -476,7 +476,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in"; std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) { for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i); auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeKind::kType); auto param = TypeVar(name, TypeKind::kType);
type_params.push_back(param); type_params.push_back(param);
arg_types.push_back(param); arg_types.push_back(param);
} }
...@@ -484,7 +484,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -484,7 +484,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types; Array<Type> ty_call_args = arg_types;
// Add output type. // Add output type.
auto out_param = TypeVarNode::make("out", TypeKind::kType); auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param); type_params.push_back(out_param);
// this will trigger copy on write. // this will trigger copy on write.
ty_call_args.push_back(out_param); ty_call_args.push_back(out_param);
...@@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel(
// A common example is sum(x, axis), where the choice of axis // A common example is sum(x, axis), where the choice of axis
// can affect the type of the function. // can affect the type of the function.
TypeConstraint type_rel = TypeConstraint type_rel =
TypeRelationNode::make(env_type_rel_func, TypeRelation(env_type_rel_func,
ty_call_args, ty_call_args,
arg_types.size(), arg_types.size(),
Attrs()); Attrs());
auto func_type = auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); FuncType(arg_types, out_param, type_params, {type_rel});
get()->op_type = func_type; get()->op_type = func_type;
......
...@@ -84,13 +84,13 @@ class PassContextNode : public Object { ...@@ -84,13 +84,13 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)}; int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */ /*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass; Array<PrimExpr> required_pass;
/*! \brief The list of disabled passes. */ /*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass; Array<PrimExpr> disabled_pass;
PassContextNode() = default; PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level); v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device); v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass); v->Visit("required_pass", &required_pass);
...@@ -118,7 +118,7 @@ class PassContextNode : public Object { ...@@ -118,7 +118,7 @@ class PassContextNode : public Object {
class PassContext : public ObjectRef { class PassContext : public ObjectRef {
public: public:
PassContext() {} PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! /*!
* \brief const accessor. * \brief const accessor.
* \return const access pointer. * \return const access pointer.
...@@ -158,7 +158,7 @@ class PassContext : public ObjectRef { ...@@ -158,7 +158,7 @@ class PassContext : public ObjectRef {
// Classes to get the Python `with` like syntax. // Classes to get the Python `with` like syntax.
friend class Internal; friend class Internal;
friend class tvm::With<PassContext>; friend class With<PassContext>;
}; };
/*! /*!
...@@ -174,11 +174,11 @@ class PassInfoNode : public Object { ...@@ -174,11 +174,11 @@ class PassInfoNode : public Object {
std::string name; std::string name;
/*! \brief The passes that are required to perform the current pass. */ /*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required; Array<PrimExpr> required;
PassInfoNode() = default; PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level); v->Visit("opt_level", &opt_level);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("required", &required); v->Visit("required", &required);
...@@ -202,7 +202,7 @@ class PassInfo : public ObjectRef { ...@@ -202,7 +202,7 @@ class PassInfo : public ObjectRef {
*/ */
TVM_DLL PassInfo(int opt_level, TVM_DLL PassInfo(int opt_level,
std::string name, std::string name,
tvm::Array<tvm::PrimExpr> required); Array<PrimExpr> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
}; };
...@@ -241,7 +241,7 @@ class PassNode : public Object { ...@@ -241,7 +241,7 @@ class PassNode : public Object {
virtual IRModule operator()(const IRModule& mod, virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0; const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass"; static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
...@@ -289,7 +289,7 @@ class Sequential : public Pass { ...@@ -289,7 +289,7 @@ class Sequential : public Pass {
* \param passes The passes to apply. * \param passes The passes to apply.
* \param pass_info The pass metadata. * \param pass_info The pass metadata.
*/ */
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info); TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
/*! /*!
* \brief The constructor of `Sequential`. * \brief The constructor of `Sequential`.
...@@ -299,10 +299,10 @@ class Sequential : public Pass { ...@@ -299,10 +299,10 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them * This allows users to only provide a list of passes and execute them
* under a given context. * under a given context.
*/ */
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential"); TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
Sequential() = default; Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
const SequentialNode* operator->() const; const SequentialNode* operator->() const;
using ContainerType = Sequential; using ContainerType = Sequential;
...@@ -322,7 +322,7 @@ Pass CreateModulePass( ...@@ -322,7 +322,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const tvm::Array<tvm::PrimExpr>& required); const Array<PrimExpr>& required);
} // namespace transform } // namespace transform
} // namespace tvm } // namespace tvm
......
...@@ -50,15 +50,26 @@ ...@@ -50,15 +50,26 @@
#define TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/data_type.h>
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <tvm/node/env_func.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/span.h> #include <tvm/ir/span.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
/*! \brief Base type of all the types. */ /*!
* \brief Type is the base type of all types.
*
* Relay's type system contains following subclasses:
*
* - PrimType: type of primitive type values used in the low-level IR.
* - FuncType: type of a function.
* - TensorType: type of certain Tensor values in the expression.
*
* There are also advanced types to support generic(polymorphic types).
* \sa Type
*/
class TypeNode : public Object { class TypeNode : public Object {
public: public:
/*! /*!
...@@ -72,29 +83,58 @@ class TypeNode : public Object { ...@@ -72,29 +83,58 @@ class TypeNode : public Object {
}; };
/*! /*!
* \brief Type is the base type of all types. * \brief Managed reference to TypeNode.
* * \sa TypeNode
* Relay's type system contains following two key concepts:
*
* - PrimitiveType: type of primitive type values used in the low-level IR.
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/ */
class Type : public ObjectRef { class Type : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode); TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
}; };
/*!
* \brief Primitive data types used in the low-level IR.
*
* PrimType represents POD-values and handles that are
* not automatically managed by the runtime.
*
* \sa PrimType
*/
class PrimTypeNode : public TypeNode {
public:
/*!
* \brief The corresponding dtype field.
*/
runtime::DataType dtype;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
}
static constexpr const char* _type_key = "relay.PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
/*!
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
*/
class PrimType : public Type {
public:
/*!
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL PrimType(runtime::DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
/*! \brief Possible kinds of TypeVars. */ /*! \brief Possible kinds of TypeVars. */
enum TypeKind : int { enum TypeKind : int {
kType = 0, kType = 0,
/*! \brief Template variable in shape expression. */ /*! \brief Template variable in shape expression. */
kShapeVar = 1, kShapeVar = 1,
kBaseType = 2, kBaseType = 2,
kShape = 3,
kConstraint = 4, kConstraint = 4,
kAdtHandle = 5, kAdtHandle = 5,
kTypeData = 6 kTypeData = 6
...@@ -115,10 +155,8 @@ enum TypeKind : int { ...@@ -115,10 +155,8 @@ enum TypeKind : int {
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
* *
* \endcode * \endcode
* \sa TypeVarNode The actual container class of TypeVar * \sa TypeVar, TypeKind
*/ */
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode { class TypeVarNode : public TypeNode {
public: public:
/*! /*!
...@@ -130,28 +168,36 @@ class TypeVarNode : public TypeNode { ...@@ -130,28 +168,36 @@ class TypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
TypeKind kind; TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static TypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.TypeVar"; static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
}; };
/*!
* \brief Managed reference to TypeVarNode
* \sa TypeVarNode
*/
class TypeVar : public Type { class TypeVar : public Type {
public: public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(std::string name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
}; };
/*! /*!
* \brief A global type variable that is used for defining new types or type aliases. * \brief A global type variable that is used for defining new types or type aliases.
* \sa GlobalTypeVar
*/ */
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode { class GlobalTypeVarNode : public TypeNode {
public: public:
/*! /*!
...@@ -163,47 +209,98 @@ class GlobalTypeVarNode : public TypeNode { ...@@ -163,47 +209,98 @@ class GlobalTypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
TypeKind kind; TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint); v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind); v->Visit("kind", &kind);
} }
TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar"; static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
}; };
/*!
* \brief Managed reference to GlobalTypeVarNode
* \sa GlobalTypeVarNode
*/
class GlobalTypeVar : public Type { class GlobalTypeVar : public Type {
public: public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
}; };
/*! /*!
* \brief Potential Constraints in the type. * \brief The type of tuple values.
* \note This is reserved for future use. * \sa TupleType
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
/*!
* \brief Managed reference to TupleTypeNode.
* \sa TupleTypeNode.
*/
class TupleType : public Type {
public:
/*!
* \brief Constructor
* \param fields Fields in the tuple.
*/
TVM_DLL explicit TupleType(Array<Type> fields);
/*!
* \brief Create an empty tuple type that constains nothing.
* \return A empty tuple type.
*/
TVM_DLL TupleType static Empty();
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/ */
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode { class TypeConstraintNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "relay.TypeConstraint"; static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
}; };
/*!
* \brief Managed reference to TypeConstraintNode.
* \sa TypeConstraintNode, TypeRelation
*/
class TypeConstraint : public Type { class TypeConstraint : public Type {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
}; };
class FuncType;
/*! /*!
* \brief Function type in Relay. * \brief Function type.
* *
* Relay support polymorphic function type. * We support polymorphic function type.
* This can be roughly viewed as template function in C++. * This can be roughly viewed as template function in C++.
* *
* \sa TypeVar, TypeConstraint * \sa FuncType, TypeVar, TypeConstraint
*/ */
class FuncTypeNode : public TypeNode { class FuncTypeNode : public TypeNode {
public: public:
...@@ -221,7 +318,7 @@ class FuncTypeNode : public TypeNode { ...@@ -221,7 +318,7 @@ class FuncTypeNode : public TypeNode {
*/ */
Array<TypeConstraint> type_constraints; Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("arg_types", &arg_types); v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type); v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params); v->Visit("type_params", &type_params);
...@@ -229,17 +326,29 @@ class FuncTypeNode : public TypeNode { ...@@ -229,17 +326,29 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static FuncType make(Array<Type> arg_types,
Type ret_type,
Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType"; static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
}; };
/*!
* \brief Managed reference to FuncTypeNode.
* \sa FuncTypeNode
*/
class FuncType : public Type { class FuncType : public Type {
public: public:
/*!
* \brief Constructor
* \param arg_types The types of the arguments.
* \param ret_type The type of the return value.
* \param type_params The type parameters.
* \param type_constraints The type constraints.
* \sa FuncTypeNode for more docs about these fields.
*/
TVM_DLL FuncType(Array<Type> arg_types,
Type ret_type,
Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
}; };
......
...@@ -19,19 +19,56 @@ ...@@ -19,19 +19,56 @@
/*! /*!
* \file tvm/ir/type_relation.h * \file tvm/ir/type_relation.h
* \brief Type relation function for type checking. * \brief Type relation and function for type inference(checking).
*/ */
#ifndef TVM_IR_TYPE_RELATION_H_ #ifndef TVM_IR_TYPE_RELATION_H_
#define TVM_IR_TYPE_RELATION_H_ #define TVM_IR_TYPE_RELATION_H_
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/ir/module.h>
#include <tvm/ir/env_func.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
namespace tvm { namespace tvm {
// TODO(tqchen): remove after migrate Module to ir. /*!
class IRModule; * \brief Type function application.
* \sa TypeCall
*/
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
Array<Type> args;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
/*!
* \brief Managed reference to TypeCallNode.
* \sa TypeCallNode
*/
class TypeCall : public Type {
public:
/*!
* \brief Constructor
* \param func The type function to apply.
* \param args The arguments to the type function.
*/
TVM_DLL TypeCall(Type func, Array<Type> args);
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*! /*!
* \brief reporter that reports back to the * \brief reporter that reports back to the
...@@ -78,7 +115,7 @@ class TypeReporterNode : public Object { ...@@ -78,7 +115,7 @@ class TypeReporterNode : public Object {
TVM_DLL virtual IRModule GetModule() = 0; TVM_DLL virtual IRModule GetModule() = 0;
// solver is not serializable. // solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter"; static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
...@@ -91,7 +128,7 @@ class TypeReporterNode : public Object { ...@@ -91,7 +128,7 @@ class TypeReporterNode : public Object {
class TypeReporter : public ObjectRef { class TypeReporter : public ObjectRef {
public: public:
TypeReporter() {} TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {
} }
TypeReporterNode* operator->() const { TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>( return const_cast<TypeReporterNode*>(
...@@ -127,12 +164,11 @@ using TypeRelationFn = ...@@ -127,12 +164,11 @@ using TypeRelationFn =
/*! /*!
* \brief User defined type relation, is an input-output relation on types. * \brief User defined type relation, is an input-output relation on types.
*/ *
class TypeRelation; * TypeRelation is more generalized than type call as it allows inference
/*! * of both inputs and outputs.
* \brief TypeRelation container. *
* \note This node is not directly serializable. * \sa TypeRelation
* The type function need to be lookedup in the module.
*/ */
class TypeRelationNode : public TypeConstraintNode { class TypeRelationNode : public TypeConstraintNode {
public: public:
...@@ -143,13 +179,13 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -143,13 +179,13 @@ class TypeRelationNode : public TypeConstraintNode {
*/ */
TypeRelationFn func; TypeRelationFn func;
/*! \brief The type arguments to the type function. */ /*! \brief The type arguments to the type function. */
tvm::Array<Type> args; Array<Type> args;
/*! \brief Number of inputs arguments */ /*! \brief Number of inputs arguments */
int num_inputs; int num_inputs;
/*! \brief Attributes to the relation function */ /*! \brief Attributes to the relation function */
Attrs attrs; Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func); v->Visit("func", &func);
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs); v->Visit("num_inputs", &num_inputs);
...@@ -157,17 +193,29 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -157,17 +193,29 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation"; static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
}; };
/*!
* \brief Managed reference to TypeRelationNode.
* \sa TypeRelationNode
*/
class TypeRelation : public TypeConstraint { class TypeRelation : public TypeConstraint {
public: public:
/*!
* \brief Constructor
* \param func The relation function.
* \param args The arguments to the type relation.
* \param num_inputs Number of inputs.
* \param attrs Attributes to the relation function.
* \sa TypeRelationNode for more docs about these fields.
*/
TVM_DLL TypeRelation(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs);
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
}; };
} // namespace tvm } // namespace tvm
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/ir/type_relation.h> #include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <string> #include <string>
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
// namespace update for backward compact
// will be removed later.
using Any = tvm::ir::AnyNode; using Any = tvm::ir::AnyNode;
using Kind = TypeKind; using Kind = TypeKind;
using Type = tvm::Type; using Type = tvm::Type;
...@@ -47,10 +49,14 @@ using TypeVar = tvm::TypeVar; ...@@ -47,10 +49,14 @@ using TypeVar = tvm::TypeVar;
using TypeVarNode = tvm::TypeVarNode; using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar; using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode; using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TupleType = tvm::TupleType;
using TupleTypeNode = tvm::TupleTypeNode;
using TypeConstraint = tvm::TypeConstraint; using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode; using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType; using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode; using FuncTypeNode = tvm::FuncTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation; using TypeRelation = tvm::TypeRelation;
using TypeRelationNode = tvm::TypeRelationNode; using TypeRelationNode = tvm::TypeRelationNode;
using TypeRelationFn = tvm::TypeRelationFn; using TypeRelationFn = tvm::TypeRelationFn;
...@@ -119,37 +125,6 @@ class TensorType : public Type { ...@@ -119,37 +125,6 @@ class TensorType : public Type {
}; };
/*! /*!
* \brief Type application.
*/
class TypeCall;
/*! \brief TypeCall container node */
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
class TypeCall : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*!
* \brief IncompleteType. * \brief IncompleteType.
* This is intermediate values that is used during type inference. * This is intermediate values that is used during type inference.
* *
...@@ -181,36 +156,6 @@ class IncompleteType : public Type { ...@@ -181,36 +156,6 @@ class IncompleteType : public Type {
}; };
/*! /*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
class TupleType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \brief The type of reference values. * \brief The type of reference values.
*/ */
class RefType; class RefType;
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/node/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
/*! /*!
* \file env_func.cc * \file env_func.cc
*/ */
#include <tvm/node/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/expr.h> #include <tvm/expr.h>
......
...@@ -27,18 +27,38 @@ ...@@ -27,18 +27,38 @@
namespace tvm { namespace tvm {
TypeVar TypeVarNode::make(std::string name, TypeKind kind) { PrimType::PrimType(runtime::DataType dtype) {
ObjectPtr<PrimTypeNode> n = make_object<PrimTypeNode>();
n->dtype = dtype;
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PrimTypeNode);
TVM_REGISTER_GLOBAL("relay._make.PrimType")
.set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PrimTypeNode*>(ref.get());
p->stream << node->dtype;
});
TypeVar::TypeVar(std::string name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>(); ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name); n->name_hint = std::move(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return TypeVar(n); data_ = std::move(n);
} }
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar") TVM_REGISTER_GLOBAL("relay._make.TypeVar")
.set_body_typed([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<TypeKind>(kind)); return TypeVar(name, static_cast<TypeKind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
...@@ -48,18 +68,19 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -48,18 +68,19 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->kind << ")"; << node->kind << ")";
}); });
GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>(); ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name); n->name_hint = std::move(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return GlobalTypeVar(n); data_ = std::move(n);
} }
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
.set_body_typed([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind)); return GlobalTypeVar(name, static_cast<TypeKind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
...@@ -69,22 +90,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -69,22 +90,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->kind << ")"; << node->kind << ")";
}); });
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, FuncType::FuncType(tvm::Array<Type> arg_types,
Type ret_type, Type ret_type,
tvm::Array<TypeVar> type_params, tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) { tvm::Array<TypeConstraint> type_constraints) {
ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>(); ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
n->arg_types = std::move(arg_types); n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type); n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params); n->type_params = std::move(type_params);
n->type_constraints = std::move(type_constraints); n->type_constraints = std::move(type_constraints);
return FuncType(n); data_ = std::move(n);
} }
TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType") TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make); .set_body_typed([](tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
return FuncType(arg_types, ret_type, type_params, type_constraints);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) { .set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
...@@ -94,4 +120,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -94,4 +120,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->type_constraints << ")"; << node->type_constraints << ")";
}); });
TupleType::TupleType(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
data_ = std::move(n);
}
TupleType TupleType::Empty() {
return TupleType(Array<Type>());
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
.set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
} // namespace tvm } // namespace tvm
...@@ -27,22 +27,49 @@ ...@@ -27,22 +27,49 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args, TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
int num_inputs, ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
Attrs attrs) { n->func = std::move(func);
n->args = std::move(args);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
.set_body_typed([](Type func, Array<Type> type) {
return TypeCall(func, type);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
TypeRelation::TypeRelation(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>(); ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
n->func = std::move(func); n->func = std::move(func);
n->args = std::move(args); n->args = std::move(args);
n->num_inputs = num_inputs; n->num_inputs = num_inputs;
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
return TypeRelation(n); data_ = std::move(n);
} }
TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation") TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make); .set_body_typed([](TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
return TypeRelation(func, args, num_inputs, attrs);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) { .set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
......
...@@ -246,7 +246,7 @@ class ScheduleGetter : ...@@ -246,7 +246,7 @@ class ScheduleGetter :
new_fields.push_back(field); new_fields.push_back(field);
} }
} }
call_node_type = TupleTypeNode::make(new_fields); call_node_type = TupleType(new_fields);
} }
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
......
...@@ -135,7 +135,7 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -135,7 +135,7 @@ FuncType FunctionNode::func_type_annotation() const {
Type ret_type = (this->ret_type.defined()) ? this->ret_type Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteTypeNode::make(Kind::kType); : IncompleteTypeNode::make(Kind::kType);
return FuncTypeNode::make(param_types, ret_type, this->type_params, {}); return FuncType(param_types, ret_type, this->type_params, {});
} }
bool FunctionNode::IsPrimitive() const { bool FunctionNode::IsPrimitive() const {
......
...@@ -63,25 +63,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -63,25 +63,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
n->func = std::move(func);
n->args = std::move(args);
return TypeCall(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
.set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
IncompleteType IncompleteTypeNode::make(Kind kind) { IncompleteType IncompleteTypeNode::make(Kind kind) {
auto n = make_object<IncompleteTypeNode>(); auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind); n->kind = std::move(kind);
...@@ -102,23 +83,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -102,23 +83,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}); });
TupleType TupleTypeNode::make(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
return TupleType(n);
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
.set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
RefType RefTypeNode::make(Type value) { RefType RefTypeNode::make(Type value) {
ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>(); ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>();
n->value = std::move(value); n->value = std::move(value);
......
...@@ -154,7 +154,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { ...@@ -154,7 +154,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
changed = changed || !new_ret_type.same_as(op->ret_type); changed = changed || !new_ret_type.same_as(op->ret_type);
if (!changed) return GetRef<Type>(op); if (!changed) return GetRef<Type>(op);
return FuncTypeNode::make(new_args, return FuncType(new_args,
new_ret_type, new_ret_type,
type_params, type_params,
type_constraints); type_constraints);
...@@ -165,7 +165,7 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { ...@@ -165,7 +165,7 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
if (new_fields.same_as(op->fields)) { if (new_fields.same_as(op->fields)) {
return GetRef<Type>(op); return GetRef<Type>(op);
} else { } else {
return TupleTypeNode::make(new_fields); return TupleType(new_fields);
} }
} }
...@@ -178,7 +178,7 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { ...@@ -178,7 +178,7 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
if (new_args.same_as(type_rel->args)) { if (new_args.same_as(type_rel->args)) {
return GetRef<Type>(type_rel); return GetRef<Type>(type_rel);
} else { } else {
return TypeRelationNode::make(type_rel->func, return TypeRelation(type_rel->func,
new_args, new_args,
type_rel->num_inputs, type_rel->num_inputs,
type_rel->attrs); type_rel->attrs);
...@@ -195,7 +195,7 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) { ...@@ -195,7 +195,7 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) {
if (new_args.same_as(op->args) && new_func.same_as(op->func)) { if (new_args.same_as(op->args) && new_func.same_as(op->func)) {
return GetRef<TypeCall>(op); return GetRef<TypeCall>(op);
} else { } else {
return TypeCallNode::make(new_func, new_args); return TypeCall(new_func, new_args);
} }
} }
......
...@@ -55,7 +55,7 @@ bool TopKRel(const Array<Type>& types, ...@@ -55,7 +55,7 @@ bool TopKRel(const Array<Type>& types,
auto values_ty = TensorTypeNode::make(out_shape, data->dtype); auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype); auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") { if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty})); reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") { } else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty); reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") { } else if (param->ret_type == "indices") {
......
...@@ -65,7 +65,7 @@ bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attr ...@@ -65,7 +65,7 @@ bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attr
auto mod = reporter->GetModule(); auto mod = reporter->GetModule();
CHECK(mod.defined()); CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage"); auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = TypeCallNode::make(storage_name, {}); auto storage = TypeCall(storage_name, {});
reporter->Assign(types[2], storage); reporter->Assign(types[2], storage);
return true; return true;
} }
...@@ -136,7 +136,7 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs ...@@ -136,7 +136,7 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
auto mod = reporter->GetModule(); auto mod = reporter->GetModule();
CHECK(mod.defined()); CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage"); auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = relay::TypeCallNode::make(storage_name, {}); auto storage = relay::TypeCall(storage_name, {});
reporter->Assign(types[0], storage); reporter->Assign(types[0], storage);
// Second argument should be shape tensor. // Second argument should be shape tensor.
auto tt = types[1].as<TensorTypeNode>(); auto tt = types[1].as<TensorTypeNode>();
...@@ -196,15 +196,15 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs ...@@ -196,15 +196,15 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
<< "internal invariant violated: invoke_tvm_op outputs must be a tuple"; << "internal invariant violated: invoke_tvm_op outputs must be a tuple";
Type ex_output; Type ex_output;
if (func_type->ret_type.as<TensorTypeNode>()) { if (func_type->ret_type.as<TensorTypeNode>()) {
ex_output = TupleTypeNode::make({func_type->ret_type}); ex_output = TupleType({func_type->ret_type});
} else { } else {
CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type"; CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
ex_output = func_type->ret_type; ex_output = func_type->ret_type;
} }
auto ex_input = TupleTypeNode::make(func_type->arg_types); auto ex_input = TupleType(func_type->arg_types);
reporter->Assign(ex_input, GetRef<Type>(input_type)); reporter->Assign(ex_input, GetRef<Type>(input_type));
reporter->Assign(ex_output, GetRef<Type>(output_type)); reporter->Assign(ex_output, GetRef<Type>(output_type));
reporter->Assign(types[3], TupleTypeNode::make({})); reporter->Assign(types[3], TupleType::Empty());
return true; return true;
} }
...@@ -236,7 +236,7 @@ bool KillRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -236,7 +236,7 @@ bool KillRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2u); CHECK_EQ(types.size(), 2u);
// TODO(@jroesch): should only support tensors. // TODO(@jroesch): should only support tensors.
reporter->Assign(types[1], TupleTypeNode::make({})); reporter->Assign(types[1], TupleType::Empty());
return true; return true;
} }
...@@ -297,7 +297,7 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -297,7 +297,7 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto func_type = types[0].as<FuncTypeNode>(); auto func_type = types[0].as<FuncTypeNode>();
CHECK(func_type != nullptr); CHECK(func_type != nullptr);
auto tuple = TupleTypeNode::make(func_type->arg_types); auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenType(tuple); auto in_types = FlattenType(tuple);
auto out_types = FlattenType(func_type->ret_type); auto out_types = FlattenType(func_type->ret_type);
...@@ -318,12 +318,12 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -318,12 +318,12 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64))); shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64)));
} }
auto input_type = TupleTypeNode::make(shape_func_ins); auto input_type = TupleType(shape_func_ins);
auto output_type = TupleTypeNode::make(shape_func_outs); auto output_type = TupleType(shape_func_outs);
reporter->Assign(types[1], input_type); reporter->Assign(types[1], input_type);
reporter->Assign(types[2], output_type); reporter->Assign(types[2], output_type);
reporter->Assign(types[3], TupleTypeNode::make({})); reporter->Assign(types[3], TupleType::Empty());
return true; return true;
} }
......
...@@ -586,7 +586,7 @@ bool DropoutRel(const Array<Type>& types, ...@@ -586,7 +586,7 @@ bool DropoutRel(const Array<Type>& types,
// dropout returns the original tensor with dropout applied // dropout returns the original tensor with dropout applied
// and a mask tensor (1.0 where element not dropped, 0.0 where dropped) // and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
auto ret_type = TensorTypeNode::make(data->shape, data->dtype); auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>({ret_type, ret_type}))); reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
return true; return true;
} }
...@@ -674,7 +674,7 @@ bool BatchNormRel(const Array<Type>& types, ...@@ -674,7 +674,7 @@ bool BatchNormRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
fields.push_back(vec_ty); fields.push_back(vec_ty);
fields.push_back(vec_ty); fields.push_back(vec_ty);
reporter->Assign(types[5], TupleTypeNode::make(Array<Type>(fields))); reporter->Assign(types[5], TupleType(Array<Type>(fields)));
return true; return true;
} }
......
...@@ -109,7 +109,7 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a ...@@ -109,7 +109,7 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype)); output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype)); output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(output_types))); reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
return true; return true;
} }
......
...@@ -2150,7 +2150,7 @@ bool SplitRel(const Array<Type>& types, ...@@ -2150,7 +2150,7 @@ bool SplitRel(const Array<Type>& types,
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
} }
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields))); reporter->Assign(types[1], TupleType(Array<Type>(fields)));
} else { } else {
auto indices = param->indices_or_sections.as<ArrayNode>()->data; auto indices = param->indices_or_sections.as<ArrayNode>()->data;
auto begin = IndexExpr(make_zero(DataType::Int(32))); auto begin = IndexExpr(make_zero(DataType::Int(32)));
...@@ -2170,7 +2170,7 @@ bool SplitRel(const Array<Type>& types, ...@@ -2170,7 +2170,7 @@ bool SplitRel(const Array<Type>& types,
oshape[axis] = data->shape[axis] - begin; oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields))); reporter->Assign(types[1], TupleType(Array<Type>(fields)));
} }
return true; return true;
} }
......
...@@ -125,7 +125,7 @@ bool MultiBoxTransformLocRel(const Array<Type>& types, ...@@ -125,7 +125,7 @@ bool MultiBoxTransformLocRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32))); fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32)));
// assign output type // assign output type
reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(fields))); reporter->Assign(types[3], TupleType(Array<Type>(fields)));
return true; return true;
} }
......
...@@ -44,7 +44,7 @@ bool GetValidCountRel(const Array<Type>& types, ...@@ -44,7 +44,7 @@ bool GetValidCountRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
// assign output type // assign output type
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields))); reporter->Assign(types[1], TupleType(Array<Type>(fields)));
return true; return true;
} }
......
...@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) { ...@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
public PatternMutator { public PatternMutator {
public: public:
TypeVar Fresh(const TypeVar& tv) { TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind); TypeVar ret = TypeVar(tv->name_hint, tv->kind);
type_rename_[tv] = ret; type_rename_[tv] = ret;
return ret; return ret;
} }
......
...@@ -42,7 +42,7 @@ class TypeVarReplacer : public TypeMutator { ...@@ -42,7 +42,7 @@ class TypeVarReplacer : public TypeMutator {
Type VisitType_(const TypeVarNode* type_var_node) final { Type VisitType_(const TypeVarNode* type_var_node) final {
const auto type_var = GetRef<TypeVar>(type_var_node); const auto type_var = GetRef<TypeVar>(type_var_node);
if (replace_map_.find(type_var) == replace_map_.end()) { if (replace_map_.find(type_var) == replace_map_.end()) {
replace_map_[type_var] = TypeVarNode::make("A", Kind::kType); replace_map_[type_var] = TypeVar("A", Kind::kType);
} }
return replace_map_[type_var]; return replace_map_[type_var];
} }
...@@ -109,7 +109,7 @@ class EtaExpander : public ExprMutator { ...@@ -109,7 +109,7 @@ class EtaExpander : public ExprMutator {
type_params.push_back(type_var_replacer_.VisitType(type_var)); type_params.push_back(type_var_replacer_.VisitType(type_var));
} }
Expr body = CallNode::make(cons, params, Attrs()); Expr body = CallNode::make(cons, params, Attrs());
Type ret_type = TypeCallNode::make(cons->belong_to, type_params); Type ret_type = TypeCall(cons->belong_to, type_params);
return FunctionNode::make( return FunctionNode::make(
Downcast<tvm::Array<Var>>(params), Downcast<tvm::Array<Var>>(params),
......
...@@ -73,10 +73,10 @@ Type WithGradientType(const Type& t) { ...@@ -73,10 +73,10 @@ Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking // TODO(M.K.): stricter checking
auto ty = t.as<FuncTypeNode>(); auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function"; CHECK(ty) << "input should be a function";
return FuncTypeNode::make(ty->arg_types, return FuncType(ty->arg_types,
TupleTypeNode::make({ TupleType({
ty->ret_type, ty->ret_type,
TupleTypeNode::make(ty->arg_types)}), {}, {}); TupleType(ty->arg_types)}), {}, {});
} }
//! \brief if the expression is a GlobalVar, transform to it's expression. //! \brief if the expression is a GlobalVar, transform to it's expression.
...@@ -219,7 +219,7 @@ Type GradRetType(const Function& f) { ...@@ -219,7 +219,7 @@ Type GradRetType(const Function& f) {
vt.push_back(p->type_annotation); vt.push_back(p->type_annotation);
} }
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); return TupleType({f->ret_type, TupleType(vt)});
} }
Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
...@@ -265,7 +265,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") ...@@ -265,7 +265,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
struct ReverseADType : TypeMutator { struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final { Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn); Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)}); return TupleType({t, RefTypeNode::make(t)});
} }
}; };
...@@ -299,7 +299,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f, ...@@ -299,7 +299,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
types.push_back(field->checked_type_); types.push_back(field->checked_type_);
} }
auto ret = TupleNode::make(fields); auto ret = TupleNode::make(fields);
ret->checked_type_ = TupleTypeNode::make(types); ret->checked_type_ = TupleType(types);
return std::move(ret); return std::move(ret);
} else { } else {
LOG(FATAL) << "unsupported input/output type: " << tt; LOG(FATAL) << "unsupported input/output type: " << tt;
...@@ -385,7 +385,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { ...@@ -385,7 +385,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
} }
Expr BPEmpty() { Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {});
return RefCreateNode::make(unitF); return RefCreateNode::make(unitF);
} }
...@@ -426,7 +426,7 @@ struct ReverseAD : ExprMutator { ...@@ -426,7 +426,7 @@ struct ReverseAD : ExprMutator {
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {})); ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {}); return CallNode::make(bpv, {});
}), }),
TupleTypeNode::make({}), TupleType::Empty(),
{}); {});
ll->Push(RefWriteNode::make(bp, nbp)); ll->Push(RefWriteNode::make(bp, nbp));
return ret; return ret;
...@@ -468,7 +468,7 @@ struct ReverseAD : ExprMutator { ...@@ -468,7 +468,7 @@ struct ReverseAD : ExprMutator {
} }
return CallNode::make(bpv, {}); return CallNode::make(bpv, {});
}), }),
TupleTypeNode::make({}), TupleType::Empty(),
{}); {});
ll->Push(RefWriteNode::make(bp, nbp)); ll->Push(RefWriteNode::make(bp, nbp));
return ret; return ret;
......
...@@ -63,7 +63,7 @@ namespace relay { ...@@ -63,7 +63,7 @@ namespace relay {
// we assume the data type has no closure - no idea how to look into datatype right now. // we assume the data type has no closure - no idea how to look into datatype right now.
Type Arrow(const Type& l, const Type& r) { Type Arrow(const Type& l, const Type& r) {
return FuncTypeNode::make({l}, r, {}, {}); return FuncType({l}, r, {}, {});
} }
Type CPSType(const Type& t, const TypeVar& answer); Type CPSType(const Type& t, const TypeVar& answer);
...@@ -74,7 +74,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { ...@@ -74,7 +74,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) {
new_arg_types.push_back(CPSType(t, answer)); new_arg_types.push_back(CPSType(t, answer));
} }
new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer)); new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer));
return FuncTypeNode::make(new_arg_types, answer, f->type_params, f->type_constraints); return FuncType(new_arg_types, answer, f->type_params, f->type_constraints);
} }
Type CPSType(const Type& t, const TypeVar& answer) { Type CPSType(const Type& t, const TypeVar& answer) {
...@@ -302,7 +302,7 @@ Function ToCPS(const Function& f, ...@@ -302,7 +302,7 @@ Function ToCPS(const Function& f,
} }
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
TypeVar answer = TypeVarNode::make("answer", kType); TypeVar answer = TypeVar("answer", kType);
VarMap var; VarMap var;
struct Remapper : ExprVisitor, PatternVisitor { struct Remapper : ExprVisitor, PatternVisitor {
Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { }
...@@ -348,7 +348,7 @@ Function UnCPS(const Function& f) { ...@@ -348,7 +348,7 @@ Function UnCPS(const Function& f) {
auto new_ret_type = Type(cont_type->arg_types[0]); auto new_ret_type = Type(cont_type->arg_types[0]);
std::vector<TypeVar> new_type_params; std::vector<TypeVar> new_type_params;
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind)); new_type_params.push_back(TypeVar(tp->name_hint, tp->kind));
} }
auto answer_type = new_type_params.back(); auto answer_type = new_type_params.back();
new_type_params.pop_back(); new_type_params.pop_back();
......
...@@ -206,7 +206,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -206,7 +206,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (Expr field : op->fields) { for (Expr field : op->fields) {
types.push_back(GetType(field)); types.push_back(GetType(field));
} }
return TupleTypeNode::make(types); return TupleType(types);
} }
Type VisitExpr_(const TupleGetItemNode* op) final { Type VisitExpr_(const TupleGetItemNode* op) final {
...@@ -218,7 +218,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -218,7 +218,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type rtype = IncompleteTypeNode::make(Kind::kType); Type rtype = IncompleteTypeNode::make(Kind::kType);
auto attrs = make_object<TupleGetItemAttrs>(); auto attrs = make_object<TupleGetItemAttrs>();
attrs->index = op->index; attrs->index = op->index;
solver_.AddConstraint(TypeRelationNode::make( solver_.AddConstraint(TypeRelation(
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op)); tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op));
return rtype; return rtype;
} }
...@@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (size_t i = 0; i < td->type_vars.size(); i++) { for (size_t i = 0; i < td->type_vars.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
} }
Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); Type expected = TypeCall(con->constructor->belong_to, unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(con)); Type unified = Unify(t, expected, GetRef<ObjectRef>(con));
auto* tc = unified.as<TypeCallNode>(); auto* tc = unified.as<TypeCallNode>();
...@@ -277,7 +277,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -277,7 +277,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (size_t i = 0; i < tup->patterns.size(); i++) { for (size_t i = 0; i < tup->patterns.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
} }
Type expected = TupleTypeNode::make(unknown_args); Type expected = TupleType(unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(tup)); Type unified = Unify(t, expected, GetRef<ObjectRef>(tup));
auto* tt = unified.as<TupleTypeNode>(); auto* tt = unified.as<TupleTypeNode>();
...@@ -388,7 +388,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -388,7 +388,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type rtype = IncompleteTypeNode::make(Kind::kType); Type rtype = IncompleteTypeNode::make(Kind::kType);
arg_types.push_back(rtype); arg_types.push_back(rtype);
// we can do simple replacement here // we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make( solver_.AddConstraint(TypeRelation(
rel->func, arg_types, arg_types.size() - 1, attrs), loc); rel->func, arg_types, arg_types.size() - 1, attrs), loc);
return rtype; return rtype;
} }
...@@ -418,7 +418,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -418,7 +418,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
ret_type = IncompleteTypeNode::make(Kind::kType); ret_type = IncompleteTypeNode::make(Kind::kType);
} }
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, Type inst_ty = FuncType(fn_ty->arg_types,
ret_type, {}, ret_type, {},
fn_ty->type_constraints); fn_ty->type_constraints);
inst_ty = Bind(inst_ty, subst_map); inst_ty = Bind(inst_ty, subst_map);
...@@ -467,7 +467,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -467,7 +467,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// with an unknown return type // with an unknown return type
if (inc_ty_node != nullptr) { if (inc_ty_node != nullptr) {
Type ret_type = IncompleteTypeNode::make(Kind::kType); Type ret_type = IncompleteTypeNode::make(Kind::kType);
Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); Type func_type = FuncType(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, GetRef<Call>(call)); Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
fn_ty_node = unified.as<FuncTypeNode>(); fn_ty_node = unified.as<FuncTypeNode>();
} }
...@@ -513,7 +513,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -513,7 +513,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (auto cs : fn_ty->type_constraints) { for (auto cs : fn_ty->type_constraints) {
if (const auto* tr = cs.as<TypeRelationNode>()) { if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint( solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call)); GetRef<Call>(call));
} else { } else {
solver_.AddConstraint(cs, GetRef<Call>(call)); solver_.AddConstraint(cs, GetRef<Call>(call));
...@@ -557,7 +557,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -557,7 +557,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f)); rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
} }
CHECK(rtype.defined()); CHECK(rtype.defined());
auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); auto ret = FuncType(arg_types, rtype, f->type_params, {});
return solver_.Resolve(ret); return solver_.Resolve(ret);
} }
...@@ -575,7 +575,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -575,7 +575,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type it = IncompleteTypeNode::make(Kind::kType); Type it = IncompleteTypeNode::make(Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op)); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op)); this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleTypeNode::make({}); return TupleType::Empty();
} }
Type VisitExpr_(const ConstructorNode* c) final { Type VisitExpr_(const ConstructorNode* c) final {
...@@ -587,7 +587,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -587,7 +587,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (const auto & t : td->type_vars) { for (const auto & t : td->type_vars) {
types.push_back(t); types.push_back(t);
} }
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), return FuncType(c->inputs, TypeCall(c->belong_to, types),
td->type_vars, {}); td->type_vars, {});
} }
......
...@@ -286,7 +286,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -286,7 +286,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
Type field = Unify(tt1->fields[i], tt2->fields[i]); Type field = Unify(tt1->fields[i], tt2->fields[i]);
new_fields.push_back(field); new_fields.push_back(field);
} }
return TupleTypeNode::make(new_fields); return TupleType(new_fields);
} }
Type VisitType_(const FuncTypeNode* op, const Type& tn) final { Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
...@@ -314,7 +314,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -314,7 +314,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType)); subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
} }
FuncType ft = FuncTypeNode::make(op->arg_types, FuncType ft = FuncType(op->arg_types,
op->ret_type, op->ret_type,
ft_type_params, ft_type_params,
op->type_constraints); op->type_constraints);
...@@ -339,7 +339,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -339,7 +339,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints.push_back(GetRef<TypeConstraint>(tcn)); type_constraints.push_back(GetRef<TypeConstraint>(tcn));
} }
return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints); return FuncType(arg_types, ret_type, ft2->type_params, type_constraints);
} }
Type VisitType_(const RefTypeNode* op, const Type& tn) final { Type VisitType_(const RefTypeNode* op, const Type& tn) final {
...@@ -361,7 +361,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -361,7 +361,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
for (size_t i = 0; i < op->args.size(); i++) { for (size_t i = 0; i < op->args.size(); i++) {
args.push_back(Unify(op->args[i], tcn->args[i])); args.push_back(Unify(op->args[i], tcn->args[i]));
} }
return TypeCallNode::make(func, args); return TypeCall(func, args);
} }
private: private:
......
...@@ -37,7 +37,7 @@ TEST(Relay, SelfReference) { ...@@ -37,7 +37,7 @@ TEST(Relay, SelfReference) {
mod = relay::transform::InferType()(mod); mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main"); auto type_fx = mod->Lookup("main");
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {}); auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment