Unverified Commit 2c0c1849 by Tianqi Chen Committed by GitHub

[REFACTOR][TYPE] Finish move all types to IR. (#4746)

* [REFACTOR][TYPE] Finish move all types to IR.

- Move definition of Ref and TensorType to ir
- Move type_functor.h to public header.
- Rename RefType -> RelayRefType for clarity.

* Add atol
parent ee0af843
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/tensor_type.h
* \brief Polymorphic tensor types.
*/
#ifndef TVM_IR_TENSOR_TYPE_H_
#define TVM_IR_TENSOR_TYPE_H_
#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
namespace tvm {
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
* \sa BaseTensorType, TensorTypeNode
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};
/*!
* \brief Managed reference to BaseTensorTypeNode.
* \sa BaseTensorTypeNode.
*/
class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorType
*/
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by PrimExpr(tvm::Expr).
*/
Array<PrimExpr> shape;
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL PrimExpr Size() const;
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};
/*!
* \brief Managed reference to TensorTypeNode.
* \sa TensorTypeNode.
*/
class TensorType : public Type {
public:
/*!
* \brief Constructor.
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);
/*!
* \brief Construct an scalar containing elements of dtype.
* \param dtype The runtime dtype of the tensor's elements.
* \return THe constructed type.
*/
TVM_DLL static TensorType Scalar(DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;
} // namespace tvm
#endif // TVM_IR_TENSOR_TYPE_H_
......@@ -352,5 +352,75 @@ class FuncType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
/*!
* \brief Intermediate values that is used to indicate incomplete type
* during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*
* \sa IncompleteType
*/
class IncompleteTypeNode : public TypeNode {
public:
/*! \brief kind of the type. */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
/*!
* \brief Managed reference to IncompleteTypeNode.
* \sa IncompleteTypeNode
*/
class IncompleteType : public Type {
public:
/*!
* \brief Constructor.
* \param kind kind of the type.
*/
TVM_DLL explicit IncompleteType(TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
/*!
* \brief Reference Type High-level Relay IR.
*
* \sa RelayRefType.
*/
class RelayRefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;
RelayRefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
/*!
* \brief Managed reference to RelayRefTypeNode.
* \sa RelayRefTypeNode.
*/
class RelayRefType : public Type {
public:
TVM_DLL explicit RelayRefType(Type value);
TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_H_
......@@ -18,11 +18,11 @@
*/
/*!
* \file type_functor.h
* \file tvm/ir/type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#ifndef TVM_IR_TYPE_FUNCTOR_H_
#define TVM_IR_TYPE_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <tvm/relay/expr.h>
......@@ -32,17 +32,16 @@
#include <utility>
namespace tvm {
namespace relay {
template <typename FType>
class TypeFunctor;
// functions to be overriden.
#define TYPE_FUNCTOR_DEFAULT \
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.get()), \
......@@ -89,10 +88,11 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
......@@ -103,25 +103,29 @@ class TypeFunctor<R(const Type& n, Args...)> {
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
return vtable;
}
};
#undef TVM_TYPE_FUNCTOR_DISPATCH
/*!
* \brief A type visitor that recursively visit types.
*/
class TypeVisitor : public TypeFunctor<void(const Type& n)> {
class TVM_DLL TypeVisitor :
public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
......@@ -129,14 +133,18 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
void VisitType_(const RelayRefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override;
};
// Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> {
/*!
* \brief TypeMutator that mutates expressions.
*/
class TVM_DLL TypeMutator :
public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
......@@ -145,10 +153,11 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
Type VisitType_(const RelayRefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override;
private:
Array<Type> MutateArray(Array<Type> arr);
......@@ -161,6 +170,5 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
*/
Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
#endif // TVM_IR_TYPE_FUNCTOR_H_
......@@ -25,6 +25,7 @@
#define TVM_RELAY_TYPE_H_
#include <tvm/ir/type.h>
#include <tvm/ir/tensor_type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
......@@ -54,6 +55,12 @@ using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
using IncompleteType = tvm::IncompleteType;
using IncompleteTypeNode = tvm::IncompleteTypeNode;
using RelayRefType = tvm::RelayRefType;
using RelayRefTypeNode = tvm::RelayRefTypeNode;
using TensorType = tvm::TensorType;
using TensorTypeNode = tvm::TensorTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation;
......@@ -62,136 +69,6 @@ using TypeRelationFn = tvm::TypeRelationFn;
using TypeReporter = tvm::TypeReporter;
using TypeReporterNode = tvm::TypeReporterNode;
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};
class BaseTensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
};
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorTypeNode The container class of TensorType.
*/
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by IndexExpr(tvm::Expr).
*/
Array<IndexExpr> shape;
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL IndexExpr Size() const;
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};
class TensorType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
};
/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static IncompleteType make(Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
class IncompleteType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;
RefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}
TVM_DLL static RefType make(Type value);
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode);
};
class RefType : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
};
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TYPE_H_
......@@ -317,8 +317,8 @@ class Object {
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr);
template <typename RelayRefType, typename ObjectType>
inline RelayRefType GetRef(const ObjectType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
......@@ -484,8 +484,8 @@ class ObjectPtr {
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMArgValue;
template <typename RefType, typename ObjType>
friend RefType GetRef(const ObjType* ptr);
template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};
......@@ -848,11 +848,11 @@ inline const ObjectType* ObjectRef::as() const {
}
}
template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
template <typename RelayRefType, typename ObjType>
inline RelayRefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RelayRefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
return RelayRefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
template <typename BaseType, typename ObjType>
......
......@@ -18,35 +18,35 @@
*/
/*!
* \file src/tvm/ir/type.cc
* \file src/tvm/ir/tensor_type.cc
* \brief The type system AST nodes of Relay.
*/
#include <tvm/relay/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/tensor_type.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace relay {
using tvm::NodePrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
ObjectPtr<TensorTypeNode> n = make_object<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
return TensorType(n);
data_ = std::move(n);
}
TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
TensorType TensorType::Scalar(DataType dtype) {
return TensorType({}, dtype);
}
IndexExpr TensorTypeNode::Size() const {
PrimExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
return tir::make_const(DataType::Int(64), 1);
}
IndexExpr size = shape[0];
PrimExpr size = shape[0];
for (size_t i = 1; i < shape.size(); ++i) {
size *= shape[i];
}
......@@ -56,7 +56,9 @@ IndexExpr TensorTypeNode::Size() const {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TensorType")
.set_body_typed(TensorTypeNode::make);
.set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
return TensorType(shape, dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
......@@ -64,45 +66,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
IncompleteType IncompleteTypeNode::make(Kind kind) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
}
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
.set_body_typed([](int kind) {
return IncompleteTypeNode::make(static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
RefType RefTypeNode::make(Type value) {
ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>();
n->value = std::move(value);
return RefType(n);
}
TVM_REGISTER_GLOBAL("relay._make.RefType")
.set_body_typed(RefTypeNode::make);
TVM_REGISTER_NODE_TYPE(RefTypeNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefTypeNode*>(ref.get());
p->stream << "RefTypeNode(" << node->value << ")";
});
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });
} // namespace relay
} // namespace tvm
......@@ -118,6 +118,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< node->type_constraints << ")";
});
TupleType::TupleType(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
......@@ -141,4 +142,44 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "TupleTypeNode(" << node->fields << ")";
});
IncompleteType::IncompleteType(TypeKind kind) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
.set_body_typed([](int kind) {
return IncompleteType(static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
RelayRefType::RelayRefType(Type value) {
ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
n->value = std::move(value);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.RefType")
.set_body_typed([](Type value) {
return RelayRefType(value);
});
TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
p->stream << "RelayRefTypeNode(" << node->value << ")";
});
} // namespace tvm
......@@ -21,11 +21,10 @@
* \file type_functor.cc
* \brief Implementations of type functors.
*/
#include <tvm/ir/type_functor.h>
#include <utility>
#include "type_functor.h"
namespace tvm {
namespace relay {
void TypeVisitor::VisitType_(const TypeVarNode* op) {
}
......@@ -57,7 +56,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) {
}
}
void TypeVisitor::VisitType_(const RefTypeNode* op) {
void TypeVisitor::VisitType_(const RelayRefTypeNode* op) {
this->VisitType(op->value);
}
......@@ -91,6 +90,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
}
}
void TypeVisitor::VisitType_(const PrimTypeNode* op) {
}
Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
}
......@@ -169,8 +171,8 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
}
}
Type TypeMutator::VisitType_(const RefTypeNode* op) {
return RefTypeNode::make(this->VisitType(op->value));
Type TypeMutator::VisitType_(const RelayRefTypeNode* op) {
return RelayRefType(this->VisitType(op->value));
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
......@@ -203,6 +205,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) {
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const PrimTypeNode* op) {
return GetRef<Type>(op);
}
// Implements bind.
class TypeBinder : public TypeMutator {
public:
......@@ -227,5 +233,4 @@ Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return TypeBinder(args_map).VisitType(type);
}
} // namespace relay
} // namespace tvm
......@@ -21,8 +21,7 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <tvm/ir/type_functor.h>
#include <tvm/top/schedule.h>
#include <tvm/top/operation.h>
#include <tvm/top/schedule_pass.h>
......@@ -42,7 +41,8 @@
#include <functional>
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace tvm {
namespace relay {
......@@ -239,12 +239,12 @@ class ScheduleGetter :
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
......
......@@ -529,7 +529,7 @@ class Interpreter :
if (is_dyn) {
auto sh = out_shapes[i];
auto tt = Downcast<TensorType>(rtype->fields[i]);
fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
fields.push_back(fset_output(i, TensorType(sh, tt->dtype)));
} else {
fields.push_back(fset_output(i, rtype->fields[i]));
}
......@@ -542,7 +542,7 @@ class Interpreter :
CHECK_EQ(out_shapes.size(), 1);
auto sh = out_shapes[0];
auto tt = Downcast<TensorType>(ret_type);
out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
out_tensor = fset_output(0, TensorType(sh, tt->dtype));
} else {
out_tensor = fset_output(0, ret_type);
}
......
......@@ -21,6 +21,7 @@
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
......@@ -28,7 +29,6 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
......@@ -277,8 +277,8 @@ class AlphaEqualHandler:
}
}
bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
return TypeEqual(lhs->value, rhs->value);
}
return false;
......
......@@ -59,7 +59,7 @@ TensorType ConstantNode::tensor_type() const {
tvm::IntImm(DataType::Int(32), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
return TensorType(shape, dtype);
}
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
......@@ -129,12 +129,12 @@ FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteTypeNode::make(Kind::kType);
: IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteTypeNode::make(Kind::kType);
: IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
......@@ -359,5 +359,8 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
return FunctionSetAttr(func, name, ref);
});
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });
} // namespace relay
} // namespace tvm
......@@ -24,10 +24,10 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
namespace tvm {
namespace relay {
......
......@@ -21,13 +21,13 @@
* \file src/tvm/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/attrs.h>
#include "type_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
......@@ -201,8 +201,8 @@ class RelayHashHandler:
return hash;
}
size_t VisitType_(const RefTypeNode* rtn) final {
size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
size_t VisitType_(const RelayRefTypeNode* rtn) final {
size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
hash = Combine(hash, TypeHash(rtn->value));
return hash;
}
......
......@@ -30,13 +30,12 @@
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../ir/attr_functor.h"
......@@ -779,7 +778,7 @@ class PrettyPrinter :
return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
}
Doc VisitType_(const RefTypeNode* node) final {
Doc VisitType_(const RelayRefTypeNode* node) final {
Doc doc;
return doc << "ref(" << Print(node->value) << ")";
}
......
......@@ -43,7 +43,7 @@ bool ArgsortRel(const Array<Type>& types,
<< types[0];
return false;
}
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
reporter->Assign(types[1], TensorType(data->shape, param->dtype));
return true;
}
......
......@@ -52,8 +52,8 @@ bool TopKRel(const Array<Type>& types,
out_shape.push_back(param->k);
}
}
auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
auto values_ty = TensorType(out_shape, data->dtype);
auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
......
......@@ -60,7 +60,7 @@ bool ResizeRel(const Array<Type>& types,
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
TensorType(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
......@@ -143,7 +143,7 @@ bool CropAndResizeRel(const Array<Type>& types,
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
reporter->Assign(types[3],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
TensorType(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
......
......@@ -154,11 +154,11 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
for (auto i = 0u; i < dims; i++) {
out_shape.push_back(tvm::Integer(sh[i]));
}
alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype);
alloc_type = TensorType(out_shape, alloc_attrs->dtype);
} else {
CHECK(alloc_attrs->assert_shape.defined())
<< "the assert_shape must be set when const_shape is not";
alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype);
alloc_type = TensorType(alloc_attrs->assert_shape, alloc_attrs->dtype);
return true;
}
......@@ -309,13 +309,13 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
shape_func_ins.push_back(in_type);
} else {
auto shape = RankShape(in_type->shape);
shape_func_ins.push_back(TensorTypeNode::make(shape, DataType::Int(64)));
shape_func_ins.push_back(TensorType(shape, DataType::Int(64)));
}
}
for (auto out_type : out_types) {
auto rank_shape = RankShape(out_type->shape);
shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64)));
shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64)));
}
auto input_type = TupleType(shape_func_ins);
......
......@@ -81,7 +81,7 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
out_shape.push_back(bits);
}
reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type));
reporter->Assign(types[1], TensorType(out_shape, pack_type));
return true;
}
......@@ -144,7 +144,7 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
DataType out_dtype = param->out_dtype;
oshape = trans_in_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -220,7 +220,7 @@ bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}
// Assign output type.
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -271,7 +271,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -310,7 +310,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -434,7 +434,7 @@ bool Conv1DTransposeRel(const Array<Type>& types,
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -469,7 +469,7 @@ bool Conv1DTransposeRel(const Array<Type>& types,
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -616,7 +616,7 @@ bool Conv2DWinogradRel(const Array<Type>& types,
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -702,7 +702,7 @@ bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
data->shape[1],
};
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
......@@ -817,7 +817,7 @@ bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
return true;
}
......@@ -1025,7 +1025,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[2], TensorTypeNode::make(wshape, data->dtype));
reporter->Assign(types[2], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -1066,12 +1066,12 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// infer offset shape
Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
oshape[2], oshape[3]});
reporter->Assign(types[1], TensorTypeNode::make(offset_shape, data->dtype));
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -81,7 +81,7 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -117,7 +117,7 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -179,7 +179,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -226,7 +226,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......@@ -290,7 +290,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
......@@ -346,7 +346,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -61,7 +61,7 @@ bool BiasAddRel(const Array<Type>& types,
<< "axis " << param->axis << " is out of range";
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(
reporter->Assign(types[1], TensorType(
{data->shape[axis]}, data->dtype));
reporter->Assign(types[2], types[0]);
return true;
......@@ -138,7 +138,7 @@ bool FIFOBufferRel(const Array<Type>& types,
Array<tvm::PrimExpr> oshape = buffer->shape;
reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
reporter->Assign(types[2], TensorType(oshape, buffer->dtype));
return true;
}
......@@ -260,10 +260,10 @@ bool PReluRel(const Array<Type>& types,
// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));
reporter->Assign(types[1], TensorType(alpha_shape, data->dtype));
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}
......@@ -419,7 +419,7 @@ bool BatchFlattenRel(const Array<Type>& types,
std::vector<IndexExpr> oshape({data->shape[0], target_dim});
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -585,7 +585,7 @@ bool DropoutRel(const Array<Type>& types,
// dropout returns the original tensor with dropout applied
// and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
auto ret_type = TensorType(data->shape, data->dtype);
reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
return true;
}
......@@ -661,17 +661,17 @@ bool BatchNormRel(const Array<Type>& types,
auto axis_size = data->shape[axis];
// if we are using beta and gamma, they need to be of shape (dim,)
reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype));
reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype));
reporter->Assign(types[1], TensorType({axis_size}, data->dtype));
reporter->Assign(types[2], TensorType({axis_size}, data->dtype));
reporter->Assign(types[3], TensorType({axis_size}, data->dtype));
reporter->Assign(types[4], TensorType({axis_size}, data->dtype));
// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
std::vector<Type> fields;
auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}),
data->dtype);
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
fields.push_back(TensorType(data->shape, data->dtype));
fields.push_back(vec_ty);
fields.push_back(vec_ty);
reporter->Assign(types[5], TupleType(Array<Type>(fields)));
......@@ -754,9 +754,9 @@ bool InstanceNormRel(const Array<Type>& types,
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
......@@ -824,9 +824,9 @@ bool LayerNormRel(const Array<Type>& types,
const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
......@@ -881,7 +881,7 @@ bool BatchMatmulRel(const Array<Type>& types,
oshape.Set(2, y->shape[1]);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
reporter->Assign(types[2], TensorType(oshape, x->dtype));
return true;
}
......@@ -940,7 +940,7 @@ bool CrossEntropyRel(const Array<Type>& types,
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
// assign output type
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
reporter->Assign(types[2], TensorType({}, x->dtype));
return true;
}
......@@ -1016,7 +1016,7 @@ bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
// Assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
......@@ -1074,7 +1074,7 @@ bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attr
// Assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
......
......@@ -52,7 +52,7 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// data dtype as the weight dtype. However if weight dtype is explicitly
// present we will use that.
auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
reporter->Assign(types[1], TensorType(wshape, weight_dtype));
oshape.Set((oshape.size() - 1), param->units);
} else {
if (weight == nullptr) return false;
......@@ -70,7 +70,7 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -155,7 +155,7 @@ bool PadRel(const Array<Type>& types,
}
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
......@@ -260,7 +260,7 @@ bool MirrorPadRel(const Array<Type>& types,
oshape.push_back(data->shape[i] + padding);
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
......
......@@ -161,7 +161,7 @@ bool Pool2DRel(const Array<Type>& types,
}
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -327,7 +327,7 @@ bool GlobalPool2DRel(const Array<Type>& types,
oshape.Set(widx, 1);
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -462,7 +462,7 @@ bool AdaptivePool2DRel(const Array<Type>& types,
oshape.Set(widx, output_width);
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -792,7 +792,7 @@ bool Pool1DRel(const Array<Type>& types,
}
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -987,7 +987,7 @@ bool Pool3DRel(const Array<Type>& types,
}
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......
......@@ -47,7 +47,7 @@ bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (weight_data->shape.size() == 1) {
// CSR case.
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
......@@ -56,7 +56,7 @@ bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
Array<IndexExpr> oshape({
data->shape[0],
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
......@@ -105,9 +105,9 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
const auto* sparse_indptr = types[2].as<TensorTypeNode>();
std::vector<Type> output_types;
output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
output_types.push_back(TensorType(sparse_data->shape, sparse_data->dtype));
output_types.push_back(TensorType(sparse_indices->shape, sparse_indices->dtype));
output_types.push_back(TensorType(sparse_indptr->shape, sparse_indptr->dtype));
reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
return true;
......
......@@ -87,7 +87,7 @@ bool UpSamplingRel(const Array<Type>& types,
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
TensorType(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......@@ -167,7 +167,7 @@ bool UpSampling3DRel(const Array<Type>& types,
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
TensorType(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
......
......@@ -272,7 +272,7 @@ bool ArgReduceRel(const Array<Type>& types,
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Int(32)));
reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
return true;
}
......@@ -297,7 +297,7 @@ bool ReduceRel(const Array<Type>& types,
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -594,7 +594,7 @@ bool VarianceRel(const Array<Type>& types,
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
......
......@@ -119,7 +119,7 @@ bool ConcatenateRel(const Array<Type>& types,
concat_dim = Any::make();
}
auto rtype = TensorTypeNode::make(oshape, dtype);
auto rtype = TensorType(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}
......
......@@ -286,7 +286,7 @@ bool ShapeOfRel(const Array<Type>& types,
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}
......@@ -337,7 +337,7 @@ bool NdarraySizeRel(const Array<Type>& types,
CHECK(tt != nullptr);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
reporter->Assign(types[1], TensorType({1}, param->dtype));
return true;
}
......
......@@ -96,7 +96,7 @@ Type ConcreteBroadcast(const TensorType& t1,
for (; i <= max_ndim; ++i) {
oshape.push_back(rshape[max_ndim - i]);
}
return TensorTypeNode::make(Array<IndexExpr>(
return TensorType(Array<IndexExpr>(
oshape.rbegin(), oshape.rend()), output_dtype);
}
......
......@@ -50,7 +50,7 @@ bool MultiboxPriorRel(const Array<Type>& types,
{1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......@@ -122,8 +122,8 @@ bool MultiBoxTransformLocRel(const Array<Type>& types,
std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
std::vector<IndexExpr> oshape1({cls_shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype));
fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32)));
fields.push_back(TensorType(oshape0, cls_prob->dtype));
fields.push_back(TensorType(oshape1, DataType::Int(32)));
// assign output type
reporter->Assign(types[3], TupleType(Array<Type>(fields)));
......
......@@ -40,8 +40,8 @@ bool GetValidCountRel(const Array<Type>& types,
std::vector<IndexExpr> oshape({data->shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape, DataType::Int(32)));
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
fields.push_back(TensorType(oshape, DataType::Int(32)));
fields.push_back(TensorType(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
......@@ -95,9 +95,9 @@ bool NMSRel(const Array<Type>& types,
// assign output type
if (param->return_indices) {
std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, DataType::Int(32)));
reporter->Assign(types[2], TensorType(oshape, DataType::Int(32)));
} else {
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
reporter->Assign(types[2], TensorType(dshape, data->dtype));
}
return true;
}
......
......@@ -45,7 +45,7 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
......@@ -96,7 +96,7 @@ bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
......@@ -155,7 +155,7 @@ bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
std::vector<IndexExpr> oshape(
{batch * proposal_attrs->rpn_post_nms_top_n, 5});
reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype));
return true;
}
......
......@@ -56,7 +56,7 @@ bool YoloReorgRel(const Array<Type>& types,
oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = indexdiv(oshape[2], param->stride);
oshape[3] = indexdiv(oshape[3], param->stride);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
......
......@@ -22,11 +22,10 @@
* \file de_duplicate.cc
* \brief Use a fresh Id for every Var to make the result well-formed.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......
......@@ -23,10 +23,10 @@
* \brief Add an abstraction over constructors and/or global variables bound to a function.
*
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......
......@@ -21,7 +21,7 @@
* \file ad.cc
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/top/operation.h>
#include <tvm/relay/expr_functor.h>
......@@ -30,7 +30,6 @@
#include "pattern_util.h"
#include "pass_util.h"
#include "let_list.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -265,7 +264,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RefTypeNode::make(t)});
return TupleType({t, RelayRefType(t)});
}
};
......
......@@ -31,9 +31,9 @@
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/error.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -107,9 +107,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
return Kind::kType;
}
Kind VisitType_(const RefTypeNode* op) override {
Kind VisitType_(const RelayRefTypeNode* op) override {
// ref types should only contain normal types
RefType rt = GetRef<RefType>(op);
RelayRefType rt = GetRef<RelayRefType>(op);
CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
return Kind::kType;
}
......
......@@ -89,12 +89,12 @@
*
* These assumptions do not affect the correctness of the algorithm, however.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include "../ir/type_functor.h"
#include "pass_util.h"
#include "let_list.h"
......@@ -863,7 +863,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
subst.Set(func->type_params[i], IncompleteType(kType));
}
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
} else {
......
......@@ -48,9 +48,9 @@ bool SimulatedQuantizeRel(const Array<Type>& types,
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
reporter->Assign(types[1], TensorTypeNode::make({}, DataType::Float(32))); // dom_scale
reporter->Assign(types[2], TensorTypeNode::make({}, DataType::Float(32))); // clip_min
reporter->Assign(types[3], TensorTypeNode::make({}, DataType::Float(32))); // clip_max
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
reporter->Assign(types[4], types[0]); // output
return true;
}
......
......@@ -50,10 +50,10 @@
* All cases in the transform must return via the mcont,
* wheter directly invoking it, or indirectly by recursion.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h"
#include "let_list.h"
#include "pass_util.h"
......
......@@ -37,7 +37,7 @@
* If we can not infer a type or there are conflicting typing
* constraints we will trigger an error.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
......@@ -45,7 +45,6 @@
#include <tvm/relay/transform.h>
#include "./pass_util.h"
#include "type_solver.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -180,7 +179,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
if (op->type_annotation.defined()) {
return op->type_annotation;
} else {
return IncompleteTypeNode::make(Kind::kType);
return IncompleteType(Kind::kType);
}
}
......@@ -215,7 +214,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
}
Type tuple_type = GetType(op->tuple);
Type rtype = IncompleteTypeNode::make(Kind::kType);
Type rtype = IncompleteType(Kind::kType);
auto attrs = make_object<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelation(
......@@ -233,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < td->type_vars.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
unknown_args.push_back(IncompleteType(Kind::kType));
}
Type expected = TypeCall(con->constructor->belong_to, unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(con));
......@@ -275,7 +274,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < tup->patterns.size(); i++) {
unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
unknown_args.push_back(IncompleteType(Kind::kType));
}
Type expected = TupleType(unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(tup));
......@@ -302,7 +301,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (const auto& c : op->clauses) {
VisitPattern(c->lhs, dtype);
}
Type rtype = IncompleteTypeNode::make(Kind::kType);
Type rtype = IncompleteType(Kind::kType);
for (const auto& c : op->clauses) {
rtype = this->Unify(rtype,
GetType(c->rhs),
......@@ -336,7 +335,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteTypeNode::make(Kind::kType);
Type let_type = IncompleteType(Kind::kType);
if (is_functional_literal) {
let_type = GetType(let->var);
......@@ -362,7 +361,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// that is a rank-0 boolean tensor.
Type cond_type = this->GetType(ite->cond);
this->Unify(cond_type,
TensorTypeNode::Scalar(tvm::DataType::Bool()),
TensorType::Scalar(tvm::DataType::Bool()),
ite->cond);
Type checked_true = this->GetType(ite->true_branch);
Type checked_false = this->GetType(ite->false_branch);
......@@ -385,7 +384,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
for (size_t i = 0; i < op->type_params.size(); ++i) {
if (!op->type_params[i].same_as(rel->args[i])) return Type();
}
Type rtype = IncompleteTypeNode::make(Kind::kType);
Type rtype = IncompleteType(Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelation(
......@@ -404,7 +403,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType));
subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
}
Type ret_type = fn_ty->ret_type;
......@@ -415,7 +414,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// This is a temporary work around to check recursive functions whose
// return type is not yet known.
if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(Kind::kType);
ret_type = IncompleteType(Kind::kType);
}
Type inst_ty = FuncType(fn_ty->arg_types,
......@@ -433,7 +432,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Array<Type> type_args;
for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
type_args.push_back(IncompleteTypeNode::make(Kind::kType));
type_args.push_back(IncompleteType(Kind::kType));
}
return InstantiateFuncType(fn_ty, type_args);
}
......@@ -466,7 +465,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// incomplete type => it must be a function taking the arg types
// with an unknown return type
if (inc_ty_node != nullptr) {
Type ret_type = IncompleteTypeNode::make(Kind::kType);
Type ret_type = IncompleteType(Kind::kType);
Type func_type = FuncType(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
fn_ty_node = unified.as<FuncTypeNode>();
......@@ -562,18 +561,18 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
Type VisitExpr_(const RefCreateNode* op) final {
return RefTypeNode::make(GetType(op->value));
return RelayRefType(GetType(op->value));
}
Type VisitExpr_(const RefReadNode* op) final {
Type it = IncompleteTypeNode::make(Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
Type it = IncompleteType(Kind::kType);
this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefRead>(op));
return it;
}
Type VisitExpr_(const RefWriteNode* op) final {
Type it = IncompleteTypeNode::make(Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
Type it = IncompleteType(Kind::kType);
this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleType::Empty();
}
......
......@@ -21,13 +21,13 @@
* \file type_solver.cc
* \brief Type solver implementations.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/op.h>
#include <string>
#include <memory>
#include <tuple>
#include <utility>
#include "type_solver.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -270,7 +270,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return Type(nullptr);
}
return TensorTypeNode::make(shape, tt1->dtype);
return TensorType(shape, tt1->dtype);
}
Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
......@@ -312,7 +312,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
}
for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
subst_map.Set(op->type_params[i], IncompleteType(kType));
}
FuncType ft = FuncType(op->arg_types,
......@@ -343,12 +343,12 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return FuncType(arg_types, ret_type, ft2->type_params, type_constraints);
}
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
const auto* rtn = tn.as<RefTypeNode>();
Type VisitType_(const RelayRefTypeNode* op, const Type& tn) final {
const auto* rtn = tn.as<RelayRefTypeNode>();
if (!rtn) {
return Type(nullptr);
}
return RefTypeNode::make(Unify(op->value, rtn->value));
return RelayRefType(Unify(op->value, rtn->value));
}
Type VisitType_(const TypeCallNode* op, const Type& tn) override {
......@@ -690,7 +690,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
Expr e = VarNode::make("dummy_var",
IncompleteTypeNode::make(Kind::kType));
IncompleteType(Kind::kType));
return solver->AddConstraint(c, e);
});
} else {
......
......@@ -23,12 +23,12 @@
*
* \brief Utility functions for Relay.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......
......@@ -52,7 +52,7 @@ bool DequantizeRel(const Array<Type>& types,
const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type, output will always be float 32.
reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
return true;
}
......
......@@ -63,7 +63,7 @@ bool QuantizeRel(const Array<Type>& types,
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -197,7 +197,7 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype));
reporter->Assign(types[5], TensorType(oshape, out_dtype));
return true;
}
......
......@@ -176,7 +176,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons
const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) {
reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype));
reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype));
}
}
......
......@@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("test.sch")
TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, DataType::Float(32));
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
......
......@@ -26,7 +26,7 @@
TEST(Relay, SelfReference) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool());
auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
......
......@@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("schedule")
TEST(Relay, Sequential) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, DataType::Float(32));
auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
auto c_data =
tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
......@@ -51,7 +51,7 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue*
TEST(MicroStandaloneRuntime, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
......
......@@ -793,7 +793,7 @@ def test_forward_layer_norm():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify((2, 5))
verify((2, 5), axis=0)
verify((2, 5, 6))
......@@ -809,7 +809,7 @@ def test_forward_one_hot():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x.astype("float32"))
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify((3,), 3, 1, 0, "int32")
verify((3,), 3, 1.0, 0.0, "float32")
verify((2, 2), 5, 2, -2, "int32")
......@@ -898,7 +898,7 @@ def test_forward_deconvolution():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, weight, bias)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment