Unverified Commit fc75de9d by Tianqi Chen Committed by GitHub

[RUNTIME][IR] Allow non-nullable ObjectRef, introduce Optional<T>. (#5314)

* [RUNTIME] Allow non-nullable ObjectRef, introduce Optional<T>.

We use ObjectRef and their sub-classes extensively throughout our codebase.
Each of ObjectRef's sub-classes are nullable, which means they can hold nullptr
as their values.

While in some places we need nullptr as an alternative value. The implicit support
for nullptr in all ObjectRef creates additional burdens for the developer
to explicitly check defined in many places of the codebase.

Moreover, it is unclear from the API's intentional point of view whether
we want a nullable object or not-null version(many cases we want the later).

Borrowing existing wisdoms from languages like Rust. We propose to
introduce non-nullable ObjectRef, and Optional<T> container that
represents a nullable variant.

To keep backward compatiblity, we will start by allowing most ObjectRef to be nullable.
However, we should start to use Optional<T> as the type in places where
we know nullable is a requirement. Gradually, we will move most of the ObjectRef
to be non-nullable and use Optional<T> in the nullable cases.

Such explicitness in typing can help reduce the potential problems
in our codebase overall.

Changes in this PR:
- Introduce _type_is_nullable attribute to ObjectRef
- Introduce Optional<T>
- Change String to be non-nullable.
- Change the API of function->GetAttr to return Optional<T>

* Address review comments

* Upgrade all compiler flags to c++14

* Update as per review comment
parent 3df8d560
...@@ -31,7 +31,7 @@ include $(config) ...@@ -31,7 +31,7 @@ include $(config)
APP_ABI ?= all APP_ABI ?= all
APP_STL := c++_shared APP_STL := c++_shared
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif endif
......
...@@ -27,7 +27,7 @@ include $(config) ...@@ -27,7 +27,7 @@ include $(config)
APP_STL := c++_static APP_STL := c++_static
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif endif
...@@ -31,7 +31,7 @@ include $(config) ...@@ -31,7 +31,7 @@ include $(config)
APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips
APP_STL := c++_shared APP_STL := c++_shared
APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1
endif endif
......
...@@ -28,7 +28,7 @@ else ...@@ -28,7 +28,7 @@ else
LINK_PTHREAD= LINK_PTHREAD=
endif endif
PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ PKG_CFLAGS = -std=c++14 -O2 -fPIC -Wall\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\ -I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include -I${TVM_ROOT}/3rdparty/dlpack/include
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
TVM_ROOT=$(shell cd ../..; pwd) TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include -I${TVM_ROOT}/3rdparty/dlpack/include
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# Minimum Makefile for the extension package # Minimum Makefile for the extension package
TVM_ROOT=$(shell cd ../..; pwd) TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include -I${TVM_ROOT}/3rdparty/dlpack/include
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
TVM_ROOT=$(shell cd ../..; pwd) TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\ -I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
* include in your project. * include in your project.
* *
* - Copy this file into your project which depends on tvm runtime. * - Copy this file into your project which depends on tvm runtime.
* - Compile with -std=c++11 * - Compile with -std=c++14
* - Add the following include path * - Add the following include path
* - /path/to/tvm/include/ * - /path/to/tvm/include/
* - /path/to/tvm/3rdparty/dmlc-core/include/ * - /path/to/tvm/3rdparty/dmlc-core/include/
......
...@@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm ...@@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm
TVM_ROOT=$(shell cd ../..; pwd) TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++14 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\ -I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
cmake_minimum_required(VERSION 3.2) cmake_minimum_required(VERSION 3.2)
project(tf_tvmdsoop C CXX) project(tf_tvmdsoop C CXX)
set(TFTVM_COMPILE_FLAGS -std=c++11) set(TFTVM_COMPILE_FLAGS -std=c++14)
set(BUILD_TVMDSOOP_ONLY ON) set(BUILD_TVMDSOOP_ONLY ON)
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)
......
...@@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc ...@@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc
GOPATH=$(CURDIR)/gopath GOPATH=$(CURDIR)/gopath
GOPATHDIR=${GOPATH}/src/${TARGET}/ GOPATHDIR=${GOPATH}/src/${TARGET}/
CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/"
CGO_CXXFLAGS="-std=c++11" CGO_CXXFLAGS="-std=c++14"
CGO_CFLAGS="-I${TVM_BASE}" CGO_CFLAGS="-I${TVM_BASE}"
CGO_LDFLAGS="-ldl -lm" CGO_LDFLAGS="-ldl -lm"
......
...@@ -85,6 +85,8 @@ namespace tvm { ...@@ -85,6 +85,8 @@ namespace tvm {
*/ */
template<typename TObjectRef> template<typename TObjectRef>
inline TObjectRef NullValue() { inline TObjectRef NullValue() {
static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
return TObjectRef(ObjectPtr<Object>(nullptr)); return TObjectRef(ObjectPtr<Object>(nullptr));
} }
......
...@@ -312,6 +312,47 @@ class FloatImm : public PrimExpr { ...@@ -312,6 +312,47 @@ class FloatImm : public PrimExpr {
}; };
/*! /*!
* \brief Boolean constant.
*
* This reference type is useful to add additional compile-time
* type checks and helper functions for Integer equal comparisons.
*/
class Bool : public IntImm {
public:
explicit Bool(bool value)
: IntImm(DataType::Bool(), value) {
}
Bool operator!() const {
return Bool((*this)->value == 0);
}
operator bool() const {
return (*this)->value != 0;
}
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
};
// Overload operators to make sure we have the most fine grained types.
inline Bool operator||(const Bool& a, bool b) {
return Bool(a.operator bool() || b);
}
inline Bool operator||(bool a, const Bool& b) {
return Bool(a || b.operator bool());
}
inline Bool operator||(const Bool& a, const Bool& b) {
return Bool(a.operator bool() || b.operator bool());
}
inline Bool operator&&(const Bool& a, bool b) {
return Bool(a.operator bool() && b);
}
inline Bool operator&&(bool a, const Bool& b) {
return Bool(a && b.operator bool());
}
inline Bool operator&&(const Bool& a, const Bool& b) {
return Bool(a.operator bool() && b.operator bool());
}
/*!
* \brief Container of constant int that adds more constructors. * \brief Container of constant int that adds more constructors.
* *
* This is used to store and automate type check * This is used to store and automate type check
...@@ -340,10 +381,10 @@ class Integer : public IntImm { ...@@ -340,10 +381,10 @@ class Integer : public IntImm {
* \tparam Enum The enum type. * \tparam Enum The enum type.
* \param value The enum value. * \param value The enum value.
*/ */
template<typename ENum, template<typename Enum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) { explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value, static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
"declare enum to be enum int to use visitor"); "declare enum to be enum int to use visitor");
} }
/*! /*!
...@@ -362,6 +403,24 @@ class Integer : public IntImm { ...@@ -362,6 +403,24 @@ class Integer : public IntImm {
<< " Trying to reference a null Integer"; << " Trying to reference a null Integer";
return (*this)->value; return (*this)->value;
} }
// comparators
Bool operator==(int other) const {
if (data_ == nullptr) return Bool(false);
return Bool((*this)->value == other);
}
Bool operator!=(int other) const {
return !(*this == other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator==(Enum other) const {
return *this == static_cast<int>(other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator!=(Enum other) const {
return *this != static_cast<int>(other);
}
}; };
/*! \brief range over one dimension */ /*! \brief range over one dimension */
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/container.h>
#include <type_traits> #include <type_traits>
#include <string> #include <string>
...@@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode { ...@@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode {
* \code * \code
* *
* void GetAttrExample(const BaseFunc& f) { * void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0); * auto value = f->GetAttr<Integer>("AttrKey", 0);
* } * }
* *
* \endcode * \endcode
*/ */
template<typename TObjectRef> template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key, Optional<TObjectRef> GetAttr(
TObjectRef default_value = NullValue<TObjectRef>()) const { const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value, static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types."); "Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value; if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key); auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) { if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second); return Downcast<Optional<TObjectRef>>((*it).second);
} else { } else {
return default_value; return default_value;
} }
} }
// variant that uses TObjectRef to enable implicit conversion to default value.
template<typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*! /*!
* \brief Check whether the function has an non-zero integer attr. * \brief Check whether the function has an non-zero integer attr.
* *
...@@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode { ...@@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode {
* \endcode * \endcode
*/ */
bool HasNonzeroAttr(const std::string& attr_key) const { bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0; return GetAttr<Integer>(attr_key, 0) != 0;
} }
static constexpr const char* _type_key = "BaseFunc"; static constexpr const char* _type_key = "BaseFunc";
......
...@@ -63,7 +63,6 @@ using runtime::make_object; ...@@ -63,7 +63,6 @@ using runtime::make_object;
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::String;
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_NODE_H_ #endif // TVM_NODE_NODE_H_
...@@ -354,6 +354,10 @@ class StringObj : public Object { ...@@ -354,6 +354,10 @@ class StringObj : public Object {
class String : public ObjectRef { class String : public ObjectRef {
public: public:
/*! /*!
* \brief Construct an empty string.
*/
String() : String(std::string()) {}
/*!
* \brief Construct a new String object * \brief Construct a new String object
* *
* \param other The moved/copied std::string object * \param other The moved/copied std::string object
...@@ -467,9 +471,6 @@ class String : public ObjectRef { ...@@ -467,9 +471,6 @@ class String : public ObjectRef {
*/ */
size_t size() const { size_t size() const {
const auto* ptr = get(); const auto* ptr = get();
if (ptr == nullptr) {
return 0;
}
return ptr->size; return ptr->size;
} }
...@@ -524,7 +525,7 @@ class String : public ObjectRef { ...@@ -524,7 +525,7 @@ class String : public ObjectRef {
/*! \return the internal StringObj pointer */ /*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); } const StringObj* get() const { return operator->(); }
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private: private:
/*! /*!
...@@ -610,7 +611,146 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { ...@@ -610,7 +611,146 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
} }
}; };
/*!
* \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef.
*
* \code
*
* Optional<String> opt0 = nullptr;
* Optional<String> opt1 = String("xyz");
* CHECK(opt0 == nullptr);
* CHECK(opt1 == "xyz");
*
* \endcode
*/
template<typename T>
class Optional : public ObjectRef {
public:
using ContainerType = typename T::ContainerType;
static_assert(std::is_base_of<ObjectRef, T>::value,
"Optional is only defined for ObjectRef.");
// default constructors.
Optional() = default;
Optional(const Optional<T>&) = default;
Optional(Optional<T>&&) = default;
Optional<T>& operator=(const Optional<T>&) = default;
Optional<T>& operator=(Optional<T>&&) = default;
/*!
* \brief Construct from an ObjectPtr
* whose type already matches the ContainerType.
* \param ptr
*/
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
// nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {}
Optional<T>& operator=(std::nullptr_t) {
data_ = nullptr;
return *this;
}
// normal value handling.
Optional(T other) // NOLINT(*)
: ObjectRef(std::move(other)) {
}
Optional<T>& operator=(T other) {
ObjectRef::operator=(std::move(other));
return *this;
}
// delete the int constructor
// since Optional<Integer>(0) is ambiguious
// 0 can be implicitly casted to nullptr_t
explicit Optional(int val) = delete;
Optional<T>& operator=(int val) = delete;
/*!
* \return A not-null container value in the optional.
* \note This function performs not-null checking.
*/
T value() const {
CHECK(data_ != nullptr);
return T(data_);
}
/*!
* \return The contained value if the Optional is not null
* otherwise return the default_value.
*/
T value_or(T default_value) const {
return data_ != nullptr ? T(data_) : default_value;
}
/*! \return Whether the container is not nullptr.*/
explicit operator bool() const {
return *this != nullptr;
}
// operator overloadings
bool operator==(std::nullptr_t) const {
return data_ == nullptr;
}
bool operator!=(std::nullptr_t) const {
return data_ != nullptr;
}
auto operator==(const Optional<T>& other) const {
// support case where sub-class returns a symbolic ref type.
using RetType = decltype(value() == other.value());
if (same_as(other)) return RetType(true);
if (*this != nullptr && other != nullptr) {
return value() == other.value();
} else {
// one of them is nullptr.
return RetType(false);
}
}
auto operator!=(const Optional<T>& other) const {
// support case where sub-class returns a symbolic ref type.
using RetType = decltype(value() != other.value());
if (same_as(other)) return RetType(false);
if (*this != nullptr && other != nullptr) {
return value() != other.value();
} else {
// one of them is nullptr.
return RetType(true);
}
}
auto operator==(const T& other) const {
using RetType = decltype(value() == other);
if (same_as(other)) return RetType(true);
if (*this != nullptr) return value() == other;
return RetType(false);
}
auto operator!=(const T& other) const {
return !(*this == other);
}
template<typename U>
auto operator==(const U& other) const {
using RetType = decltype(value() == other);
if (*this == nullptr) return RetType(false);
return value() == other;
}
template<typename U>
auto operator!=(const U& other) const {
using RetType = decltype(value() != other);
if (*this == nullptr) return RetType(true);
return value() != other;
}
static constexpr bool _type_is_nullable = true;
};
template<typename T>
struct PackedFuncValueConverter<Optional<T>> {
static Optional<T> From(const TVMArgValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
static Optional<T> From(const TVMRetValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
};
} // namespace runtime } // namespace runtime
// expose the functions to the root namespace.
using runtime::String;
using runtime::Optional;
} // namespace tvm } // namespace tvm
namespace std { namespace std {
......
...@@ -546,7 +546,9 @@ class ObjectRef { ...@@ -546,7 +546,9 @@ class ObjectRef {
bool operator<(const ObjectRef& other) const { bool operator<(const ObjectRef& other) const {
return data_.get() < other.data_.get(); return data_.get() < other.data_.get();
} }
/*! \return whether the expression is null */ /*!
* \return whether the object is defined(not null).
*/
bool defined() const { bool defined() const {
return data_ != nullptr; return data_ != nullptr;
} }
...@@ -582,6 +584,8 @@ class ObjectRef { ...@@ -582,6 +584,8 @@ class ObjectRef {
/*! \brief type indicate the container type. */ /*! \brief type indicate the container type. */
using ContainerType = Object; using ContainerType = Object;
// Default type properties for the reference class.
static constexpr bool _type_is_nullable = true;
protected: protected:
/*! \brief Internal pointer that backs the reference. */ /*! \brief Internal pointer that backs the reference. */
...@@ -720,6 +724,17 @@ struct ObjectEqual { ...@@ -720,6 +724,17 @@ struct ObjectEqual {
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex() TypeName::_GetOrAllocRuntimeTypeIndex()
/*
* \brief Define the default copy/move constructor and assign opeator
* \param TypeName The class typename.
*/
#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
TypeName(const TypeName& other) = default; \
TypeName(TypeName&& other) = default; \
TypeName& operator=(const TypeName& other) = default; \
TypeName& operator=(TypeName&& other) = default; \
/* /*
* \brief Define object reference methods. * \brief Define object reference methods.
* \param TypeName The object type name * \param TypeName The object type name
...@@ -727,16 +742,35 @@ struct ObjectEqual { ...@@ -727,16 +742,35 @@ struct ObjectEqual {
* \param ObjectName The type name of the object. * \param ObjectName The type name of the object.
*/ */
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() = default; \
explicit TypeName( \ explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \ : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { \ const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \ return static_cast<const ObjectName*>(data_.get()); \
} \ } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
/* /*
* \brief Define object reference methods that is not nullable.
*
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;
/*
* \brief Define object reference methods of whose content is mutable. * \brief Define object reference methods of whose content is mutable.
* \param TypeName The object type name * \param TypeName The object type name
* \param ParentType The parent type of the objectref * \param ParentType The parent type of the objectref
...@@ -745,7 +779,8 @@ struct ObjectEqual { ...@@ -745,7 +779,8 @@ struct ObjectEqual {
* This macro is only reserved for objects that stores runtime states. * This macro is only reserved for objects that stores runtime states.
*/ */
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() = default; \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
explicit TypeName( \ explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \ : ParentType(n) {} \
...@@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const { ...@@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const {
} }
} }
template <typename RelayRefType, typename ObjType> template <typename RefType, typename ObjType>
inline RelayRefType GetRef(const ObjType* ptr) { inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RelayRefType::ContainerType, ObjType>::value, static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type"); "Can only cast to the ref of same container type");
return RelayRefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr)))); if (!RefType::_type_is_nullable) {
CHECK(ptr != nullptr);
}
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
} }
template <typename BaseType, typename ObjType> template <typename BaseType, typename ObjType>
...@@ -885,9 +923,15 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) { ...@@ -885,9 +923,15 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
template <typename SubRef, typename BaseRef> template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) { inline SubRef Downcast(BaseRef ref) {
CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>()) if (ref.defined()) {
<< "Downcast from " << ref->GetTypeKey() << " to " CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
<< SubRef::ContainerType::_type_key << " failed."; << "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
} else {
CHECK(SubRef::_type_is_nullable)
<< "Downcast from nullptr to not nullable reference of "
<< SubRef::ContainerType::_type_key;
}
return SubRef(std::move(ref.data_)); return SubRef(std::move(ref.data_));
} }
......
...@@ -352,7 +352,7 @@ template<typename T> ...@@ -352,7 +352,7 @@ template<typename T>
struct ObjectTypeChecker { struct ObjectTypeChecker {
static bool Check(const Object* ptr) { static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true; if (ptr == nullptr) return T::_type_is_nullable;
return ptr->IsInstance<ContainerType>(); return ptr->IsInstance<ContainerType>();
} }
static std::string TypeName() { static std::string TypeName() {
...@@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef"); "Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType; using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kTVMNullptr) return TObjectRef(ObjectPtr<Object>(nullptr)); if (type_code_ == kTVMNullptr) {
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
// NOTE: the following code can be optimized by constant folding. // NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) { if (std::is_base_of<NDArray, TObjectRef>::value) {
// Casting to a sub-class of NDArray // Casting to a sub-class of NDArray
......
...@@ -96,7 +96,7 @@ def config_cython(): ...@@ -96,7 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include", "../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include", "../3rdparty/dlpack/include",
], ],
extra_compile_args=["-std=c++11"], extra_compile_args=["-std=c++14"],
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
language="c++")) language="c++"))
......
...@@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed, ...@@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed,
auto host_pass_list = { auto host_pass_list = {
FilterBy([](const tir::PrimFunc& f) { FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value; return f->GetAttr<Integer>(
return value != static_cast<int>(CallingConv::kDeviceKernelLaunch); tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
}), }),
BindTarget(target_host), BindTarget(target_host),
tir::transform::LowerTVMBuiltin(), tir::transform::LowerTVMBuiltin(),
...@@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed, ...@@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed,
// device pipeline // device pipeline
auto device_pass_list = { auto device_pass_list = {
FilterBy([](const tir::PrimFunc& f) { FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value; return f->GetAttr<Integer>(
return value == static_cast<int>(CallingConv::kDeviceKernelLaunch); tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
}), }),
BindTarget(target), BindTarget(target),
tir::transform::LowerWarpMemory(), tir::transform::LowerWarpMemory(),
......
...@@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode {
if (src_func->GetAttr<String>(attr::kCompiler).defined()) { if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler); auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set"; CHECK(code_gen.defined()) << "No external codegen is set";
std::string code_gen_name = code_gen; std::string code_gen_name = code_gen.value();
if (ext_mods.find(code_gen_name) == ext_mods.end()) { if (ext_mods.find(code_gen_name) == ext_mods.end()) {
ext_mods[code_gen_name] = IRModule({}, {}); ext_mods[code_gen_name] = IRModule({}, {});
} }
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol); auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n" CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false); << AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name)); auto gv = GlobalVar(std::string(symbol_name.value()));
ext_mods[code_gen_name]->Add(gv, src_func); ext_mods[code_gen_name]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first); cached_ext_funcs.push_back(it.first);
} }
...@@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode {
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol); key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) CHECK(name_node.defined())
<< "External function has not been attached a name yet."; << "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node); cache_node->func_name = std::string(name_node.value());
cache_node->target = tvm::target::ext_dev(); cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node); value->cached_func = CachedFunc(cache_node);
return value; return value;
......
...@@ -72,7 +72,7 @@ class CSourceModuleCodegenBase { ...@@ -72,7 +72,7 @@ class CSourceModuleCodegenBase {
const auto name_node = const auto name_node =
func->GetAttr<String>(tvm::attr::kGlobalSymbol); func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol."; CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node); return std::string(name_node.value());
} }
}; };
......
...@@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
const Expr& outputs) { const Expr& outputs) {
std::vector<Index> argument_registers; std::vector<Index> argument_registers;
CHECK_NE(func->GetAttr<Integer>(attr::kPrimitive, 0)->value, 0) CHECK(func->GetAttr<Integer>(attr::kPrimitive, 0) != 0)
<< "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
auto input_tuple = inputs.as<TupleNode>(); auto input_tuple = inputs.as<TupleNode>();
......
...@@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) { ...@@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) {
} }
bool IsClosure(const Function& func) { bool IsClosure(const Function& func) {
return func->GetAttr<Integer>(attr::kClosure, 0)->value != 0; return func->GetAttr<Integer>(attr::kClosure, 0) != 0;
} }
Function MarkClosure(Function func) { Function MarkClosure(Function func) {
......
...@@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, ...@@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod,
} }
bool FunctionPassNode::SkipFunction(const Function& func) const { bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 || return (func->GetAttr<String>(attr::kCompiler).defined()) ||
(func->GetAttr<String>(attr::kCompiler).defined()); func->GetAttr<Integer>(attr::kSkipOptimization, 0) != 0;
} }
Pass CreateFunctionPass( Pass CreateFunctionPass(
......
...@@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator {
// if it is in the target list. // if it is in the target list.
Function func = Downcast<Function>(cn->op); Function func = Downcast<Function>(cn->op);
CHECK(func.defined()); CHECK(func.defined());
auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) { if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
std::string comp_name_str = comp_name; std::string comp_name_str = comp_name.value();
size_t i = comp_name_str.find('.'); size_t i = comp_name_str.find('.');
if (i != std::string::npos) { if (i != std::string::npos) {
std::string comp_target = comp_name_str.substr(0, i); std::string comp_target = comp_name_str.substr(0, i);
......
...@@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) { ...@@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) {
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis); if (auto opt = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
if (thread_axis.defined()) { auto thread_axis = opt.value();
for (size_t i = 0; i < thread_axis.size(); ++i) { for (size_t i = 0; i < thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
} }
} }
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info; fmap[static_cast<std::string>(global_symbol.value())] = info;
} }
return fmap; return fmap;
} }
......
...@@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { ...@@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) {
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back( export_system_symbols_.emplace_back(
std::make_pair(global_symbol.operator std::string(), std::make_pair(global_symbol.value().operator std::string(),
builder_->CreatePointerCast(function_, t_void_p_))); builder_->CreatePointerCast(function_, t_void_p_)));
} }
AddDebugInformation(function_); AddDebugInformation(function_);
......
...@@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { ...@@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr) CHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
<< "Function " << global_symbol << " already exist in module"; << "Function " << global_symbol << " already exist in module";
function_ = llvm::Function::Create( function_ = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, ftype, llvm::Function::ExternalLinkage,
global_symbol.operator std::string(), module_.get()); global_symbol.value().operator std::string(), module_.get());
function_->setCallingConv(llvm::CallingConv::C); function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
......
...@@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()); CHECK(global_symbol.defined());
entry_func = global_symbol; entry_func = global_symbol.value();
} }
funcs.push_back(f); funcs.push_back(f);
} }
......
...@@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) { ...@@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) {
<< "CodeGenCUDA: Can only take PrimFunc"; << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
......
...@@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod, ...@@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod,
<< "CodegenOpenCL: Can only take PrimFunc"; << "CodegenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
......
...@@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { ...@@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(); this->PrintFuncPrefix();
this->stream << " " << static_cast<std::string>(global_symbol) << "("; this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i]; tir::Var v = f->params[i];
......
...@@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ...@@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header. // Function header.
this->stream << "kernel void " << static_cast<std::string>(global_symbol) << "("; this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
// Buffer arguments // Buffer arguments
size_t num_buffer = 0; size_t num_buffer = 0;
...@@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ...@@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
size_t nargs = f->params.size() - num_buffer; size_t nargs = f->params.size() - num_buffer;
std::string varg = GetUniqueName("arg"); std::string varg = GetUniqueName("arg");
if (nargs != 0) { if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol) + "_args_t"; std::string arg_buf_type =
static_cast<std::string>(global_symbol.value()) + "_args_t";
stream << " constant " << arg_buf_type << "& " << varg stream << " constant " << arg_buf_type << "& " << varg
<< " [[ buffer(" << num_buffer << ") ]],\n"; << " [[ buffer(" << num_buffer << ") ]],\n";
// declare the struct // declare the struct
...@@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ...@@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
int work_dim = 0; int work_dim = 0;
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis); auto thread_axis = f->GetAttr<Array<tir::IterVar>>(
CHECK(thread_axis.defined()); tir::attr::kDeviceThreadAxis).value();
for (IterVar iv : thread_axis) { for (IterVar iv : thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
...@@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) { ...@@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) {
<< "CodeGenMetal: Can only take PrimFunc"; << "CodeGenMetal: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
......
...@@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) { ...@@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) {
<< "CodeGenOpenCL: Can only take PrimFunc"; << "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
......
...@@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { ...@@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
shaders_[static_cast<std::string>(global_symbol)] = runtime::OpenGLShader( shaders_[static_cast<std::string>(global_symbol.value())] = runtime::OpenGLShader(
this->decl_stream.str() + this->stream.str(), this->decl_stream.str() + this->stream.str(),
std::move(arg_names), std::move(arg_kinds), std::move(arg_names), std::move(arg_kinds),
this->thread_extent_var_); this->thread_extent_var_);
...@@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) { ...@@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) {
<< "CodeGenOpenGL: Can only take PrimFunc"; << "CodeGenOpenGL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
......
...@@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { ...@@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
<< "CodeGenVHLS: Can only take PrimFunc"; << "CodeGenVHLS: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
...@@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { ...@@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
kernel_info.push_back({global_symbol, code}); kernel_info.push_back({global_symbol.value(), code});
} }
std::string xclbin; std::string xclbin;
......
...@@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) { ...@@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) {
<< "CodeGenSPIRV: Can only take PrimFunc"; << "CodeGenSPIRV: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() && CHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol; std::string f_name = global_symbol.value();
f = PointerValueTypeRewrite(std::move(f)); f = PointerValueTypeRewrite(std::move(f));
VulkanShader shader; VulkanShader shader;
shader.data = cg.BuildFunction(f); shader.data = cg.BuildFunction(f);
......
...@@ -82,7 +82,8 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) { ...@@ -82,7 +82,8 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
builder_->CommitKernelFunction(func_ptr, static_cast<std::string>(global_symbol)); builder_->CommitKernelFunction(
func_ptr, static_cast<std::string>(global_symbol.value()));
return builder_->Finalize(); return builder_->Finalize();
} }
......
...@@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { ...@@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()) CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol; std::string f_name = global_symbol.value();
StackVM vm = codegen::CodeGenStackVM().Compile(f); StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!fmap.count(f_name)) CHECK(!fmap.count(f_name))
<< "Function name " << f_name << "already exist in list"; << "Function name " << f_name << "already exist in list";
......
...@@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) { ...@@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget); auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute"; << "LowerWarpMemory: Require the target attribute";
MemoryAccessVerifier v(func, target->device_type); MemoryAccessVerifier v(func, target.value()->device_type);
v.Run(); v.Run();
if (v.Failed()) { if (v.Failed()) {
LOG(FATAL) LOG(FATAL)
......
...@@ -99,7 +99,7 @@ Pass BindDeviceType() { ...@@ -99,7 +99,7 @@ Pass BindDeviceType() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "BindDeviceType: Require the target attribute"; << "BindDeviceType: Require the target attribute";
n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
......
...@@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() { ...@@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() {
CHECK(target.defined()) CHECK(target.defined())
<< "LowerCustomDatatypes: Require the target attribute"; << "LowerCustomDatatypes: Require the target attribute";
n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
......
...@@ -293,7 +293,7 @@ Pass LowerIntrin() { ...@@ -293,7 +293,7 @@ Pass LowerIntrin() {
<< "LowerIntrin: Require the target attribute"; << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer; arith::Analyzer analyzer;
n->body = n->body =
IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
......
...@@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() { ...@@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute"; << "LowerThreadAllreduce: Require the target attribute";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
......
...@@ -393,7 +393,7 @@ Pass LowerWarpMemory() { ...@@ -393,7 +393,7 @@ Pass LowerWarpMemory() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute"; << "LowerWarpMemory: Require the target attribute";
n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
......
...@@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { ...@@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
PrimFunc MakePackedAPI(PrimFunc&& func, PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) { int num_unpacked_args) {
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()) CHECK(global_symbol)
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol; std::string name_hint = global_symbol.value();
auto* func_ptr = func.CopyOnWrite(); auto* func_ptr = func.CopyOnWrite();
const Stmt nop = EvaluateNode::make(0); const Stmt nop = EvaluateNode::make(0);
...@@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) { ...@@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) {
for (const auto& kv : mptr->functions) { for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) { if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n); PrimFunc func = GetRef<PrimFunc>(n);
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value if (func->GetAttr<Integer>(
== static_cast<int>(CallingConv::kDefault)) { tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDefault) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
updates.push_back({kv.first, updated_func}); updates.push_back({kv.first, updated_func});
} }
......
...@@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) ...@@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map)
tmap[kv.first] = kv.second; tmap[kv.first] = kv.second;
} }
auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis); auto opt_thread_axis = f->GetAttr<Array<IterVar>>(tir::attr::kDeviceThreadAxis);
CHECK(opt_thread_axis != nullptr)
<< "Require attribute " << tir::attr::kDeviceThreadAxis;
auto thread_axis = opt_thread_axis.value();
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
// replace the thread axis // replace the thread axis
......
...@@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { ...@@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
HostDeviceSplitter splitter( HostDeviceSplitter splitter(
device_mod, target, static_cast<std::string>(global_symbol)); device_mod,
target.value(),
static_cast<std::string>(global_symbol.value()));
auto* n = func.CopyOnWrite(); auto* n = func.CopyOnWrite();
n->body = splitter(std::move(n->body)); n->body = splitter(std::move(n->body));
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/function.h>
#include <new> #include <new>
#include <unordered_map> #include <unordered_map>
...@@ -401,6 +402,74 @@ TEST(String, Cast) { ...@@ -401,6 +402,74 @@ TEST(String, Cast) {
String s2 = Downcast<String>(r); String s2 = Downcast<String>(r);
} }
TEST(Optional, Composition) {
Optional<String> opt0(nullptr);
Optional<String> opt1 = String("xyz");
Optional<String> opt2 = String("xyz1");
// operator bool
CHECK(!opt0);
CHECK(opt1);
// comparison op
CHECK(opt0 != "xyz");
CHECK(opt1 == "xyz");
CHECK(opt1 != nullptr);
CHECK(opt0 == nullptr);
CHECK(opt0.value_or("abc") == "abc");
CHECK(opt1.value_or("abc") == "xyz");
CHECK(opt0 != opt1);
CHECK(opt1 == Optional<String>(String("xyz")));
CHECK(opt0 == Optional<String>(nullptr));
opt0 = opt1;
CHECK(opt0 == opt1);
CHECK(opt0.value().same_as(opt1.value()));
opt0 = std::move(opt2);
CHECK(opt0 != opt2);
}
TEST(Optional, IntCmp) {
Integer val(CallingConv::kDefault);
Optional<Integer> opt = Integer(0);
CHECK(0 == static_cast<int>(CallingConv::kDefault));
CHECK(val == CallingConv::kDefault);
CHECK(opt == CallingConv::kDefault);
// check we can handle implicit 0 to nullptr conversion.
Optional<Integer> opt1(nullptr);
CHECK(opt1 != 0);
CHECK(opt1 != false);
CHECK(!(opt1 == 0));
}
TEST(Optional, PackedCall) {
auto tf = [](Optional<String> s, bool isnull) {
if (isnull) {
CHECK(s == nullptr);
} else {
CHECK(s != nullptr);
}
return s;
};
auto func = TypedPackedFunc<Optional<String>(Optional<String>, bool)>(tf);
CHECK(func(String("xyz"), false) == "xyz");
CHECK(func(Optional<String>(nullptr), true) == nullptr);
auto pf = [](TVMArgs args, TVMRetValue* rv) {
Optional<String> s = args[0];
bool isnull = args[1];
if (isnull) {
CHECK(s == nullptr);
} else {
CHECK(s != nullptr);
}
*rv = s;
};
auto packedfunc = PackedFunc(pf);
CHECK(packedfunc("xyz", false).operator String() == "xyz");
CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz");
CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr);
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -39,7 +39,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -39,7 +39,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {} kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
tmp_path = util.tempdir() tmp_path = util.tempdir()
lib_name = 'lib.so' lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name) lib_path = tmp_path.relpath(lib_name)
......
...@@ -468,13 +468,13 @@ def run_extern(label, get_extern_src, **kwargs): ...@@ -468,13 +468,13 @@ def run_extern(label, get_extern_src, **kwargs):
def test_dso_extern(): def test_dso_extern():
run_extern("lib", generate_csource_module, options=["-O2", "-std=c++11"]) run_extern("lib", generate_csource_module, options=["-O2", "-std=c++14"])
def test_engine_extern(): def test_engine_extern():
run_extern("engine", run_extern("engine",
generate_engine_module, generate_engine_module,
options=["-O2", "-std=c++11", "-I" + tmp_path.relpath("")]) options=["-O2", "-std=c++14", "-I" + tmp_path.relpath("")])
def test_json_extern(): def test_json_extern():
if not tvm.get_global_func("module.loadfile_examplejson", True): if not tvm.get_global_func("module.loadfile_examplejson", True):
......
...@@ -42,7 +42,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -42,7 +42,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {} kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
tmp_path = util.tempdir() tmp_path = util.tempdir()
lib_name = 'lib.so' lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name) lib_path = tmp_path.relpath(lib_name)
......
...@@ -182,7 +182,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -182,7 +182,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {} kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
tmp_path = util.tempdir() tmp_path = util.tempdir()
lib_name = 'lib.so' lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name) lib_path = tmp_path.relpath(lib_name)
......
...@@ -191,7 +191,7 @@ def test_mod_export(): ...@@ -191,7 +191,7 @@ def test_mod_export():
path_lib = temp.relpath(file_name) path_lib = temp.relpath(file_name)
resnet18_cpu_lib.import_module(f) resnet18_cpu_lib.import_module(f)
resnet18_cpu_lib.import_module(engine_module) resnet18_cpu_lib.import_module(engine_module)
kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]} kwargs = {"options": ["-O2", "-std=c++14", "-I" + header_file_dir_path.relpath("")]}
resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs)
loaded_lib = tvm.runtime.load_module(path_lib) loaded_lib = tvm.runtime.load_module(path_lib)
assert loaded_lib.type_key == "library" assert loaded_lib.type_key == "library"
......
...@@ -107,7 +107,7 @@ def server_start(): ...@@ -107,7 +107,7 @@ def server_start():
if pkg.same_config(old_cfg): if pkg.same_config(old_cfg):
logging.info("Skip reconfig_runtime due to same config.") logging.info("Skip reconfig_runtime due to same config.")
return return
cflags = ["-O2", "-std=c++11"] cflags = ["-O2", "-std=c++14"]
cflags += pkg.cflags cflags += pkg.cflags
ldflags = pkg.ldflags ldflags = pkg.ldflags
lib_name = dll_path lib_name = dll_path
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment