Commit bc3959b2 by tqchen

Change array to copy on write semnatics

parent 3c0dc79d
......@@ -32,6 +32,8 @@ class ArrayNode : public Node {
/*!
* \brief Immutable array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
* \tparam T The content NodeRef type.
*/
template<typename T,
......@@ -128,6 +130,62 @@ class Array : public NodeRef {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
/*! \brief copy on write semantics */
inline void CopyOnWrite() {
if (node_.get() == nullptr || node_.unique()) return;
node_ = std::make_shared<ArrayNode>(
*static_cast<const ArrayNode*>(node_.get()));
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param other The value to be setted.
*/
inline void Set(size_t i, const T& value) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data[i] = value.node_;
}
/*! \brief wrapper class to represent an array reference */
struct ArrayItemRef {
/*! \brief reference to parent */
Array<T>* parent;
/*! \brief The index */
size_t index;
/*!
* \brief assign operator
* \param value The value to be assigned
* \return reference to self.
*/
inline ArrayItemRef& operator=(const T& other) {
parent->Set(index, other);
return *this;
}
/*! \brief The conversion operator */
inline operator T() const {
return (*static_cast<const Array<T>*>(parent))[index];
}
// overload print
friend std::ostream& operator<<(
std::ostream &os, const typename Array<T>::ArrayItemRef& r) { // NOLINT(*0
return os << r.operator T();
}
};
/*!
* \brief Get reference of i-th element from array.
* \param i The index
* \return the ref to i-th element.
*/
inline ArrayItemRef operator[](size_t i) {
return ArrayItemRef{this, i};
}
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) {
......
......@@ -9,6 +9,7 @@
#include <string>
#include <vector>
#include <type_traits>
#include "./base.h"
#include "./expr.h"
#include "./array.h"
......
......@@ -12,6 +12,18 @@ TEST(Array, Expr) {
LOG(INFO) << list[1];
}
TEST(Array, Mutate) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Array<Expr> list{x, z, z};
auto list2 = list;
list[1] = x;
LOG(INFO) << list[1];
LOG(INFO) << list2[1];
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment