Skip to content
Snippets Groups Projects
Commit bc3959b2 authored by tqchen's avatar tqchen
Browse files

Change array to copy on write semnatics

parent 3c0dc79d
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,8 @@ class ArrayNode : public Node {
/*!
* \brief Immutable array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
* \tparam T The content NodeRef type.
*/
template<typename T,
......@@ -128,6 +130,62 @@ class Array : public NodeRef {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
/*! \brief copy on write semantics */
inline void CopyOnWrite() {
if (node_.get() == nullptr || node_.unique()) return;
node_ = std::make_shared<ArrayNode>(
*static_cast<const ArrayNode*>(node_.get()));
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param other The value to be setted.
*/
inline void Set(size_t i, const T& value) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data[i] = value.node_;
}
/*! \brief wrapper class to represent an array reference */
struct ArrayItemRef {
/*! \brief reference to parent */
Array<T>* parent;
/*! \brief The index */
size_t index;
/*!
* \brief assign operator
* \param value The value to be assigned
* \return reference to self.
*/
inline ArrayItemRef& operator=(const T& other) {
parent->Set(index, other);
return *this;
}
/*! \brief The conversion operator */
inline operator T() const {
return (*static_cast<const Array<T>*>(parent))[index];
}
// overload print
friend std::ostream& operator<<(
std::ostream &os, const typename Array<T>::ArrayItemRef& r) { // NOLINT(*0
return os << r.operator T();
}
};
/*!
* \brief Get reference of i-th element from array.
* \param i The index
* \return the ref to i-th element.
*/
inline ArrayItemRef operator[](size_t i) {
return ArrayItemRef{this, i};
}
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) {
......
......@@ -9,6 +9,7 @@
#include <string>
#include <vector>
#include <type_traits>
#include "./base.h"
#include "./expr.h"
#include "./array.h"
......
......@@ -12,6 +12,18 @@ TEST(Array, Expr) {
LOG(INFO) << list[1];
}
TEST(Array, Mutate) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Array<Expr> list{x, z, z};
auto list2 = list;
list[1] = x;
LOG(INFO) << list[1];
LOG(INFO) << list2[1];
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment