Unverified Commit 12e51e6c by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Initialize Unified IR Expr Data Structure (#4673)

This PR moves a few base types from relay and low-level Expr into the ir sub-folder.
These classes will serve as a common type system across the stack.

Rationale:

- PrimExpr for low-level expressions
- RelayExpr for advanced features, including Function definition.
- Introduce BaseFunc to host all functions, including future PrimFunc(low-level expr functions, subject to discussion).

This is a minimum change we can do to unify the classes into a common hierarchy.
The main data structure that are variant specific will still be kept in the sub-namespaces.
We only include classes that is needed to allow a common Module class.
- BaseFunc
- GlobalVar
- Type definition part of ADT

We will only need the BaseFunc and their checked_type to decide the calling convention
across the function variants.
parent 86092de0
......@@ -24,6 +24,7 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <tvm/ir/expr.h>
#include <string>
#include <algorithm>
#include <unordered_map>
......@@ -37,58 +38,6 @@
namespace tvm {
/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;
static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
};
/*!
* \brief Container of all primitive expressions.
* \sa PrimExprNode
*/
class PrimExpr : public ObjectRef {
public:
PrimExpr() {}
explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
using ContainerType = PrimExprNode;
};
/*! \brief Base node of all statements. */
class StmtNode : public Object {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
* for decripting structured data.
*/
#ifndef TVM_IR_ADT_H_
#define TVM_IR_ADT_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <string>
namespace tvm {
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
* \sa Constructor
*/
class ConstructorNode : public RelayExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;
ConstructorNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
/*!
* \brief Managed reference to ConstructorNode
* \sa ConstructorNode
*/
class Constructor : public RelayExpr {
public:
/*!
* \brief Constructor
* \param name_hint the name of the constructor.
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
TVM_DLL Constructor(std::string name_hint,
Array<Type> inputs,
GlobalTypeVar belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
Array<TypeVar> type_vars;
/*! \brief The constructors. */
Array<Constructor> constructors;
void VisitAttrs(AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData : public Type {
public:
/*!
* \brief Constructor
* \param header the name of ADT.
* \param type_vars type variables.
* \param constructors constructors field.
*/
TVM_DLL TypeData(GlobalTypeVar header,
Array<TypeVar> type_vars,
Array<Constructor> constructors);
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
} // namespace tvm
#endif // TVM_IR_ADT_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/expr.h
* \brief Base expr nodes in TVM.
*/
#ifndef TVM_IR_EXPR_H_
#define TVM_IR_EXPR_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
#include <string>
namespace tvm {
/*!
* \brief Base type of all the expressions.
* \sa Expr
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
/*!
* \brief Managed reference to BaseExprNode.
* \sa BaseExprNode
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
};
/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public BaseExprNode {
public:
/*!
* \brief The runtime data type of the primitive expression.
*
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* PrimExpr expression construction and can be used for
* quick type checking.
*
* dtype is sufficient to decide the Type of the PrimExpr
* when it corresponds to POD value types such as i32.
*
* When dtype is DataType::Handle(), the expression could corresponds to
* a more fine-grained Type, and we can get the type by running lazy type inference.
*/
DataType dtype;
static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};
/*!
* \brief Reference to PrimExprNode.
* \sa PrimExprNode
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;
};
/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
* first class citizens. The life-cycle of the corresponding
* objects are implicitly managed by the language.
*
* \sa RelayExpr
*/
class RelayExprNode : public BaseExprNode {
public:
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};
/*!
* \brief Managed reference to RelayExprNode.
* \sa RelayExprNode
*/
class RelayExpr : public BaseExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode);
};
class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
*
* \sa GlobalVarNode
*/
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
/*!
* \brief Managed reference to GlobalVarNode.
* \sa GlobalVarNode
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(std::string name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
// implementataions
inline const Type& RelayExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<RelayExpr>(this);
return this->checked_type_;
}
template<typename TTypeNode>
inline const TTypeNode* RelayExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
} // namespace tvm
#endif // TVM_IR_EXPR_H_
......@@ -28,7 +28,7 @@
*
* ## Relation between Type and runtime::DataType
*
* Besides Type, we also store a dtype field in some of the low-level IR's Expr.
* Besides Type, we also store a dtype field in the low-level PrimExpr.
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* low-level expression construction and can be used for
......
......@@ -25,6 +25,7 @@
#define TVM_RELAY_ADT_H_
#include <tvm/attrs.h>
#include <tvm/ir/adt.h>
#include <string>
#include <functional>
#include "./base.h"
......@@ -34,6 +35,12 @@
namespace tvm {
namespace relay {
using Constructor = tvm::Constructor;
using ConstructorNode = tvm::ConstructorNode;
using TypeData = tvm::TypeData;
using TypeDataNode = tvm::TypeDataNode;
/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
......@@ -105,47 +112,6 @@ class PatternVar : public Pattern {
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;
ConstructorNode() {}
TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode);
};
class Constructor : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode);
};
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
......@@ -201,53 +167,6 @@ class PatternTuple : public Pattern {
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
class TypeData : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
......@@ -306,7 +225,7 @@ class MatchNode : public ExprNode {
class Match : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode);
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};
} // namespace relay
......
......@@ -25,6 +25,7 @@
#define TVM_RELAY_EXPR_H_
#include <tvm/attrs.h>
#include <tvm/ir/expr.h>
#include <string>
#include <functional>
#include "./base.h"
......@@ -33,47 +34,12 @@
namespace tvm {
namespace relay {
/*!
* \brief A Relay expression.
*/
class Expr;
/*!
* \brief Base type of the Relay expression hiearchy.
*/
class ExprNode : public RelayNode {
public:
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode);
};
class Expr : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode);
};
using Expr = tvm::RelayExpr;
using ExprNode = tvm::RelayExprNode;
using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
......@@ -112,7 +78,7 @@ class ConstantNode : public ExprNode {
class Constant : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode);
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
/*! \brief Tuple of multiple Exprs */
......@@ -137,7 +103,7 @@ class TupleNode : public ExprNode {
class Tuple : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
};
/*!
......@@ -193,37 +159,7 @@ class VarNode : public ExprNode {
class Var : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
};
/*!
* \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function.
*
* \note A GlobalVar may only point to functions.
*/
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static GlobalVar make(std::string name_hint);
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode);
};
class GlobalVar : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode);
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
/*!
......@@ -231,7 +167,7 @@ class GlobalVar : public Expr {
*/
class Function;
/*! \brief Function container */
class FunctionNode : public ExprNode {
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
......@@ -312,12 +248,12 @@ class FunctionNode : public ExprNode {
tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
class Function : public Expr {
class Function : public BaseFunc {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
};
......@@ -388,7 +324,7 @@ class CallNode : public ExprNode {
class Call : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode);
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
};
/*!
......@@ -429,7 +365,7 @@ class LetNode : public ExprNode {
class Let : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode);
TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
};
/*!
......@@ -470,7 +406,7 @@ class IfNode : public ExprNode {
class If : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
};
/*! \brief Get index-th field out of a tuple. */
......@@ -497,7 +433,7 @@ class TupleGetItemNode : public ExprNode {
class TupleGetItem : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode);
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
};
/*! \brief Create a new Reference out of initial value. */
......@@ -521,7 +457,7 @@ class RefCreateNode : public ExprNode {
class RefCreate : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
};
/*! \brief Get value out of Reference. */
......@@ -545,7 +481,7 @@ class RefReadNode : public ExprNode {
class RefRead : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
};
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
......@@ -571,7 +507,7 @@ class RefWriteNode : public ExprNode {
class RefWrite : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
};
/*!
......@@ -600,32 +536,9 @@ class TempExprNode : public ExprNode {
class TempExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode);
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};
// implementataions
inline const Type& ExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<Expr>(this);
return this->checked_type_;
}
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node);
......
......@@ -25,7 +25,7 @@
#define TVM_RELAY_FEATURE_H_
#include <tvm/node/container.h>
#include <tvm/expr.h>
#include <tvm/relay/expr.h>
#include <bitset>
namespace tvm {
......@@ -132,7 +132,6 @@ class FeatureSet {
explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
};
class Expr;
/*!
* \brief Calculate the feature of the program.
*
......@@ -140,7 +139,7 @@ class Expr;
*
* \return The FeatureSet.
*/
FeatureSet DetectFeature(const Expr& expr);
FeatureSet DetectFeature(const RelayExpr& expr);
struct Module;
/*!
......
......@@ -140,7 +140,7 @@ class Op : public relay::Expr {
/*! \brief default constructor */
Op() {}
/*! \brief constructor from node pointer */
explicit Op(ObjectPtr<Object> n) : Expr(n) {}
explicit Op(ObjectPtr<Object> n) : RelayExpr(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......
......@@ -650,6 +650,7 @@ struct ObjectEqual {
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static const uint32_t RuntimeTypeIndex() { \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/adt.cc
* \brief ADT type definitions.
*/
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
namespace tvm {
Constructor::Constructor(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed([](std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
return Constructor(name_hint, inputs, belong_to);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData::TypeData(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed([](GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
return TypeData(header, type_vars, constructors);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
namespace tvm {
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
.set_body_typed([](std::string name){
return GlobalVar(name);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
} // namespace tvm
......@@ -628,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode {
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
auto gv = GlobalVar(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
......
......@@ -101,7 +101,7 @@ class LambdaLifter : public ExprMutator {
}
auto name = GenerateName(func);
auto global = GlobalVarNode::make(name);
auto global = GlobalVar(name);
auto free_vars = FreeVars(func);
auto free_type_vars = FreeTypeVars(func, module_);
......
......@@ -96,50 +96,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
return Constructor(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData TypeDataNode::make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
return TypeData(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
Clause ClauseNode::make(Pattern lhs, Expr rhs) {
ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
n->lhs = std::move(lhs);
......
......@@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span")
.set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn);
rn->span = sp;
} else if (auto* rn = node_ref.as<RelayExprNode>()) {
rn->span = sp;
} else if (auto* rn = node_ref.as<TypeNode>()) {
rn->span = sp;
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/expr.cc
* \file src/tvm/relay/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/relay/expr.h>
......@@ -109,24 +109,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")";
});
GlobalVar GlobalVarNode::make(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
return GlobalVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
.set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
......
......@@ -279,7 +279,7 @@ Module ModuleNode::FromExpr(
} else {
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVarNode::make("main");
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
return mod;
}
......
......@@ -213,7 +213,7 @@ class ConstantFolder : public ExprMutator {
{},
module_->type_definitions,
module_->Imports());
auto global = GlobalVarNode::make("main");
auto global = GlobalVar("main");
mod->Add(global, func);
auto seq = transform::Sequential(passes);
mod = seq(mod);
......
......@@ -155,7 +155,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op);
if (cm->count(gv) == 0) {
auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps");
auto cps_gv = GlobalVar(gv->name_hint + "_cps");
cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm));
}
......
......@@ -662,7 +662,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter();
auto module = ModuleNode::make({}, {});
auto dummy_fn_name = GlobalVarNode::make("test");
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
......
......@@ -29,7 +29,7 @@ TEST(Relay, SelfReference) {
auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment