Commit b0b51f25 by Tianqi Chen Committed by Zhi

[REFACTOR][IR] attrs.h -> ir (#4709)

This PR moves attrs.h into the ir folder as it
can serve as a common infra for building ir dats structures.

We also moved common container(FloatImm) into ir/expr.h
parent 83da72f2
...@@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { ...@@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high)); return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
} }
} }
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value)); if (t.is_float()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the // For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format // datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype. // specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough? // TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) { if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value)); return FloatImm(t, static_cast<double>(value));
} }
LOG(FATAL) << "cannot make const for type " << t; LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr(); return PrimExpr();
......
...@@ -37,25 +37,9 @@ namespace tvm { ...@@ -37,25 +37,9 @@ namespace tvm {
namespace ir { namespace ir {
using IntImmNode = tvm::IntImmNode; using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode; using VarNode = tvm::VarNode;
/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
TVM_DLL static PrimExpr make(DataType t, double value);
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*! \brief String constants, only used in asserts. */ /*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode { class StringImmNode : public PrimExprNode {
public: public:
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
/*! /*!
* \file tvm/attrs.h * \file tvm/ir/attrs.h
* \brief TVM attribute module * \brief Helpers for attribute objects.
* *
* This module enables declaration of named attributes * This module enables declaration of named attributes
* which support default value setup and bound checking. * which support default value setup and bound checking.
...@@ -42,20 +41,19 @@ ...@@ -42,20 +41,19 @@
* *
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
*/ */
#ifndef TVM_ATTRS_H_ #ifndef TVM_IR_ATTRS_H_
#define TVM_ATTRS_H_ #define TVM_IR_ATTRS_H_
#include <dmlc/common.h> #include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <type_traits> #include <type_traits>
#include <string> #include <string>
#include <utility> #include <utility>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"
namespace tvm { namespace tvm {
/*! /*!
...@@ -63,10 +61,10 @@ namespace tvm { ...@@ -63,10 +61,10 @@ namespace tvm {
* \param ClassName The name of the class. * \param ClassName The name of the class.
* \param TypeKey The type key to be used by the TVM node system. * \param TypeKey The type key to be used by the TVM node system.
*/ */
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \ static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \ template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
...@@ -481,45 +479,36 @@ template<typename T> ...@@ -481,45 +479,36 @@ template<typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) { inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T(); *ptr = val.operator T();
} }
template<typename T> template<typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) { inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) { if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64); *ptr = static_cast<T>(val.value().v_int64);
} else { } else {
PrimExpr expr = val; IntImm expr = val;
CHECK(expr.defined()); *ptr = static_cast<T>(expr->value);
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
} }
} }
template<> template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) { if (val.type_code() == kStr) {
*ptr = val.operator std::string(); *ptr = val.operator std::string();
} else { } else {
PrimExpr expr = val; LOG(FATAL) << "Expect str";
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
} }
} }
template<>
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}
template<> template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) { inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double(); *ptr = val.operator double();
} else { } else {
PrimExpr expr = val; ObjectRef expr = val;
CHECK(expr.defined()); CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) { if (const IntImmNode* op = expr.as<IntImmNode>()) {
*ptr = static_cast<double>(op->value); *ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) { } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
*ptr = static_cast<double>(op->value); *ptr = static_cast<double>(op->value);
} else { } else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
...@@ -611,7 +600,7 @@ struct TypeName<uint64_t> { ...@@ -611,7 +600,7 @@ struct TypeName<uint64_t> {
template<> template<>
struct TypeName<DataType> { struct TypeName<DataType> {
static constexpr const char* value = "Type"; static constexpr const char* value = "DataType";
}; };
template<> template<>
...@@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) ...@@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
} }
} // namespace tvm } // namespace tvm
#endif // TVM_ATTRS_H_ #endif // TVM_IR_ATTRS_H_
...@@ -182,6 +182,56 @@ class IntImm : public PrimExpr { ...@@ -182,6 +182,56 @@ class IntImm : public PrimExpr {
}; };
/*! /*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
};
/*!
* \brief Base node of all non-primitive expressions. * \brief Base node of all non-primitive expressions.
* *
* RelayExpr supports tensor types, functions and ADT as * RelayExpr supports tensor types, functions and ADT as
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#define TVM_IR_OP_H_ #define TVM_IR_OP_H_
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
...@@ -296,7 +296,8 @@ class OpRegistry { ...@@ -296,7 +296,8 @@ class OpRegistry {
// return internal pointer to op. // return internal pointer to op.
inline OpNode* get(); inline OpNode* get();
// update the attribute OpMap // update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, TVM_DLL void UpdateAttr(const std::string& key,
runtime::TVMRetValue value,
int plevel); int plevel);
}; };
...@@ -316,7 +317,7 @@ class GenericOpMap { ...@@ -316,7 +317,7 @@ class GenericOpMap {
* \param op The key to the map * \param op The key to the map
* \return the const reference to the content value. * \return the const reference to the content value.
*/ */
inline const TVMRetValue& operator[](const Op& op) const; inline const runtime::TVMRetValue& operator[](const Op& op) const;
/*! /*!
* \brief get the corresponding value element at op with default value. * \brief get the corresponding value element at op with default value.
* \param op The key to the map * \param op The key to the map
...@@ -342,7 +343,7 @@ class GenericOpMap { ...@@ -342,7 +343,7 @@ class GenericOpMap {
// the attribute field. // the attribute field.
std::string attr_name_; std::string attr_name_;
// internal data // internal data
std::vector<std::pair<TVMRetValue, int> > data_; std::vector<std::pair<runtime::TVMRetValue, int> > data_;
// The value // The value
GenericOpMap() = default; GenericOpMap() = default;
}; };
...@@ -532,7 +533,7 @@ template <typename ValueType> ...@@ -532,7 +533,7 @@ template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) { const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv; runtime::TVMRetValue rv;
rv = value; rv = value;
UpdateAttr(attr_name, rv, plevel); UpdateAttr(attr_name, rv, plevel);
return *this; return *this;
...@@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const { ...@@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const {
} }
} }
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { inline const runtime::TVMRetValue&
GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined()); CHECK(op.defined());
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0) CHECK(idx < data_.size() && data_[idx].second != 0)
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ADT_H_ #ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h> #include <tvm/ir/adt.h>
#include <string> #include <string>
#include <functional> #include <functional>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_ #ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
#define TVM_RELAY_ATTRS_ALGORITHM_H_ #define TVM_RELAY_ATTRS_ALGORITHM_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_ #ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_ #define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_ #ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
#define TVM_RELAY_ATTRS_BITSERIAL_H_ #define TVM_RELAY_ATTRS_BITSERIAL_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEBUG_H_ #ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_ #define TVM_RELAY_ATTRS_DEBUG_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_ #ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_IMAGE_H_ #ifndef TVM_RELAY_ATTRS_IMAGE_H_
#define TVM_RELAY_ATTRS_IMAGE_H_ #define TVM_RELAY_ATTRS_IMAGE_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_MEMORY_H_ #ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_ #define TVM_RELAY_ATTRS_MEMORY_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_NN_H_ #ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_ #define TVM_RELAY_ATTRS_NN_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_REDUCE_H_ #ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_ #define TVM_RELAY_ATTRS_REDUCE_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_ #ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_ #define TVM_RELAY_ATTRS_TRANSFORM_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_VISION_H_ #ifndef TVM_RELAY_ATTRS_VISION_H_
#define TVM_RELAY_ATTRS_VISION_H_ #define TVM_RELAY_ATTRS_VISION_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
#include <string> #include <string>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_EXPR_H_ #ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <string> #include <string>
#include <functional> #include <functional>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_QNN_ATTRS_H_ #ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_ #define TVM_RELAY_QNN_ATTRS_H_
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h> #include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
...@@ -34,7 +35,7 @@ ...@@ -34,7 +35,7 @@
#include <string> #include <string>
#include "base.h" #include "base.h"
#include "../attrs.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") ...@@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
REGISTER_MAKE(Reduce); REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt); REGISTER_MAKE(AttrStmt);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm); REGISTER_MAKE(StringImm);
REGISTER_MAKE(Add); REGISTER_MAKE(Add);
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
......
...@@ -102,7 +102,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) { ...@@ -102,7 +102,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b; if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a; if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return b; if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a; if (fb && fb->value == 0) return a;
}); });
...@@ -115,7 +115,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) { ...@@ -115,7 +115,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a; if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a; if (fb && fb->value == 0) return a;
}); });
return PrimExpr(); return PrimExpr();
...@@ -134,7 +134,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) { ...@@ -134,7 +134,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a; if (pb->value == 1) return a;
if (pb->value == 0) return b; if (pb->value == 0) return b;
} }
if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value); if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa) { if (fa) {
if (fa->value == 1) return b; if (fa->value == 1) return b;
if (fa->value == 0) return a; if (fa->value == 0) return a;
...@@ -165,7 +165,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) { ...@@ -165,7 +165,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
CHECK_NE(pb->value, 0) << "Divide by zero"; CHECK_NE(pb->value, 0) << "Divide by zero";
} }
if (fa && fb && fb->value != 0) { if (fa && fb && fb->value != 0) {
return FloatImmNode::make(rtype, fa->value / fb->value); return FloatImm(rtype, fa->value / fb->value);
} }
if (fa && fa->value == 0) return a; if (fa && fa->value == 0) return a;
if (fb) { if (fb) {
...@@ -210,7 +210,7 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) { ...@@ -210,7 +210,7 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
CHECK_NE(pb->value, 0) << "Divide by zero"; CHECK_NE(pb->value, 0) << "Divide by zero";
} }
if (fa && fb && fb->value != 0) { if (fa && fb && fb->value != 0) {
return FloatImmNode::make(rtype, std::floor(fa->value / fb->value)); return FloatImm(rtype, std::floor(fa->value / fb->value));
} }
if (fa && fa->value == 0) return a; if (fa && fa->value == 0) return a;
if (fb) { if (fb) {
...@@ -244,7 +244,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) { ...@@ -244,7 +244,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
}); });
if (a.same_as(b)) return a; if (a.same_as(b)) return a;
return PrimExpr(); return PrimExpr();
...@@ -255,7 +255,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) { ...@@ -255,7 +255,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
}); });
if (a.same_as(b)) return a; if (a.same_as(b)) return a;
return PrimExpr(); return PrimExpr();
......
...@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > ...@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var}); feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
Array<PrimExpr> attr{std::string("_attr_"), Array<PrimExpr> attr{std::string("_attr_"),
FloatImmNode::make(DataType::Float(32), trans(fea.length)), FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level), IntImm(DataType::Int(32), fea.nest_level),
FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), FloatImm(DataType::Float(32), trans(fea.topdown_product)),
FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
}; };
// one hot annotation // one hot annotation
for (int i = 0; i < kNum; i++) { for (int i = 0; i < kNum; i++) {
...@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > ...@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
// arithmetic // arithmetic
feature_row.push_back(Array<PrimExpr>{std::string("_arith_"), feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)), FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)), FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)), FloatImm(DataType::Float(32), trans(fea.div_ct)),
}); });
// touch map // touch map
...@@ -283,12 +283,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > ...@@ -283,12 +283,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
TouchPattern &v = fea.touch_feature[k]; TouchPattern &v = fea.touch_feature[k];
feature_row.push_back( feature_row.push_back(
Array<PrimExpr>{k, Array<PrimExpr>{k,
FloatImmNode::make(DataType::Float(32), trans(v.stride)), FloatImm(DataType::Float(32), trans(v.stride)),
FloatImmNode::make(DataType::Float(32), trans(v.mod)), FloatImm(DataType::Float(32), trans(v.mod)),
FloatImmNode::make(DataType::Float(32), trans(v.count)), FloatImm(DataType::Float(32), trans(v.count)),
FloatImmNode::make(DataType::Float(32), trans(v.reuse)), FloatImm(DataType::Float(32), trans(v.reuse)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_count)), FloatImm(DataType::Float(32), trans(v.thread_count)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)), FloatImm(DataType::Float(32), trans(v.thread_reuse)),
}); });
} }
......
...@@ -95,7 +95,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ...@@ -95,7 +95,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
ir::CallNode::PureIntrinsic)), ir::CallNode::PureIntrinsic)),
MakeValue( MakeValue(
ir::BroadcastNode::make( ir::BroadcastNode::make(
ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), FloatImm(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
}); });
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
* - array of attributes * - array of attributes
* - map of attributes * - map of attributes
*/ */
#ifndef TVM_LANG_ATTR_FUNCTOR_H_ #ifndef TVM_IR_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_ #define TVM_IR_ATTR_FUNCTOR_H_
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <utility> #include <utility>
...@@ -230,4 +230,4 @@ class AttrsHashHandler : ...@@ -230,4 +230,4 @@ class AttrsHashHandler :
} }
}; };
} // namespace tvm } // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_ #endif // TVM_IR_ATTR_FUNCTOR_H_
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
/*! /*!
* \file attrs.cc * \file attrs.cc
*/ */
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
......
...@@ -45,6 +45,22 @@ TVM_REGISTER_GLOBAL("make.IntImm") ...@@ -45,6 +45,22 @@ TVM_REGISTER_GLOBAL("make.IntImm")
return IntImm(dtype, value); return IntImm(dtype, value);
}); });
FloatImm::FloatImm(DataType dtype, double value) {
CHECK_EQ(dtype.lanes(), 1)
<< "ValueError: FloatImm can only take scalar.";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("make.FloatImm")
.set_body_typed([](DataType dtype, double value) {
return FloatImm(dtype, value);
});
GlobalVar::GlobalVar(std::string name_hint) { GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>(); ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint); n->name_hint = std::move(name_hint);
......
...@@ -36,6 +36,10 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); ...@@ -36,6 +36,10 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
namespace tvm { namespace tvm {
using runtime::TVMRetValue;
using runtime::TVMArgs;
using runtime::PackedFunc;
::dmlc::Registry<OpRegistry>* OpRegistry::Registry() { ::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
return ::dmlc::Registry<OpRegistry>::Get(); return ::dmlc::Registry<OpRegistry>::Get();
} }
......
...@@ -33,7 +33,7 @@ PrimExpr::PrimExpr(int32_t value) ...@@ -33,7 +33,7 @@ PrimExpr::PrimExpr(int32_t value)
: PrimExpr(IntImm(DataType::Int(32), value)) {} : PrimExpr(IntImm(DataType::Int(32), value)) {}
PrimExpr::PrimExpr(float value) PrimExpr::PrimExpr(float value)
: PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} : PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(std::string str) PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {} : PrimExpr(ir::StringImmNode::make(str)) {}
......
...@@ -108,11 +108,11 @@ PrimExpr max_value(const DataType& dtype) { ...@@ -108,11 +108,11 @@ PrimExpr max_value(const DataType& dtype) {
} }
} else if (dtype.is_float()) { } else if (dtype.is_float()) {
if (dtype.bits() == 64) { if (dtype.bits() == 64) {
return FloatImmNode::make(dtype, std::numeric_limits<double>::max()); return FloatImm(dtype, std::numeric_limits<double>::max());
} else if (dtype.bits() == 32) { } else if (dtype.bits() == 32) {
return FloatImmNode::make(dtype, std::numeric_limits<float>::max()); return FloatImm(dtype, std::numeric_limits<float>::max());
} else if (dtype.bits() == 16) { } else if (dtype.bits() == 16) {
return FloatImmNode::make(dtype, 65504.0); return FloatImm(dtype, 65504.0);
} }
} }
LOG(FATAL) << "Cannot decide max_value for type" << dtype; LOG(FATAL) << "Cannot decide max_value for type" << dtype;
...@@ -134,11 +134,11 @@ PrimExpr min_value(const DataType& dtype) { ...@@ -134,11 +134,11 @@ PrimExpr min_value(const DataType& dtype) {
return IntImm(dtype, 0); return IntImm(dtype, 0);
} else if (dtype.is_float()) { } else if (dtype.is_float()) {
if (dtype.bits() == 64) { if (dtype.bits() == 64) {
return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest()); return FloatImm(dtype, std::numeric_limits<double>::lowest());
} else if (dtype.bits() == 32) { } else if (dtype.bits() == 32) {
return FloatImmNode::make(dtype, std::numeric_limits<float>::lowest()); return FloatImm(dtype, std::numeric_limits<float>::lowest());
} else if (dtype.bits() == 16) { } else if (dtype.bits() == 16) {
return FloatImmNode::make(dtype, -65504.0); return FloatImm(dtype, -65504.0);
} }
} }
LOG(FATAL) << "Cannot decide min_value for type" << dtype; LOG(FATAL) << "Cannot decide min_value for type" << dtype;
...@@ -219,7 +219,7 @@ PrimExpr operator-(PrimExpr a) { ...@@ -219,7 +219,7 @@ PrimExpr operator-(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>(); const IntImmNode* pa = a.as<IntImmNode>();
const FloatImmNode* fa = a.as<FloatImmNode>(); const FloatImmNode* fa = a.as<FloatImmNode>();
if (pa) return IntImm(a.dtype(), -pa->value); if (pa) return IntImm(a.dtype(), -pa->value);
if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); if (fa) return FloatImm(a.dtype(), -fa->value);
return make_zero(a.dtype()) - a; return make_zero(a.dtype()) - a;
} }
...@@ -492,7 +492,7 @@ PrimExpr abs(PrimExpr x) { ...@@ -492,7 +492,7 @@ PrimExpr abs(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) { if (fx) {
return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value)); return FloatImm(x.dtype(), std::fabs(fx->value));
} }
return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic);
} else if (x.dtype().is_uint()) { } else if (x.dtype().is_uint()) {
...@@ -593,28 +593,28 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) { ...@@ -593,28 +593,28 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
PrimExpr floor(PrimExpr x) { PrimExpr floor(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value)); if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
} }
PrimExpr ceil(PrimExpr x) { PrimExpr ceil(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value)); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
} }
PrimExpr round(PrimExpr x) { PrimExpr round(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
} }
PrimExpr nearbyint(PrimExpr x) { PrimExpr nearbyint(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
} }
...@@ -622,7 +622,7 @@ PrimExpr trunc(PrimExpr x) { ...@@ -622,7 +622,7 @@ PrimExpr trunc(PrimExpr x) {
using ir::FloatImmNode; using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>(); const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) { if (fx) {
return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value))); std::floor(fx->value)));
} }
return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic); return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic);
......
...@@ -32,7 +32,7 @@ namespace ir { ...@@ -32,7 +32,7 @@ namespace ir {
// constructors // constructors
PrimExpr FloatImmNode::make(DataType t, double value) { PrimExpr FloatImm(DataType t, double value) {
CHECK_EQ(t.lanes(), 1) CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar"; << "ValueError: FloatImm can only take scalar";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>(); ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
......
...@@ -25,10 +25,14 @@ ...@@ -25,10 +25,14 @@
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/node/reflection.h> #include <tvm/node/reflection.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
namespace tvm { namespace tvm {
using runtime::TVMRetValue;
using runtime::TVMArgs;
using runtime::PackedFunc;
// Attr getter. // Attr getter.
class AttrGetter : public AttrVisitor { class AttrGetter : public AttrVisitor {
public: public:
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/node/reflection.h> #include <tvm/node/reflection.h>
#include <tvm/node/serialization.h> #include <tvm/node/serialization.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
#include <map> #include <map>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_PASS_STORAGE_ACCESS_H_ #define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <vector> #include <vector>
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include "type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../ir/attr_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include "type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../ir/attr_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
#include "doc.h" #include "doc.h"
#include "type_functor.h" #include "type_functor.h"
#include "../pass/dependency_graph.h" #include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h" #include "../../ir/attr_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/expr_operator.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
namespace tvm { namespace tvm {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment