Commit 0de304ca by ziheng Committed by Tianqi Chen

Rename TShape::index_t to dim_t and change to int64_t (#114)

* Change TShape::index_t to int64_t

* Add comment

* Make Tuple::Save&Load dtype generic

* trigger update

* Fix lint

* Fix comment

* Change index_t to dim_t

* Remove legacy index_t
parent 56ea6d6b
......@@ -15,8 +15,8 @@
namespace nnvm {
/*! \brief data type to store array index */
typedef uint32_t index_t;
/*! \brief data type to store dim size */
typedef int64_t dim_t;
/*!
* \brief A dynamic sized array data strcuture that is optimized for storing
......@@ -151,7 +151,7 @@ class Tuple {
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
}
/*! \return number of dimension of the tuple */
inline index_t ndim() const {
inline uint32_t ndim() const {
return ndim_;
}
/*!
......@@ -159,7 +159,7 @@ class Tuple {
* \param i dimension index
* \return the corresponding dimension size
*/
inline ValueType& operator[](index_t i) {
inline ValueType& operator[](size_t i) {
return begin()[i];
}
/*!
......@@ -167,7 +167,7 @@ class Tuple {
* \param i dimension index
* \return the corresponding dimension size
*/
inline const ValueType& operator[](index_t i) const {
inline const ValueType& operator[](size_t i) const {
return begin()[i];
}
/*!
......@@ -275,55 +275,48 @@ class Tuple {
/*!
* \brief save the content into binary stream
* \param strm the output stream
* \tparam DType data type that save to
* \tparam TStream any stream type that have write
*/
template<typename TStream>
inline void Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
strm->Write(begin(), sizeof(ValueType) * ndim_);
}
template<typename DType = ValueType, typename TStream>
inline void Save(TStream *strm) const;
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam DType data type that load from
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
template<typename TStream>
inline bool Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(ValueType) * ndim_;
if (strm->Read(begin(), nread) != nread) return false;
return true;
}
template<typename DType = ValueType, typename TStream>
inline bool Load(TStream *strm);
protected:
// stack cache size
static const uint32_t kStackCache = 4;
/*! \brief number of dimension of the tuple */
index_t ndim_{0};
uint32_t ndim_{0};
/*! \brief number of cells allocated in data_heap_ */
index_t num_heap_allocated_{0};
uint32_t num_heap_allocated_{0};
/*! \brief in stack space used to store shape when it is small */
ValueType data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
ValueType* data_heap_{nullptr};
// internal function to change the dimension
inline void SetDim(index_t dim) {
if (dim > kStackCache &&
dim > num_heap_allocated_) {
inline void SetDim(uint32_t ndim) {
if (ndim > kStackCache &&
ndim > num_heap_allocated_) {
delete [] data_heap_;
data_heap_ = new ValueType[dim];
num_heap_allocated_ = dim;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
}
ndim_ = dim;
ndim_ = ndim;
}
};
/*!
* \brief A Shape class that is used to represent shape of each tensor.
*/
class TShape : public Tuple<index_t> {
class TShape : public Tuple<dim_t> {
public:
/*! \brief default constructor */
TShape() = default;
......@@ -331,7 +324,7 @@ class TShape : public Tuple<index_t> {
* constructor to construct a shape with all 1.
* \param ndim the number of dimension
*/
inline TShape(index_t ndim) { // NOLINT(*)
inline TShape(uint32_t ndim) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
}
......@@ -339,21 +332,21 @@ class TShape : public Tuple<index_t> {
* \brief copy constructor of TShape
* \param s source shape.
*/
inline TShape(const Tuple<index_t>& s) { // NOLINT(*)
inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
inline TShape(std::initializer_list<index_t> init) {
inline TShape(std::initializer_list<dim_t> init) {
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor.
* \param s source shape.
*/
inline TShape(Tuple<index_t>&& s) { // NOLINT(*)
inline TShape(Tuple<dim_t>&& s) { // NOLINT(*)
this->swap(s);
}
/*!
......@@ -372,7 +365,7 @@ class TShape : public Tuple<index_t> {
* \param src source shape.
* \return self.
*/
inline TShape& operator=(const Tuple<index_t>& src) {
inline TShape& operator=(const Tuple<dim_t>& src) {
this->assign(src.begin(), src.end());
return *this;
}
......@@ -381,15 +374,15 @@ class TShape : public Tuple<index_t> {
* \param src source shape.
* \return self.
*/
inline TShape& operator=(Tuple<index_t>&& src) { // NOLINT(*)
inline TShape& operator=(Tuple<dim_t>&& src) { // NOLINT(*)
TShape(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return total number of elements in the shape */
inline size_t Size() const {
size_t size = 1;
const index_t* start = begin(), *fin = end();
for (const index_t* it = start; it != fin; ++it) {
dim_t size = 1;
const dim_t* start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
size *= *it;
}
return size;
......@@ -399,20 +392,20 @@ class TShape : public Tuple<index_t> {
* \param dimstart start dimension
* \param dimend end dimension
*/
inline index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
const index_t *d = this->data();
inline size_t ProdShape(int dimstart, int dimend) const {
dim_t num = 1;
const dim_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*! \return the begin data pointer to content of the tuple */
inline const index_t *data() const {
inline const dim_t *data() const {
return begin();
}
/*! \return the begin data pointer to content of the tuple */
inline index_t *data() {
inline dim_t *data() {
return begin();
}
#ifdef MSHADOW_XINLINE
......@@ -445,7 +438,7 @@ class TShape : public Tuple<index_t> {
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, static_cast<int>(ndim()))
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const index_t *d = this->data();
const dim_t *d = this->data();
mshadow::Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
......@@ -459,10 +452,10 @@ class TShape : public Tuple<index_t> {
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
const index_t *d = this->data();
const dim_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
index_t ymax = 1;
for (index_t i = 1; i < ndim(); ++i) {
dim_t ymax = 1;
for (size_t i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
......@@ -474,22 +467,22 @@ class TShape : public Tuple<index_t> {
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
const index_t *d = this->data();
const dim_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
for (index_t i = 0; i < axis_begin; ++i) {
for (size_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis_begin; i <= axis_end; ++i) {
for (size_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim(); ++i) {
for (size_t i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
......@@ -499,7 +492,7 @@ class TShape : public Tuple<index_t> {
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis) const {
inline mshadow::Shape<3> FlatTo3D(size_t axis) const {
return FlatTo3D(axis, axis);
}
inline bool operator==(const TShape &s) const {
......@@ -517,8 +510,8 @@ class TShape : public Tuple<index_t> {
template<int dim>
inline bool operator==(const mshadow::Shape<dim> &s) const {
if (ndim_ != dim) return false;
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (index_t i = 0; i < dim; ++i) {
const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (size_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
return true;
......@@ -535,6 +528,57 @@ class TShape : public Tuple<index_t> {
#endif
};
/*! \brief helper function to cast type of container elements */
template<typename SrcIter, typename DstIter>
inline DstIter ShapeTypeCast(const SrcIter begin,
const SrcIter end,
DstIter dst_begin) {
typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
typedef typename std::iterator_traits<DstIter>::value_type DstDType;
auto cast = [](const SrcDType& dim) { return static_cast<DstDType>(dim); };
return std::transform(begin, end, dst_begin, cast);
}
/*! \brief helper function to transform a container to TShape with type cast */
template<typename SrcIter>
inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) {
size_t ndim = std::distance(begin, end);
TShape res(ndim);
ShapeTypeCast(begin, end, res.begin());
return res;
}
/*! \tparam ValueType The type of data stored inside tuple. */
template<typename ValueType>
template<typename DType, typename TStream>
inline void Tuple<ValueType>::Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
if (typeid(DType) == typeid(ValueType)) {
strm->Write(begin(), sizeof(ValueType) * ndim_);
} else {
std::vector<DType> buffer(ndim_);
ShapeTypeCast(begin(), end(), buffer.data());
strm->Write(buffer.data(), sizeof(DType) * ndim_);
}
}
/*! \tparam ValueType The type of data stored inside tuple. */
template<typename ValueType>
template<typename DType, typename TStream>
inline bool Tuple<ValueType>::Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(DType) * ndim_;
if (typeid(DType) == typeid(ValueType)) {
if (strm->Read(begin(), nread) != nread) return false;
} else {
std::vector<DType> buffer(ndim_);
if (strm->Read(buffer.data(), nread) != nread) return false;
ShapeTypeCast(buffer.begin(), buffer.end(), begin());
}
return true;
}
} // namespace nnvm
#endif // NNVM_TUPLE_H_
......@@ -156,7 +156,7 @@ Graph Gradient(Graph src) {
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
} else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
for (index_t i = 0; i < fwd_node->num_inputs(); ++i) {
for (size_t i = 0; i < fwd_node->num_inputs(); ++i) {
std::ostringstream os;
if (1 == fwd_node->num_inputs()) {
os << fwd_node->attrs.name << "_backward";
......
......@@ -17,7 +17,7 @@ TEST(Tuple, Basic) {
std::istringstream is(os.str());
is >> y;
CHECK_EQ(x, y);
Tuple<nnvm::index_t> ss{1, 2, 3};
Tuple<nnvm::dim_t> ss{1, 2, 3};
TShape s = ss;
s = std::move(ss);
CHECK((s == TShape{1, 2, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment