Commit 0858c5ad by Lianmin Zheng Committed by Tianqi Chen

[IR] Make iterators compatible with constructors of STL containers (#3624)

parent 97e333ca
......@@ -110,18 +110,28 @@ template<typename Converter,
typename TIter>
class IterAdapter {
public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename std::iterator_traits<TIter>::value_type;
using pointer = typename std::iterator_traits<TIter>::pointer;
using reference = typename std::iterator_traits<TIter>::reference;
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter& operator++(int) { // NOLINT(*)
inline IterAdapter& operator++() {
++iter_;
return *this;
}
inline IterAdapter operator+(int offset) const { // NOLINT(*)
inline IterAdapter operator+(difference_type offset) const {
return IterAdapter(iter_ + offset);
}
template<typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
typename T::difference_type>::type
inline operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
......
......@@ -35,16 +35,6 @@
namespace tvm {
namespace relay {
template<typename T>
inline std::vector<T> AsVector(const Array<T> &array) {
std::vector<T> result;
result.reserve(array.size());
for (const T& ele : array) {
result.push_back(ele);
}
return result;
}
/*! Quick helper macro
* - Expose a positional make function to construct the node.
* - Register op to the registry.
......
......@@ -229,7 +229,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......@@ -254,7 +254,7 @@ bool ReduceRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......
......@@ -265,7 +265,7 @@ bool ConcatenateRel(const Array<Type>& types,
}
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape
std::vector<IndexExpr>&& oshape = AsVector(first->shape);
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
IndexExpr &concat_dim = oshape[axis];
bool has_any = false;
if (concat_dim.as<Any>()) {
......@@ -834,7 +834,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(param != nullptr);
if (!param->axis.defined()) {
std::vector<IndexExpr>&& oshape = AsVector(indices->shape);
std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
......@@ -1990,7 +1990,7 @@ bool SplitRel(const Array<Type>& types,
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] /= int32_t(sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
......@@ -2003,7 +2003,7 @@ bool SplitRel(const Array<Type>& types,
for (unsigned int i = 0; i < indices.size(); ++i) {
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
......@@ -2011,7 +2011,7 @@ bool SplitRel(const Array<Type>& types,
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
......@@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array<Type>& types,
const auto param = attrs.as<SliceLikeAttrs>();
CHECK(param != nullptr);
const Array<IndexExpr> dshape = data->shape;
const Array<IndexExpr> target_shape = target->shape;
std::vector<IndexExpr>&& oshape = AsVector(dshape);
const Array<IndexExpr>& dshape = data->shape;
const Array<IndexExpr>& target_shape = target->shape;
std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
if (!param->axes.defined()) {
for (size_t i = 0; i < dshape.size(); ++i) {
......
......@@ -53,7 +53,7 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK(param != nullptr);
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride;
oshape[3] = oshape[3] / param->stride;
......
......@@ -17,6 +17,8 @@
* under the License.
*/
#include <vector>
#include <unordered_map>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/packed_func_ext.h>
......@@ -42,6 +44,13 @@ TEST(Array, Mutate) {
CHECK(list2[1].same_as(z));
}
TEST(Array, Iterator) {
using namespace tvm;
Array<Expr> array{1, 2, 3};
std::vector<Expr> vector(array.begin(), array.end());
CHECK(vector[1].as<IntImm>()->value == 2);
}
TEST(Map, Expr) {
using namespace tvm;
Var x("x");
......@@ -86,6 +95,14 @@ TEST(Map, Mutate) {
LOG(INFO) << dict;
}
TEST(Map, Iterator) {
using namespace tvm;
Expr a = 1, b = 2;
Map<Expr, Expr> map1{{a, b}};
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(), map1.end());
CHECK(map2[a].as<IntImm>()->value == 2);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment