Unverified Commit d2a79a5f by Wei Chen Committed by GitHub

[Object] Add String container (#4628)

parent d04525e8
......@@ -28,7 +28,37 @@
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <cstring>
#include <initializer_list>
#include <string>
// We use c++14 std::experimental::string_view for optimizing hash computation
// only right now, its usage is limited in this file. Any broader usage of
// std::experiment in our core codebase is discouraged and needs community
// discussion for each use case. Reference for feature test macros of
// string_view:
// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations
// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros
#if defined(__cpp_lib_experimental_string_view) && \
__cpp_lib_experimental_string_view >= 201411
#define TVM_USE_CXX14_STRING_VIEW_HASH 1
#else
#define TVM_USE_CXX14_STRING_VIEW_HASH 0
#endif
// Tested with clang version 9.0.1 and c++17. It will detect string_view support
// correctly.
#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606
#define TVM_USE_CXX17_STRING_VIEW_HASH 1
#else
#define TVM_USE_CXX17_STRING_VIEW_HASH 0
#endif
#if TVM_USE_CXX17_STRING_VIEW_HASH
#include <string_view>
#elif TVM_USE_CXX14_STRING_VIEW_HASH
#include <experimental/string_view>
#endif
#include <type_traits>
#include <utility>
#include <vector>
......@@ -274,7 +304,285 @@ class ADT : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
};
/*! \brief An object representing string. It's POD type. */
class StringObj : public Object {
public:
/*! \brief The pointer to string data. */
const char* data;
/*! \brief The length of the string object. */
uint64_t size;
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
private:
/*! \brief String object which is moved from std::string container. */
class FromStd;
friend class String;
};
/*!
* \brief Reference to string objects.
*
* \code
*
* // Example to create runtime String reference object from std::string
* std::string s = "hello world";
*
* // You can create the reference from existing std::string
* String ref{std::move(s)};
*
* // You can rebind the reference to another string.
* ref = std::string{"hello world2"};
*
* // You can use the reference as hash map key
* std::unordered_map<String, int32_t> m;
* m[ref] = 1;
*
* // You can compare the reference object with other string objects
* assert(ref == "hello world", true);
*
* // You can convert the reference to std::string again
* string s2 = (string)ref;
*
* \endcode
*/
class String : public ObjectRef {
public:
/*!
* \brief Construct a new String object
*
* \param other The moved/copied std::string object
*
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit String(std::string other);
/*!
* \brief Change the value the reference object points to.
*
* \param other The value for the new String
*
*/
inline String operator=(std::string other);
/*!
* \brief Compare is equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator==(const std::string& other) const {
return this->compare(other) == 0;
}
/*!
* \brief Compare is not equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator!=(const std::string& other) const { return !operator==(other); }
/*!
* \brief Compare is equal to other char string
*
* \param other The other char string
*
* \return the comparison result
*/
bool operator==(const char* other) const { return compare(other) == 0; }
/*!
* \brief Compare is not equal to other char string
*
* \param other The other char string
*
* \return the comparison result
*/
bool operator!=(const char* other) const { return !operator==(other); }
/*!
* \brief Compares this String object to other
*
* \param other The String to compare with.
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const String& other) const {
return memncmp(data(), other.data(), size(), other.size());
}
/*!
* \brief Compares this String object to other
*
* \param other The string to compare with.
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const std::string& other) const {
return memncmp(data(), other.data(), size(), other.size());
}
/*!
* \brief Compares this to other
*
* \param other The character array to compare with.
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const char* other) const {
return memncmp(data(), other, size(), std::strlen(other));
}
/*!
* \brief Returns a pointer to the char array in the string.
*
* \return const char*
*/
const char* c_str() const { return get()->data; }
/*!
* \brief Return the length of the string
*
* \return size_t string length
*/
size_t size() const {
const auto* ptr = get();
if (ptr == nullptr) {
return 0;
}
return ptr->size;
}
/*!
* \brief Return the length of the string
*
* \return size_t string length
*/
size_t length() const { return size(); }
/*!
* \brief Retun if the string is empty
*
* \return true if empty, false otherwise.
*/
bool empty() const { return size() == 0; }
/*!
* \brief Return the data pointer
*
* \return const char* data pointer
*/
const char* data() const { return get()->data; }
/*!
* \brief Convert String to an std::sting object
*
* \return std::string
*/
operator std::string() const { return std::string{get()->data, size()}; }
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }
/*!
* \brief Compare two char sequence
*
* \param lhs Pointers to the char array to compare
* \param rhs Pointers to the char array to compare
* \param lhs_count Length of the char array to compare
* \param rhs_count Length of the char array to compare
* \return int zero if both char sequences compare equal. negative if this
* appear before other, positive otherwise.
*/
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count,
size_t rhs_count);
};
/*! \brief An object representing string moved from std::string. */
class StringObj::FromStd : public StringObj {
public:
/*!
* \brief Construct a new FromStd object
*
* \param other The moved/copied std::string object
*
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit FromStd(std::string other) : data_container{other} {}
private:
/*! \brief Container that holds the memory. */
std::string data_container;
friend class String;
};
inline String::String(std::string other) {
auto ptr = make_object<StringObj::FromStd>(std::move(other));
ptr->size = ptr->data_container.size();
ptr->data = ptr->data_container.data();
data_ = std::move(ptr);
}
inline String String::operator=(std::string other) {
String replace{std::move(other)};
data_.swap(replace.data_);
return Downcast<String>(*this);
}
inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;
for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
if (lhs[i] < rhs[i]) return -1;
if (lhs[i] > rhs[i]) return 1;
}
if (lhs_count < rhs_count) {
return -1;
} else if (lhs_count > rhs_count) {
return 1;
} else {
return 0;
}
}
} // namespace runtime
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::runtime::String> {
std::size_t operator()(const ::tvm::runtime::String& str) const {
// This function falls back to string copy with c++11 compiler and is
// recommended to be compiled with c++14
#if TVM_USE_CXX17_STRING_VIEW_HASH
return std::hash<std::string_view>{}(
std::string_view{str.data(), str.size()});
#elif TVM_USE_CXX14_STRING_VIEW_HASH
return std::hash<std::experimental::string_view>{}(
std::experimental::string_view{str.data(), str.size()});
#else
return std::hash<std::string>()(str.operator std::string());
#endif
}
};
} // namespace std
#endif // TVM_RUNTIME_CONTAINER_H_
......@@ -19,8 +19,9 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tir/op.h>
#include <tvm/runtime/container.h>
#include <tvm/tir/op.h>
#include <new>
#include <unordered_map>
#include <vector>
......@@ -221,11 +222,185 @@ TEST(Map, Iterator) {
using namespace tvm;
PrimExpr a = 1, b = 2;
Map<PrimExpr, PrimExpr> map1{{a, b}};
std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual>
map2(map1.begin(), map1.end());
std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual> map2(
map1.begin(), map1.end());
CHECK(map2[a].as<IntImmNode>()->value == 2);
}
TEST(String, MoveFromStd) {
using namespace std;
string source = "this is a string";
string expect = source;
String s(std::move(source));
string copy = (string)s;
CHECK_EQ(copy, expect);
CHECK_EQ(source.size(), 0);
}
TEST(String, CopyFromStd) {
using namespace std;
string source = "this is a string";
string expect = source;
String s{source};
string copy = (string)s;
CHECK_EQ(copy, expect);
CHECK_EQ(source.size(), expect.size());
}
TEST(String, Assignment) {
using namespace std;
String s{string{"hello"}};
s = string{"world"};
CHECK_EQ(s == "world", true);
string s2{"world2"};
s = std::move(s2);
CHECK_EQ(s == "world2", true);
}
TEST(String, empty) {
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
s = "";
CHECK_EQ(s.empty(), true);
}
TEST(String, Comparisons) {
using namespace std;
string source = "a string";
string mismatch = "a string but longer";
String s{source};
CHECK_EQ(s == source, true);
CHECK_EQ(s == mismatch, false);
CHECK_EQ(s == source.data(), true);
CHECK_EQ(s == mismatch.data(), false);
}
// Check '\0' handling
TEST(String, null_byte_handling) {
using namespace std;
// Ensure string still compares equal if it contains '\0'.
string v1 = "hello world";
size_t v1_size = v1.size();
v1[5] = '\0';
CHECK_EQ(v1[5], '\0');
CHECK_EQ(v1.size(), v1_size);
String str_v1{v1};
CHECK_EQ(str_v1.compare(v1), 0);
CHECK_EQ(str_v1.size(), v1_size);
// Ensure bytes after '\0' are taken into account for mismatches.
string v2 = "aaa one";
string v3 = "aaa two";
v2[3] = '\0';
v3[3] = '\0';
String str_v2{v2};
String str_v3{v3};
CHECK_EQ(str_v2.compare(str_v3), -1);
CHECK_EQ(str_v2.size(), 7);
// strcmp won't be able to detect the mismatch
CHECK_EQ(strcmp(v2.data(), v3.data()), 0);
// string::compare can handle \0 since it knows size
CHECK_LT(v2.compare(v3), 0);
// If there is mismatch before '\0', should still handle it.
string v4 = "acc one";
string v5 = "abb two";
v4[3] = '\0';
v5[3] = '\0';
String str_v4{v4};
String str_v5{v5};
CHECK_GT(str_v4.compare(str_v5), 0);
CHECK_EQ(str_v4.size(), 7);
// strcmp is able to detect the mismatch
CHECK_GT(strcmp(v4.data(), v5.data()), 0);
// string::compare can handle \0 since it knows size
CHECK_GT(v4.compare(v5), 0);
}
TEST(String, compare_same_memory_region_different_size) {
using namespace std;
string source = "a string";
String str_source{source};
char* memory = const_cast<char*>(str_source.data());
CHECK_EQ(str_source.compare(memory), 0);
// This changes the string size
memory[2] = '\0';
// memory is logically shorter now
CHECK_GT(str_source.compare(memory), 0);
}
TEST(String, compare) {
using namespace std;
string source = "a string";
string mismatch1 = "a string but longer";
string mismatch2 = "a strin";
string mismatch3 = "a b";
string mismatch4 = "a t";
String str_source{source};
String str_mismatch1{mismatch1};
String str_mismatch2{mismatch2};
String str_mismatch3{mismatch3};
String str_mismatch4{mismatch4};
// compare with string
CHECK_EQ(str_source.compare(source), 0);
CHECK_LT(str_source.compare(mismatch1), 0);
CHECK_GT(str_source.compare(mismatch2), 0);
CHECK_GT(str_source.compare(mismatch3), 0);
CHECK_LT(str_source.compare(mismatch4), 0);
// compare with char*
CHECK_EQ(str_source.compare(source.data()), 0);
CHECK_LT(str_source.compare(mismatch1.data()), 0);
CHECK_GT(str_source.compare(mismatch2.data()), 0);
CHECK_GT(str_source.compare(mismatch3.data()), 0);
CHECK_LT(str_source.compare(mismatch4.data()), 0);
// compare with String
CHECK_LT(str_source.compare(str_mismatch1), 0);
CHECK_GT(str_source.compare(str_mismatch2), 0);
CHECK_GT(str_source.compare(str_mismatch3), 0);
CHECK_LT(str_source.compare(str_mismatch4), 0);
}
TEST(String, c_str) {
using namespace std;
string source = "this is a string";
string mismatch = "mismatch";
String s{source};
CHECK_EQ(std::strcmp(s.c_str(), source.data()), 0);
CHECK_NE(std::strcmp(s.c_str(), mismatch.data()), 0);
}
TEST(String, hash) {
using namespace std;
string source = "this is a string";
String s{source};
std::hash<String>()(s);
std::unordered_map<String, std::string> map;
String k1{string{"k1"}};
string v1{"v1"};
String k2{string{"k2"}};
string v2{"v2"};
map[k1] = v1;
map[k2] = v2;
CHECK_EQ(map[k1], v1);
CHECK_EQ(map[k2], v2);
}
TEST(String, Cast) {
using namespace std;
string source = "this is a string";
String s{source};
ObjectRef r = s;
String s2 = Downcast<String>(r);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment