Commit bc3959b2 by tqchen

Change array to copy on write semnatics

parent 3c0dc79d
...@@ -32,6 +32,8 @@ class ArrayNode : public Node { ...@@ -32,6 +32,8 @@ class ArrayNode : public Node {
/*! /*!
* \brief Immutable array container of NodeRef in DSL graph. * \brief Immutable array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
* \tparam T The content NodeRef type. * \tparam T The content NodeRef type.
*/ */
template<typename T, template<typename T,
...@@ -128,6 +130,62 @@ class Array : public NodeRef { ...@@ -128,6 +130,62 @@ class Array : public NodeRef {
if (node_.get() == nullptr) return 0; if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size(); return static_cast<const ArrayNode*>(node_.get())->data.size();
} }
/*! \brief copy on write semantics */
inline void CopyOnWrite() {
if (node_.get() == nullptr || node_.unique()) return;
node_ = std::make_shared<ArrayNode>(
*static_cast<const ArrayNode*>(node_.get()));
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param other The value to be setted.
*/
inline void Set(size_t i, const T& value) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data[i] = value.node_;
}
/*! \brief wrapper class to represent an array reference */
struct ArrayItemRef {
/*! \brief reference to parent */
Array<T>* parent;
/*! \brief The index */
size_t index;
/*!
* \brief assign operator
* \param value The value to be assigned
* \return reference to self.
*/
inline ArrayItemRef& operator=(const T& other) {
parent->Set(index, other);
return *this;
}
/*! \brief The conversion operator */
inline operator T() const {
return (*static_cast<const Array<T>*>(parent))[index];
}
// overload print
friend std::ostream& operator<<(
std::ostream &os, const typename Array<T>::ArrayItemRef& r) { // NOLINT(*0
return os << r.operator T();
}
};
/*!
* \brief Get reference of i-th element from array.
* \param i The index
* \return the ref to i-th element.
*/
inline ArrayItemRef operator[](size_t i) {
return ArrayItemRef{this, i};
}
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*) friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) { for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) { if (i == 0) {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
#include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./array.h" #include "./array.h"
......
...@@ -12,6 +12,18 @@ TEST(Array, Expr) { ...@@ -12,6 +12,18 @@ TEST(Array, Expr) {
LOG(INFO) << list[1]; LOG(INFO) << list[1];
} }
TEST(Array, Mutate) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Array<Expr> list{x, z, z};
auto list2 = list;
list[1] = x;
LOG(INFO) << list[1];
LOG(INFO) << list2[1];
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment