Unverified Commit f4c4fde4 by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Polish ir/type (#4705)

- Use consistent constructor style to construct objects.
- Move env_func to ir as it is mainly used to construct IRs.
- Make docs consistent.
parent 3f2abfbc
......@@ -18,7 +18,7 @@
*/
/*!
* \file tvm/relay/adt.h
* \file tvm/ir/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
......
......@@ -18,21 +18,24 @@
*/
/*!
* \file tvm/node/env_func.h
* \brief Serializable global function.
* \file tvm/ir/env_func.h
* \brief Serializable global function used in IR.
*/
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_
#ifndef TVM_IR_ENV_FUNC_H_
#define TVM_IR_ENV_FUNC_H_
#include <tvm/node/reflection.h>
#include <string>
#include <utility>
namespace tvm {
/*!
* \brief Node container of EnvFunc
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \sa EnvFunc
*/
class EnvFuncNode : public Object {
......@@ -53,11 +56,8 @@ class EnvFuncNode : public Object {
};
/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \brief Managed reference to EnvFuncNode.
* \sa EnvFuncNode
*/
class EnvFunc : public ObjectRef {
public:
......@@ -140,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};
} // namespace tvm
#endif // TVM_NODE_ENV_FUNC_H_
#endif // TVM_IR_ENV_FUNC_H_
......@@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode {
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......
......@@ -48,13 +48,13 @@ class IRModule;
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, BaseFunc> functions;
Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
Map<GlobalTypeVar, TypeData> type_definitions;
IRModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
......@@ -146,7 +146,7 @@ class IRModuleNode : public Object {
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;
TVM_DLL Array<GlobalVar> GetGlobalVars() const;
/*!
* \brief Look up a global function by its name.
......@@ -159,7 +159,7 @@ class IRModuleNode : public Object {
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*!
* \brief Look up a global function by its variable.
......@@ -235,12 +235,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_var_map_;
Map<std::string, GlobalVar> global_var_map_;
/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
Map<std::string, GlobalTypeVar> global_type_var_map_;
/*! \brief A map from constructor tags to constructor objects
* for convenient access
......@@ -266,8 +266,8 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
*/
TVM_DLL explicit IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions = {},
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
/*! \brief default constructor */
IRModule() {}
......@@ -296,8 +296,8 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
/*!
* \brief Parse text format source file into an IRModule.
......
......@@ -91,7 +91,7 @@ class OpNode : public RelayExprNode {
*/
int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
......@@ -476,7 +476,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeKind::kType);
auto param = TypeVar(name, TypeKind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}
......@@ -484,7 +484,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types;
// Add output type.
auto out_param = TypeVarNode::make("out", TypeKind::kType);
auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
......@@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel(
// A common example is sum(x, axis), where the choice of axis
// can affect the type of the function.
TypeConstraint type_rel =
TypeRelationNode::make(env_type_rel_func,
TypeRelation(env_type_rel_func,
ty_call_args,
arg_types.size(),
Attrs());
auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
FuncType(arg_types, out_param, type_params, {type_rel});
get()->op_type = func_type;
......
......@@ -84,13 +84,13 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass;
Array<PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass;
Array<PrimExpr> disabled_pass;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
......@@ -118,7 +118,7 @@ class PassContextNode : public Object {
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
......@@ -158,7 +158,7 @@ class PassContext : public ObjectRef {
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
friend class With<PassContext>;
};
/*!
......@@ -174,11 +174,11 @@ class PassInfoNode : public Object {
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required;
Array<PrimExpr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
......@@ -202,7 +202,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required);
Array<PrimExpr> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
......@@ -241,7 +241,7 @@ class PassNode : public Object {
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
......@@ -289,7 +289,7 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
TVM_DLL Sequential(Array<Pass> passes, PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
......@@ -299,10 +299,10 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
......@@ -322,7 +322,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const Array<PrimExpr>& required);
} // namespace transform
} // namespace tvm
......
......@@ -50,15 +50,26 @@
#define TVM_IR_TYPE_H_
#include <tvm/runtime/object.h>
#include <tvm/runtime/data_type.h>
#include <tvm/node/node.h>
#include <tvm/node/env_func.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <string>
namespace tvm {
/*! \brief Base type of all the types. */
/*!
* \brief Type is the base type of all types.
*
* Relay's type system contains following subclasses:
*
* - PrimType: type of primitive type values used in the low-level IR.
* - FuncType: type of a function.
* - TensorType: type of certain Tensor values in the expression.
*
* There are also advanced types to support generic(polymorphic types).
* \sa Type
*/
class TypeNode : public Object {
public:
/*!
......@@ -72,29 +83,58 @@ class TypeNode : public Object {
};
/*!
* \brief Type is the base type of all types.
*
* Relay's type system contains following two key concepts:
*
* - PrimitiveType: type of primitive type values used in the low-level IR.
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
* \brief Managed reference to TypeNode.
* \sa TypeNode
*/
class Type : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
};
/*!
* \brief Primitive data types used in the low-level IR.
*
* PrimType represents POD-values and handles that are
* not automatically managed by the runtime.
*
* \sa PrimType
*/
class PrimTypeNode : public TypeNode {
public:
/*!
* \brief The corresponding dtype field.
*/
runtime::DataType dtype;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
}
static constexpr const char* _type_key = "relay.PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
/*!
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
*/
class PrimType : public Type {
public:
/*!
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL PrimType(runtime::DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
......@@ -115,10 +155,8 @@ enum TypeKind : int {
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeVarNode The actual container class of TypeVar
* \sa TypeVar, TypeKind
*/
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*!
......@@ -130,28 +168,36 @@ class TypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
/*!
* \brief Managed reference to TypeVarNode
* \sa TypeVarNode
*/
class TypeVar : public Type {
public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(std::string name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*!
* \brief A global type variable that is used for defining new types or type aliases.
* \sa GlobalTypeVar
*/
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
......@@ -163,47 +209,98 @@ class GlobalTypeVarNode : public TypeNode {
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
}
TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
/*!
* \brief Managed reference to GlobalTypeVarNode
* \sa GlobalTypeVarNode
*/
class GlobalTypeVar : public Type {
public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
* \brief The type of tuple values.
* \sa TupleType
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
/*!
* \brief Managed reference to TupleTypeNode.
* \sa TupleTypeNode.
*/
class TupleType : public Type {
public:
/*!
* \brief Constructor
* \param fields Fields in the tuple.
*/
TVM_DLL explicit TupleType(Array<Type> fields);
/*!
* \brief Create an empty tuple type that constains nothing.
* \return A empty tuple type.
*/
TVM_DLL TupleType static Empty();
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
/*!
* \brief Managed reference to TypeConstraintNode.
* \sa TypeConstraintNode, TypeRelation
*/
class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
class FuncType;
/*!
* \brief Function type in Relay.
* \brief Function type.
*
* Relay support polymorphic function type.
* We support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeVar, TypeConstraint
* \sa FuncType, TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
......@@ -221,7 +318,7 @@ class FuncTypeNode : public TypeNode {
*/
Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
......@@ -229,17 +326,29 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
TVM_DLL static FuncType make(Array<Type> arg_types,
Type ret_type,
Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
/*!
* \brief Managed reference to FuncTypeNode.
* \sa FuncTypeNode
*/
class FuncType : public Type {
public:
/*!
* \brief Constructor
* \param arg_types The types of the arguments.
* \param ret_type The type of the return value.
* \param type_params The type parameters.
* \param type_constraints The type constraints.
* \sa FuncTypeNode for more docs about these fields.
*/
TVM_DLL FuncType(Array<Type> arg_types,
Type ret_type,
Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
......
......@@ -19,19 +19,56 @@
/*!
* \file tvm/ir/type_relation.h
* \brief Type relation function for type checking.
* \brief Type relation and function for type inference(checking).
*/
#ifndef TVM_IR_TYPE_RELATION_H_
#define TVM_IR_TYPE_RELATION_H_
#include <tvm/ir/type.h>
#include <tvm/ir/module.h>
#include <tvm/ir/env_func.h>
#include <tvm/attrs.h>
namespace tvm {
// TODO(tqchen): remove after migrate Module to ir.
class IRModule;
/*!
* \brief Type function application.
* \sa TypeCall
*/
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
Array<Type> args;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
/*!
* \brief Managed reference to TypeCallNode.
* \sa TypeCallNode
*/
class TypeCall : public Type {
public:
/*!
* \brief Constructor
* \param func The type function to apply.
* \param args The arguments to the type function.
*/
TVM_DLL TypeCall(Type func, Array<Type> args);
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*!
* \brief reporter that reports back to the
......@@ -78,7 +115,7 @@ class TypeReporterNode : public Object {
TVM_DLL virtual IRModule GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
......@@ -91,7 +128,7 @@ class TypeReporterNode : public Object {
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {
}
TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>(
......@@ -127,12 +164,11 @@ using TypeRelationFn =
/*!
* \brief User defined type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the module.
*
* TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs.
*
* \sa TypeRelation
*/
class TypeRelationNode : public TypeConstraintNode {
public:
......@@ -143,13 +179,13 @@ class TypeRelationNode : public TypeConstraintNode {
*/
TypeRelationFn func;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
......@@ -157,17 +193,29 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
/*!
* \brief Managed reference to TypeRelationNode.
* \sa TypeRelationNode
*/
class TypeRelation : public TypeConstraint {
public:
/*!
* \brief Constructor
* \param func The relation function.
* \param args The arguments to the type relation.
* \param num_inputs Number of inputs.
* \param attrs Attributes to the relation function.
* \sa TypeRelationNode for more docs about these fields.
*/
TVM_DLL TypeRelation(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs);
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
} // namespace tvm
......
......@@ -28,7 +28,7 @@
#include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir.h>
#include <string>
......@@ -39,6 +39,8 @@
namespace tvm {
namespace relay {
// namespace update for backward compact
// will be removed later.
using Any = tvm::ir::AnyNode;
using Kind = TypeKind;
using Type = tvm::Type;
......@@ -47,10 +49,14 @@ using TypeVar = tvm::TypeVar;
using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TupleType = tvm::TupleType;
using TupleTypeNode = tvm::TupleTypeNode;
using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation;
using TypeRelationNode = tvm::TypeRelationNode;
using TypeRelationFn = tvm::TypeRelationFn;
......@@ -119,37 +125,6 @@ class TensorType : public Type {
};
/*!
* \brief Type application.
*/
class TypeCall;
/*! \brief TypeCall container node */
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
class TypeCall : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode);
};
/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
......@@ -181,36 +156,6 @@ class IncompleteType : public Type {
};
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
class TupleType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \brief The type of reference values.
*/
class RefType;
......
......@@ -25,7 +25,7 @@
#include <tvm/tensor.h>
#include <tvm/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/env_func.h>
#include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h>
......
......@@ -20,7 +20,7 @@
/*!
* \file env_func.cc
*/
#include <tvm/node/env_func.h>
#include <tvm/ir/env_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/expr.h>
......
......@@ -27,18 +27,38 @@
namespace tvm {
TypeVar TypeVarNode::make(std::string name, TypeKind kind) {
PrimType::PrimType(runtime::DataType dtype) {
ObjectPtr<PrimTypeNode> n = make_object<PrimTypeNode>();
n->dtype = dtype;
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PrimTypeNode);
TVM_REGISTER_GLOBAL("relay._make.PrimType")
.set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PrimTypeNode*>(ref.get());
p->stream << node->dtype;
});
TypeVar::TypeVar(std::string name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return TypeVar(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar")
.set_body_typed([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
return TypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
......@@ -48,18 +68,19 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->kind << ")";
});
GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
.set_body_typed([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
......@@ -69,22 +90,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->kind << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
FuncType::FuncType(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->type_constraints = std::move(type_constraints);
return FuncType(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);
.set_body_typed([](tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
return FuncType(arg_types, ret_type, type_params, type_constraints);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
......@@ -94,4 +120,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->type_constraints << ")";
});
TupleType::TupleType(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
data_ = std::move(n);
}
TupleType TupleType::Empty() {
return TupleType(Array<Type>());
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
.set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
} // namespace tvm
......@@ -27,22 +27,49 @@
#include <tvm/packed_func_ext.h>
namespace tvm {
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
n->func = std::move(func);
n->args = std::move(args);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
.set_body_typed([](Type func, Array<Type> type) {
return TypeCall(func, type);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
TypeRelation::TypeRelation(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
n->func = std::move(func);
n->args = std::move(args);
n->num_inputs = num_inputs;
n->attrs = std::move(attrs);
return TypeRelation(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make);
.set_body_typed([](TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
return TypeRelation(func, args, num_inputs, attrs);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
......
......@@ -246,7 +246,7 @@ class ScheduleGetter :
new_fields.push_back(field);
}
}
call_node_type = TupleTypeNode::make(new_fields);
call_node_type = TupleType(new_fields);
}
CHECK(call_node->op.as<OpNode>())
......
......@@ -135,7 +135,7 @@ FuncType FunctionNode::func_type_annotation() const {
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteTypeNode::make(Kind::kType);
return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::IsPrimitive() const {
......
......@@ -63,25 +63,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
n->func = std::move(func);
n->args = std::move(args);
return TypeCall(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
.set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
IncompleteType IncompleteTypeNode::make(Kind kind) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
......@@ -102,23 +83,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
TupleType TupleTypeNode::make(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
return TupleType(n);
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
.set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
RefType RefTypeNode::make(Type value) {
ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>();
n->value = std::move(value);
......
......@@ -154,7 +154,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
changed = changed || !new_ret_type.same_as(op->ret_type);
if (!changed) return GetRef<Type>(op);
return FuncTypeNode::make(new_args,
return FuncType(new_args,
new_ret_type,
type_params,
type_constraints);
......@@ -165,7 +165,7 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
if (new_fields.same_as(op->fields)) {
return GetRef<Type>(op);
} else {
return TupleTypeNode::make(new_fields);
return TupleType(new_fields);
}
}
......@@ -178,7 +178,7 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
if (new_args.same_as(type_rel->args)) {
return GetRef<Type>(type_rel);
} else {
return TypeRelationNode::make(type_rel->func,
return TypeRelation(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
......@@ -195,7 +195,7 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) {
if (new_args.same_as(op->args) && new_func.same_as(op->func)) {
return GetRef<TypeCall>(op);
} else {
return TypeCallNode::make(new_func, new_args);
return TypeCall(new_func, new_args);
}
}
......
......@@ -55,7 +55,7 @@ bool TopKRel(const Array<Type>& types,
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
......
......@@ -65,7 +65,7 @@ bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attr
auto mod = reporter->GetModule();
CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = TypeCallNode::make(storage_name, {});
auto storage = TypeCall(storage_name, {});
reporter->Assign(types[2], storage);
return true;
}
......@@ -136,7 +136,7 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
auto mod = reporter->GetModule();
CHECK(mod.defined());
auto storage_name = mod->GetGlobalTypeVar("Storage");
auto storage = relay::TypeCallNode::make(storage_name, {});
auto storage = relay::TypeCall(storage_name, {});
reporter->Assign(types[0], storage);
// Second argument should be shape tensor.
auto tt = types[1].as<TensorTypeNode>();
......@@ -196,15 +196,15 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
<< "internal invariant violated: invoke_tvm_op outputs must be a tuple";
Type ex_output;
if (func_type->ret_type.as<TensorTypeNode>()) {
ex_output = TupleTypeNode::make({func_type->ret_type});
ex_output = TupleType({func_type->ret_type});
} else {
CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
ex_output = func_type->ret_type;
}
auto ex_input = TupleTypeNode::make(func_type->arg_types);
auto ex_input = TupleType(func_type->arg_types);
reporter->Assign(ex_input, GetRef<Type>(input_type));
reporter->Assign(ex_output, GetRef<Type>(output_type));
reporter->Assign(types[3], TupleTypeNode::make({}));
reporter->Assign(types[3], TupleType::Empty());
return true;
}
......@@ -236,7 +236,7 @@ bool KillRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2u);
// TODO(@jroesch): should only support tensors.
reporter->Assign(types[1], TupleTypeNode::make({}));
reporter->Assign(types[1], TupleType::Empty());
return true;
}
......@@ -297,7 +297,7 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto func_type = types[0].as<FuncTypeNode>();
CHECK(func_type != nullptr);
auto tuple = TupleTypeNode::make(func_type->arg_types);
auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenType(tuple);
auto out_types = FlattenType(func_type->ret_type);
......@@ -318,12 +318,12 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64)));
}
auto input_type = TupleTypeNode::make(shape_func_ins);
auto output_type = TupleTypeNode::make(shape_func_outs);
auto input_type = TupleType(shape_func_ins);
auto output_type = TupleType(shape_func_outs);
reporter->Assign(types[1], input_type);
reporter->Assign(types[2], output_type);
reporter->Assign(types[3], TupleTypeNode::make({}));
reporter->Assign(types[3], TupleType::Empty());
return true;
}
......
......@@ -586,7 +586,7 @@ bool DropoutRel(const Array<Type>& types,
// dropout returns the original tensor with dropout applied
// and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>({ret_type, ret_type})));
reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
return true;
}
......@@ -674,7 +674,7 @@ bool BatchNormRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
fields.push_back(vec_ty);
fields.push_back(vec_ty);
reporter->Assign(types[5], TupleTypeNode::make(Array<Type>(fields)));
reporter->Assign(types[5], TupleType(Array<Type>(fields)));
return true;
}
......
......@@ -109,7 +109,7 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(output_types)));
reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
return true;
}
......
......@@ -2150,7 +2150,7 @@ bool SplitRel(const Array<Type>& types,
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
} else {
auto indices = param->indices_or_sections.as<ArrayNode>()->data;
auto begin = IndexExpr(make_zero(DataType::Int(32)));
......@@ -2170,7 +2170,7 @@ bool SplitRel(const Array<Type>& types,
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
}
return true;
}
......
......@@ -125,7 +125,7 @@ bool MultiBoxTransformLocRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32)));
// assign output type
reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(fields)));
reporter->Assign(types[3], TupleType(Array<Type>(fields)));
return true;
}
......
......@@ -44,7 +44,7 @@ bool GetValidCountRel(const Array<Type>& types,
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
return true;
}
......
......@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind);
TypeVar ret = TypeVar(tv->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
......
......@@ -42,7 +42,7 @@ class TypeVarReplacer : public TypeMutator {
Type VisitType_(const TypeVarNode* type_var_node) final {
const auto type_var = GetRef<TypeVar>(type_var_node);
if (replace_map_.find(type_var) == replace_map_.end()) {
replace_map_[type_var] = TypeVarNode::make("A", Kind::kType);
replace_map_[type_var] = TypeVar("A", Kind::kType);
}
return replace_map_[type_var];
}
......@@ -109,7 +109,7 @@ class EtaExpander : public ExprMutator {
type_params.push_back(type_var_replacer_.VisitType(type_var));
}
Expr body = CallNode::make(cons, params, Attrs());
Type ret_type = TypeCallNode::make(cons->belong_to, type_params);
Type ret_type = TypeCall(cons->belong_to, type_params);
return FunctionNode::make(
Downcast<tvm::Array<Var>>(params),
......
......@@ -73,10 +73,10 @@ Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
return FuncTypeNode::make(ty->arg_types,
TupleTypeNode::make({
return FuncType(ty->arg_types,
TupleType({
ty->ret_type,
TupleTypeNode::make(ty->arg_types)}), {}, {});
TupleType(ty->arg_types)}), {}, {});
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
......@@ -219,7 +219,7 @@ Type GradRetType(const Function& f) {
vt.push_back(p->type_annotation);
}
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
return TupleType({f->ret_type, TupleType(vt)});
}
Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
......@@ -265,7 +265,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)});
return TupleType({t, RefTypeNode::make(t)});
}
};
......@@ -299,7 +299,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
types.push_back(field->checked_type_);
}
auto ret = TupleNode::make(fields);
ret->checked_type_ = TupleTypeNode::make(types);
ret->checked_type_ = TupleType(types);
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
......@@ -385,7 +385,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {});
return RefCreateNode::make(unitF);
}
......@@ -426,7 +426,7 @@ struct ReverseAD : ExprMutator {
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
TupleType::Empty(),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
......@@ -468,7 +468,7 @@ struct ReverseAD : ExprMutator {
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
TupleType::Empty(),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
......
......@@ -63,7 +63,7 @@ namespace relay {
// we assume the data type has no closure - no idea how to look into datatype right now.
Type Arrow(const Type& l, const Type& r) {
return FuncTypeNode::make({l}, r, {}, {});
return FuncType({l}, r, {}, {});
}
Type CPSType(const Type& t, const TypeVar& answer);
......@@ -74,7 +74,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) {
new_arg_types.push_back(CPSType(t, answer));
}
new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer));
return FuncTypeNode::make(new_arg_types, answer, f->type_params, f->type_constraints);
return FuncType(new_arg_types, answer, f->type_params, f->type_constraints);
}
Type CPSType(const Type& t, const TypeVar& answer) {
......@@ -302,7 +302,7 @@ Function ToCPS(const Function& f,
}
Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
TypeVar answer = TypeVarNode::make("answer", kType);
TypeVar answer = TypeVar("answer", kType);
VarMap var;
struct Remapper : ExprVisitor, PatternVisitor {
Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { }
......@@ -348,7 +348,7 @@ Function UnCPS(const Function& f) {
auto new_ret_type = Type(cont_type->arg_types[0]);
std::vector<TypeVar> new_type_params;
for (const auto& tp : f->type_params) {
new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind));
new_type_params.push_back(TypeVar(tp->name_hint, tp->kind));
}
auto answer_type = new_type_params.back();
new_type_params.pop_back();
......
......@@ -206,7 +206,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (Expr field : op->fields) {
types.push_back(GetType(field));
}
return TupleTypeNode::make(types);
return TupleType(types);
}
Type VisitExpr_(const TupleGetItemNode* op) final {
......@@ -218,7 +218,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type rtype = IncompleteTypeNode::make(Kind::kType);
auto attrs = make_object<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelationNode::make(
solver_.AddConstraint(TypeRelation(
tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef<TupleGetItem>(op));
return rtype;
}
......@@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (size_t i = 0; i < td->type_vars.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args);
Type expected = TypeCall(con->constructor->belong_to, unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(con));
auto* tc = unified.as<TypeCallNode>();
......@@ -277,7 +277,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (size_t i = 0; i < tup->patterns.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
}
Type expected = TupleTypeNode::make(unknown_args);
Type expected = TupleType(unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(tup));
auto* tt = unified.as<TupleTypeNode>();
......@@ -388,7 +388,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type rtype = IncompleteTypeNode::make(Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelationNode::make(
solver_.AddConstraint(TypeRelation(
rel->func, arg_types, arg_types.size() - 1, attrs), loc);
return rtype;
}
......@@ -418,7 +418,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
ret_type = IncompleteTypeNode::make(Kind::kType);
}
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
Type inst_ty = FuncType(fn_ty->arg_types,
ret_type, {},
fn_ty->type_constraints);
inst_ty = Bind(inst_ty, subst_map);
......@@ -467,7 +467,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// with an unknown return type
if (inc_ty_node != nullptr) {
Type ret_type = IncompleteTypeNode::make(Kind::kType);
Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {});
Type func_type = FuncType(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
fn_ty_node = unified.as<FuncTypeNode>();
}
......@@ -513,7 +513,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (auto cs : fn_ty->type_constraints) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs),
TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
GetRef<Call>(call));
} else {
solver_.AddConstraint(cs, GetRef<Call>(call));
......@@ -557,7 +557,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f));
}
CHECK(rtype.defined());
auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
auto ret = FuncType(arg_types, rtype, f->type_params, {});
return solver_.Resolve(ret);
}
......@@ -575,7 +575,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type it = IncompleteTypeNode::make(Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleTypeNode::make({});
return TupleType::Empty();
}
Type VisitExpr_(const ConstructorNode* c) final {
......@@ -587,7 +587,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (const auto & t : td->type_vars) {
types.push_back(t);
}
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
return FuncType(c->inputs, TypeCall(c->belong_to, types),
td->type_vars, {});
}
......
......@@ -286,7 +286,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
Type field = Unify(tt1->fields[i], tt2->fields[i]);
new_fields.push_back(field);
}
return TupleTypeNode::make(new_fields);
return TupleType(new_fields);
}
Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
......@@ -314,7 +314,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
}
FuncType ft = FuncTypeNode::make(op->arg_types,
FuncType ft = FuncType(op->arg_types,
op->ret_type,
ft_type_params,
op->type_constraints);
......@@ -339,7 +339,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
type_constraints.push_back(GetRef<TypeConstraint>(tcn));
}
return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints);
return FuncType(arg_types, ret_type, ft2->type_params, type_constraints);
}
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
......@@ -361,7 +361,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
for (size_t i = 0; i < op->args.size(); i++) {
args.push_back(Unify(op->args[i], tcn->args[i]));
}
return TypeCallNode::make(func, args);
return TypeCall(func, args);
}
private:
......
......@@ -37,7 +37,7 @@ TEST(Relay, SelfReference) {
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment