Unverified Commit 1ecd3ee2 by Tianqi Chen Committed by GitHub

[REFACTOR] Unified IR base types. (#4616)

This PR moves a few base types from relay to the ir sub-folder.
These types will serve as a common type system across the stack.

Notably, we want to be able to use the same FuncType for all function signatures.
I tried to make a minimum move to bring the necessary dependencies for a FuncType.
We can discuss what additional things we want to move as a follow-up.

Notably, because the TensorType will have a dependency on low-level Expr,
we will need to break the type.h into two files and introduce a
tensor_type.h(or leave them in relay for now).
parent 24e6fcb6
......@@ -125,6 +125,8 @@ assign_source_group("Include" ${GROUP_INCLUDE})
# Source file lists
file(GLOB COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/api/*.cc
src/arithmetic/*.cc
src/autotvm/*.cc
......@@ -132,7 +134,6 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/node/*.cc
src/schedule/*.cc
)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/span.h
* \brief Span information for debugging purposes.
*/
#ifndef TVM_IR_SPAN_H_
#define TVM_IR_SPAN_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <string>
namespace tvm {
/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
class SourceName;
/*!
* \brief The name of a source fragment.
*/
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
/*!
* \brief The source name of a file span.
* \sa SourceNameNode, Span
*/
class SourceName : public ObjectRef {
public:
/*!
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
* \param name Name of the operator.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static SourceName Get(const std::string& name);
TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
};
/*!
* \brief Span information for debugging purposes
*/
class Span;
/*!
* \brief Stores locations in frontend source that generated a node.
*/
class SpanNode : public Object {
public:
/*! \brief The source name */
SourceName source;
/*! \brief Line number */
int lineno;
/*! \brief column offset */
int col_offset;
// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
} // namespace tvm
#endif // TVM_IR_SPAN_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/type.h
* \brief IR/AST nodes for the unified type system in TVM.
*
* We use Relay's type system as the unified type system
* throughout the stack.
*
* This file contains types that are common across IR variants.
*
* ## Relation between Type and runtime::DataType
*
* Besides Type, we also store a dtype field in some of the low-level IR's Expr.
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* low-level expression construction and can be used for
* quick type checking in the low-level IR.
* For example, when an Expr's dtype is int32,
* we know for sure that its type is also int32.
*
* On the other hand, Type provides more fine grained information.
* For example, a low level expression can have DataType::Handle() as
* its dtype and MemRef[float32] as its type.
* Types are usually lazily constructed via type checking,
* so they may not readily be available during IR construction.
*
* The unified Type serves as a common bridge across IR dialects.
* For example, we require all the functions to have a type signature,
* which allow us to build cross dialect function calls.
*/
#ifndef TVM_IR_TYPE_H_
#define TVM_IR_TYPE_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <string>
namespace tvm {
/*! \brief Base type of all the types. */
class TypeNode : public Object {
public:
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
/*!
* \brief Type is the base type of all types.
*
* Relay's type system contains following two key concepts:
*
* - PrimitiveType: type of primitive type values used in the low-level IR.
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
};
/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
};
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
class TypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*!
* \brief A global type variable that is used for defining new types or type aliases.
*/
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
}
TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
class GlobalTypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
class FuncType;
/*!
* \brief Function type in Relay.
*
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
TVM_DLL static FuncType make(Array<Type> arg_types,
Type ret_type,
Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
class FuncType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_H_
......@@ -25,6 +25,7 @@
#define TVM_RELAY_BASE_H_
#include <tvm/api_registry.h>
#include <tvm/ir/span.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
#include <string>
......@@ -58,88 +59,9 @@ namespace relay {
*/
using IndexExpr = ::tvm::Expr;
/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
class SourceName;
/*!
* \brief The name of a source fragment.
*/
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
/*!
* \brief The source name of a file span.
* \sa SourceNameNode, Span
*/
class SourceName : public ObjectRef {
public:
/*! \brief default constructor */
SourceName() {}
/*! \brief constructor from node pointer */
explicit SourceName(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const {
return static_cast<const SourceNameNode*>(get());
}
/*!
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
* \param name Name of the operator.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static SourceName Get(const std::string& name);
/*! \brief specify container node */
using ContainerType = SourceNameNode;
};
/*!
* \brief Span information for debugging purposes
*/
class Span;
/*!
* \brief Stores locations in frontend source that generated a node.
*/
class SpanNode : public Object {
public:
/*! \brief The source name */
SourceName source;
/*! \brief Line number */
int lineno;
/*! \brief column offset */
int col_offset;
// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
class Span : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};
using SourceName = tvm::SourceName;
using Span = tvm::Span;
using SpanNode = tvm::SpanNode;
/*!
* \brief This is the base node container of all relay structures.
......
......@@ -25,8 +25,8 @@
#define TVM_RELAY_TYPE_H_
#include <tvm/api_registry.h>
#include <tvm/ir/type.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
#include <string>
#include "base.h"
......@@ -36,32 +36,17 @@ namespace tvm {
namespace relay {
using Any = tvm::ir::Any;
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
/*!
* \brief Type is the base type of relay type hiearchy.
*
* Relay's type system contains following two key concepts:
*
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public ObjectRef {
public:
Type() {}
explicit Type(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = TypeNode;
};
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
using TypeVar = tvm::TypeVar;
using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
/*!
* \brief Base of all Tensor types
......@@ -124,90 +109,6 @@ class TensorType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
/*! \brief Possible kinds of Type. */
enum Kind : int {
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
};
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
class TypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*!
* \brief A global type variable that is used for defining new types or type aliases.
*/
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
class GlobalTypeVar : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*!
* \brief Type application.
*/
......@@ -271,70 +172,6 @@ class IncompleteType : public Type {
};
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
class FuncType;
/*!
* \brief Function type in Relay.
*
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
tvm::Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
class FuncType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
/*!
* \brief The type of tuple values.
*/
class TupleType;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file span.cc
* \brief The span data structure.
*/
#include <tvm/ir/span.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
auto sn = source_map.find(name);
if (sn == source_map.end()) {
ObjectPtr<SourceNameNode> n = make_object<SourceNameNode>();
source_map[name] = n;
n->name = std::move(name);
return n;
} else {
return sn->second;
}
}
SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_global_key([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
n->source = std::move(source);
n->lineno = lineno;
n->col_offset = col_offset;
return Span(n);
}
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/type.cc
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ir/type.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
TypeVar TypeVarNode::make(std::string name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return TypeVar(n);
}
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
});
GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->type_constraints = std::move(type_constraints);
return FuncType(n);
}
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
});
} // namespace tvm
......@@ -22,76 +22,26 @@
* \brief The core base types for Relay.
*/
#include <tvm/api_registry.h>
#include <tvm/ir/type.h>
#include <tvm/relay/base.h>
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
auto sn = source_map.find(name);
if (sn == source_map.end()) {
ObjectPtr<SourceNameNode> n = make_object<SourceNameNode>();
source_map[name] = n;
n->name = std::move(name);
return n;
} else {
return sn->second;
}
}
SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_global_key([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
n->source = std::move(source);
n->lineno = lineno;
n->col_offset = col_offset;
return Span(n);
}
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body_typed<void(ObjectRef, Span)>([](ObjectRef node_ref, Span sp) {
auto rn = node_ref.as<RelayNode>();
if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn);
rn->span = sp;
} else if (auto* rn = node_ref.as<TypeNode>()) {
rn->span = sp;
} else {
LOG(FATAL) << "Expect Type or RelayNode ";
}
});
} // namespace relay
......
......@@ -228,11 +228,6 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;
// TODO(tqchen) Introduce TypeVarExpr
// const auto* ty_param = var.as<TypeVarNode>();
// if (ty_param && ty_param->kind == Kind::kShapeVar) {
// hash_map_[ty_param->var] = hash;
// }
return hash;
}
......
......@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->name_hint;
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
......
......@@ -63,48 +63,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeVar TypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return TypeVar(n);
}
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});
GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
ObjectPtr<TypeCallNode> n = make_object<TypeCallNode>();
n->func = std::move(func);
......@@ -143,31 +101,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints) {
ObjectPtr<FuncTypeNode> n = make_object<FuncTypeNode>();
n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->type_constraints = std::move(type_constraints);
return FuncType(n);
}
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
});
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
......
......@@ -38,7 +38,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main");
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected));
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
}
int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment