Unverified Commit 12e51e6c by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Initialize Unified IR Expr Data Structure (#4673)

This PR moves a few base types from relay and low-level Expr into the ir sub-folder.
These classes will serve as a common type system across the stack.

Rationale:

- PrimExpr for low-level expressions
- RelayExpr for advanced features, including Function definition.
- Introduce BaseFunc to host all functions, including future PrimFunc(low-level expr functions, subject to discussion).

This is a minimum change we can do to unify the classes into a common hierarchy.
The main data structure that are variant specific will still be kept in the sub-namespaces.
We only include classes that is needed to allow a common Module class.
- BaseFunc
- GlobalVar
- Type definition part of ADT

We will only need the BaseFunc and their checked_type to decide the calling convention
across the function variants.
parent 86092de0
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef TVM_EXPR_H_ #ifndef TVM_EXPR_H_
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <tvm/ir/expr.h>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
...@@ -37,58 +38,6 @@ ...@@ -37,58 +38,6 @@
namespace tvm { namespace tvm {
/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;
static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
};
/*!
* \brief Container of all primitive expressions.
* \sa PrimExprNode
*/
class PrimExpr : public ObjectRef {
public:
PrimExpr() {}
explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
using ContainerType = PrimExprNode;
};
/*! \brief Base node of all statements. */ /*! \brief Base node of all statements. */
class StmtNode : public Object { class StmtNode : public Object {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
* for decripting structured data.
*/
#ifndef TVM_IR_ADT_H_
#define TVM_IR_ADT_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <string>
namespace tvm {
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
* \sa Constructor
*/
class ConstructorNode : public RelayExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;
ConstructorNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
/*!
* \brief Managed reference to ConstructorNode
* \sa ConstructorNode
*/
class Constructor : public RelayExpr {
public:
/*!
* \brief Constructor
* \param name_hint the name of the constructor.
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
TVM_DLL Constructor(std::string name_hint,
Array<Type> inputs,
GlobalTypeVar belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
Array<TypeVar> type_vars;
/*! \brief The constructors. */
Array<Constructor> constructors;
void VisitAttrs(AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData : public Type {
public:
/*!
* \brief Constructor
* \param header the name of ADT.
* \param type_vars type variables.
* \param constructors constructors field.
*/
TVM_DLL TypeData(GlobalTypeVar header,
Array<TypeVar> type_vars,
Array<Constructor> constructors);
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
} // namespace tvm
#endif // TVM_IR_ADT_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/expr.h
* \brief Base expr nodes in TVM.
*/
#ifndef TVM_IR_EXPR_H_
#define TVM_IR_EXPR_H_
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
#include <string>
namespace tvm {
/*!
* \brief Base type of all the expressions.
* \sa Expr
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
/*!
* \brief Managed reference to BaseExprNode.
* \sa BaseExprNode
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
};
/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public BaseExprNode {
public:
/*!
* \brief The runtime data type of the primitive expression.
*
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* PrimExpr expression construction and can be used for
* quick type checking.
*
* dtype is sufficient to decide the Type of the PrimExpr
* when it corresponds to POD value types such as i32.
*
* When dtype is DataType::Handle(), the expression could corresponds to
* a more fine-grained Type, and we can get the type by running lazy type inference.
*/
DataType dtype;
static constexpr const char* _type_key = "PrimExpr";
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};
/*!
* \brief Reference to PrimExprNode.
* \sa PrimExprNode
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;
};
/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
* first class citizens. The life-cycle of the corresponding
* objects are implicitly managed by the language.
*
* \sa RelayExpr
*/
class RelayExprNode : public BaseExprNode {
public:
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};
/*!
* \brief Managed reference to RelayExprNode.
* \sa RelayExprNode
*/
class RelayExpr : public BaseExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode);
};
class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
*
* \sa GlobalVarNode
*/
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
/*!
* \brief Managed reference to GlobalVarNode.
* \sa GlobalVarNode
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(std::string name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
// implementataions
inline const Type& RelayExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<RelayExpr>(this);
return this->checked_type_;
}
template<typename TTypeNode>
inline const TTypeNode* RelayExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
} // namespace tvm
#endif // TVM_IR_EXPR_H_
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
* *
* ## Relation between Type and runtime::DataType * ## Relation between Type and runtime::DataType
* *
* Besides Type, we also store a dtype field in some of the low-level IR's Expr. * Besides Type, we also store a dtype field in the low-level PrimExpr.
* runtime::DataType(dtype) provides coarse grained type information * runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in * during compile time and runtime. It is eagerly built in
* low-level expression construction and can be used for * low-level expression construction and can be used for
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/ir/adt.h>
#include <string> #include <string>
#include <functional> #include <functional>
#include "./base.h" #include "./base.h"
...@@ -34,6 +35,12 @@ ...@@ -34,6 +35,12 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using Constructor = tvm::Constructor;
using ConstructorNode = tvm::ConstructorNode;
using TypeData = tvm::TypeData;
using TypeDataNode = tvm::TypeDataNode;
/*! \brief Base type for declaring relay pattern. */ /*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode { class PatternNode : public RelayNode {
public: public:
...@@ -105,47 +112,6 @@ class PatternVar : public Pattern { ...@@ -105,47 +112,6 @@ class PatternVar : public Pattern {
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
}; };
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;
ConstructorNode() {}
TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode);
};
class Constructor : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode);
};
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor; class PatternConstructor;
/*! \brief PatternVar container node */ /*! \brief PatternVar container node */
...@@ -201,53 +167,6 @@ class PatternTuple : public Pattern { ...@@ -201,53 +167,6 @@ class PatternTuple : public Pattern {
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
}; };
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
class TypeData : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
/*! \brief A clause in a match expression. */ /*! \brief A clause in a match expression. */
class Clause; class Clause;
/*! \brief Clause container node. */ /*! \brief Clause container node. */
...@@ -306,7 +225,7 @@ class MatchNode : public ExprNode { ...@@ -306,7 +225,7 @@ class MatchNode : public ExprNode {
class Match : public Expr { class Match : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode); TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
}; };
} // namespace relay } // namespace relay
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_
#include <tvm/attrs.h> #include <tvm/attrs.h>
#include <tvm/ir/expr.h>
#include <string> #include <string>
#include <functional> #include <functional>
#include "./base.h" #include "./base.h"
...@@ -33,47 +34,12 @@ ...@@ -33,47 +34,12 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! using Expr = tvm::RelayExpr;
* \brief A Relay expression. using ExprNode = tvm::RelayExprNode;
*/ using BaseFunc = tvm::BaseFunc;
class Expr; using BaseFuncNode = tvm::BaseFuncNode;
/*! using GlobalVar = tvm::GlobalVar;
* \brief Base type of the Relay expression hiearchy. using GlobalVarNode = tvm::GlobalVarNode;
*/
class ExprNode : public RelayNode {
public:
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode);
};
class Expr : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode);
};
/*! /*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device. * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
...@@ -112,7 +78,7 @@ class ConstantNode : public ExprNode { ...@@ -112,7 +78,7 @@ class ConstantNode : public ExprNode {
class Constant : public Expr { class Constant : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
}; };
/*! \brief Tuple of multiple Exprs */ /*! \brief Tuple of multiple Exprs */
...@@ -137,7 +103,7 @@ class TupleNode : public ExprNode { ...@@ -137,7 +103,7 @@ class TupleNode : public ExprNode {
class Tuple : public Expr { class Tuple : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
}; };
/*! /*!
...@@ -193,37 +159,7 @@ class VarNode : public ExprNode { ...@@ -193,37 +159,7 @@ class VarNode : public ExprNode {
class Var : public Expr { class Var : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
/*!
* \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function.
*
* \note A GlobalVar may only point to functions.
*/
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static GlobalVar make(std::string name_hint);
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode);
};
class GlobalVar : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode);
}; };
/*! /*!
...@@ -231,7 +167,7 @@ class GlobalVar : public Expr { ...@@ -231,7 +167,7 @@ class GlobalVar : public Expr {
*/ */
class Function; class Function;
/*! \brief Function container */ /*! \brief Function container */
class FunctionNode : public ExprNode { class FunctionNode : public BaseFuncNode {
public: public:
/*! \brief Function parameters */ /*! \brief Function parameters */
tvm::Array<Var> params; tvm::Array<Var> params;
...@@ -312,12 +248,12 @@ class FunctionNode : public ExprNode { ...@@ -312,12 +248,12 @@ class FunctionNode : public ExprNode {
tvm::Map<Var, Constant> GetParams() const; tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
}; };
class Function : public Expr { class Function : public BaseFunc {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
}; };
...@@ -388,7 +324,7 @@ class CallNode : public ExprNode { ...@@ -388,7 +324,7 @@ class CallNode : public ExprNode {
class Call : public Expr { class Call : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
}; };
/*! /*!
...@@ -429,7 +365,7 @@ class LetNode : public ExprNode { ...@@ -429,7 +365,7 @@ class LetNode : public ExprNode {
class Let : public Expr { class Let : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode); TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
}; };
/*! /*!
...@@ -470,7 +406,7 @@ class IfNode : public ExprNode { ...@@ -470,7 +406,7 @@ class IfNode : public ExprNode {
class If : public Expr { class If : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
}; };
/*! \brief Get index-th field out of a tuple. */ /*! \brief Get index-th field out of a tuple. */
...@@ -497,7 +433,7 @@ class TupleGetItemNode : public ExprNode { ...@@ -497,7 +433,7 @@ class TupleGetItemNode : public ExprNode {
class TupleGetItem : public Expr { class TupleGetItem : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
}; };
/*! \brief Create a new Reference out of initial value. */ /*! \brief Create a new Reference out of initial value. */
...@@ -521,7 +457,7 @@ class RefCreateNode : public ExprNode { ...@@ -521,7 +457,7 @@ class RefCreateNode : public ExprNode {
class RefCreate : public Expr { class RefCreate : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode); TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
}; };
/*! \brief Get value out of Reference. */ /*! \brief Get value out of Reference. */
...@@ -545,7 +481,7 @@ class RefReadNode : public ExprNode { ...@@ -545,7 +481,7 @@ class RefReadNode : public ExprNode {
class RefRead : public Expr { class RefRead : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode); TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
}; };
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite; class RefWrite;
...@@ -571,7 +507,7 @@ class RefWriteNode : public ExprNode { ...@@ -571,7 +507,7 @@ class RefWriteNode : public ExprNode {
class RefWrite : public Expr { class RefWrite : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode); TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
}; };
/*! /*!
...@@ -600,32 +536,9 @@ class TempExprNode : public ExprNode { ...@@ -600,32 +536,9 @@ class TempExprNode : public ExprNode {
class TempExpr : public Expr { class TempExpr : public Expr {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode); TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
}; };
// implementataions
inline const Type& ExprNode::checked_type() const {
CHECK(checked_type_.defined())
<< "internal error: the type checker has "
<< "not populated the checked_type "
<< "field for "
<< GetRef<Expr>(this);
return this->checked_type_;
}
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->GetTypeKey();
return node;
}
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node); std::string PrettyPrint(const ObjectRef& node);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_RELAY_FEATURE_H_ #define TVM_RELAY_FEATURE_H_
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/expr.h> #include <tvm/relay/expr.h>
#include <bitset> #include <bitset>
namespace tvm { namespace tvm {
...@@ -132,7 +132,6 @@ class FeatureSet { ...@@ -132,7 +132,6 @@ class FeatureSet {
explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { } explicit FeatureSet(const std::bitset<feature_count>& bs) : bs_(bs) { }
}; };
class Expr;
/*! /*!
* \brief Calculate the feature of the program. * \brief Calculate the feature of the program.
* *
...@@ -140,7 +139,7 @@ class Expr; ...@@ -140,7 +139,7 @@ class Expr;
* *
* \return The FeatureSet. * \return The FeatureSet.
*/ */
FeatureSet DetectFeature(const Expr& expr); FeatureSet DetectFeature(const RelayExpr& expr);
struct Module; struct Module;
/*! /*!
......
...@@ -140,7 +140,7 @@ class Op : public relay::Expr { ...@@ -140,7 +140,7 @@ class Op : public relay::Expr {
/*! \brief default constructor */ /*! \brief default constructor */
Op() {} Op() {}
/*! \brief constructor from node pointer */ /*! \brief constructor from node pointer */
explicit Op(ObjectPtr<Object> n) : Expr(n) {} explicit Op(ObjectPtr<Object> n) : RelayExpr(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
......
...@@ -650,6 +650,7 @@ struct ObjectEqual { ...@@ -650,6 +650,7 @@ struct ObjectEqual {
* \param ParentType The name of the ParentType * \param ParentType The name of the ParentType
*/ */
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static const uint32_t RuntimeTypeIndex() { \ static const uint32_t RuntimeTypeIndex() { \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \ return TypeName::_type_index; \
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/adt.cc
* \brief ADT type definitions.
*/
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
namespace tvm {
Constructor::Constructor(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed([](std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
return Constructor(name_hint, inputs, belong_to);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData::TypeData(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed([](GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
return TypeData(header, type_vars, constructors);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
namespace tvm {
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
.set_body_typed([](std::string name){
return GlobalVar(name);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
} // namespace tvm
...@@ -628,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -628,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode {
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>(); const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value); auto gv = GlobalVar(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func); ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first); cached_ext_funcs.push_back(it.first);
} }
......
...@@ -101,7 +101,7 @@ class LambdaLifter : public ExprMutator { ...@@ -101,7 +101,7 @@ class LambdaLifter : public ExprMutator {
} }
auto name = GenerateName(func); auto name = GenerateName(func);
auto global = GlobalVarNode::make(name); auto global = GlobalVar(name);
auto free_vars = FreeVars(func); auto free_vars = FreeVars(func);
auto free_type_vars = FreeTypeVars(func, module_); auto free_type_vars = FreeTypeVars(func, module_);
......
...@@ -96,50 +96,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -96,50 +96,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "PatternTupleNode(" << node->patterns << ")"; p->stream << "PatternTupleNode(" << node->patterns << ")";
}); });
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
return Constructor(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData TypeDataNode::make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
ObjectPtr<TypeDataNode> n = make_object<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
return TypeData(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
Clause ClauseNode::make(Pattern lhs, Expr rhs) { Clause ClauseNode::make(Pattern lhs, Expr rhs) {
ObjectPtr<ClauseNode> n = make_object<ClauseNode>(); ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
n->lhs = std::move(lhs); n->lhs = std::move(lhs);
......
...@@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(IdNode); ...@@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span") TVM_REGISTER_GLOBAL("relay._base.set_span")
.set_body_typed([](ObjectRef node_ref, Span sp) { .set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) { if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn); rn->span = sp;
} else if (auto* rn = node_ref.as<RelayExprNode>()) {
rn->span = sp; rn->span = sp;
} else if (auto* rn = node_ref.as<TypeNode>()) { } else if (auto* rn = node_ref.as<TypeNode>()) {
rn->span = sp; rn->span = sp;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/expr.cc * \file src/tvm/relay/ir/expr.cc
* \brief The expression AST nodes of Relay. * \brief The expression AST nodes of Relay.
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -109,24 +109,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -109,24 +109,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")"; p->stream << ")";
}); });
GlobalVar GlobalVarNode::make(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
return GlobalVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
.set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
Function FunctionNode::make(tvm::Array<Var> params, Function FunctionNode::make(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
......
...@@ -279,7 +279,7 @@ Module ModuleNode::FromExpr( ...@@ -279,7 +279,7 @@ Module ModuleNode::FromExpr(
} else { } else {
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
} }
auto main_gv = GlobalVarNode::make("main"); auto main_gv = GlobalVar("main");
mod->Add(main_gv, func); mod->Add(main_gv, func);
return mod; return mod;
} }
......
...@@ -213,7 +213,7 @@ class ConstantFolder : public ExprMutator { ...@@ -213,7 +213,7 @@ class ConstantFolder : public ExprMutator {
{}, {},
module_->type_definitions, module_->type_definitions,
module_->Imports()); module_->Imports());
auto global = GlobalVarNode::make("main"); auto global = GlobalVar("main");
mod->Add(global, func); mod->Add(global, func);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
......
...@@ -155,7 +155,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const ...@@ -155,7 +155,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op); auto gv = GetRef<GlobalVar>(op);
if (cm->count(gv) == 0) { if (cm->count(gv) == 0) {
auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps"); auto cps_gv = GlobalVar(gv->name_hint + "_cps");
cm->insert({gv, cps_gv}); cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm));
} }
......
...@@ -662,7 +662,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") ...@@ -662,7 +662,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
ErrorReporter *err_reporter = new ErrorReporter(); ErrorReporter *err_reporter = new ErrorReporter();
auto module = ModuleNode::make({}, {}); auto module = ModuleNode::make({}, {});
auto dummy_fn_name = GlobalVarNode::make("test"); auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter); auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
......
...@@ -29,7 +29,7 @@ TEST(Relay, SelfReference) { ...@@ -29,7 +29,7 @@ TEST(Relay, SelfReference) {
auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool()); auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type()); auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {}); auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::VarNode::make("y", tensor_type); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y }); auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment