Commit 0858c5ad by Lianmin Zheng Committed by Tianqi Chen

[IR] Make iterators compatible with constructors of STL containers (#3624)

parent 97e333ca
...@@ -110,18 +110,28 @@ template<typename Converter, ...@@ -110,18 +110,28 @@ template<typename Converter,
typename TIter> typename TIter>
class IterAdapter { class IterAdapter {
public: public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename std::iterator_traits<TIter>::value_type;
using pointer = typename std::iterator_traits<TIter>::pointer;
using reference = typename std::iterator_traits<TIter>::reference;
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
explicit IterAdapter(TIter iter) : iter_(iter) {} explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() { // NOLINT(*) inline IterAdapter& operator++() {
++iter_;
return *this;
}
inline IterAdapter& operator++(int) { // NOLINT(*)
++iter_; ++iter_;
return *this; return *this;
} }
inline IterAdapter operator+(int offset) const { // NOLINT(*) inline IterAdapter operator+(difference_type offset) const {
return IterAdapter(iter_ + offset); return IterAdapter(iter_ + offset);
} }
template<typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
typename T::difference_type>::type
inline operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}
inline bool operator==(IterAdapter other) const { inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_; return iter_ == other.iter_;
} }
......
...@@ -35,16 +35,6 @@ ...@@ -35,16 +35,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
template<typename T>
inline std::vector<T> AsVector(const Array<T> &array) {
std::vector<T> result;
result.reserve(array.size());
for (const T& ele : array) {
result.push_back(ele);
}
return result;
}
/*! Quick helper macro /*! Quick helper macro
* - Expose a positional make function to construct the node. * - Expose a positional make function to construct the node.
* - Register op to the registry. * - Register op to the registry.
......
...@@ -229,7 +229,7 @@ bool ArgReduceRel(const Array<Type>& types, ...@@ -229,7 +229,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0); CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape); std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -254,7 +254,7 @@ bool ReduceRel(const Array<Type>& types, ...@@ -254,7 +254,7 @@ bool ReduceRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
std::vector<IndexExpr>&& in_shape = AsVector(data->shape); std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
......
...@@ -265,7 +265,7 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -265,7 +265,7 @@ bool ConcatenateRel(const Array<Type>& types,
} }
axis = axis < 0 ? ndim + axis : axis; axis = axis < 0 ? ndim + axis : axis;
// Calculate shape // Calculate shape
std::vector<IndexExpr>&& oshape = AsVector(first->shape); std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
IndexExpr &concat_dim = oshape[axis]; IndexExpr &concat_dim = oshape[axis];
bool has_any = false; bool has_any = false;
if (concat_dim.as<Any>()) { if (concat_dim.as<Any>()) {
...@@ -834,7 +834,7 @@ bool TakeRel(const Array<Type>& types, ...@@ -834,7 +834,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
if (!param->axis.defined()) { if (!param->axis.defined()) {
std::vector<IndexExpr>&& oshape = AsVector(indices->shape); std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true; return true;
} }
...@@ -1990,7 +1990,7 @@ bool SplitRel(const Array<Type>& types, ...@@ -1990,7 +1990,7 @@ bool SplitRel(const Array<Type>& types,
<< "indices_or_sections need to be able to divide input.shape[axis]"; << "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields; std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) { for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] /= int32_t(sections->value); oshape[axis] /= int32_t(sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
...@@ -2003,7 +2003,7 @@ bool SplitRel(const Array<Type>& types, ...@@ -2003,7 +2003,7 @@ bool SplitRel(const Array<Type>& types,
for (unsigned int i = 0; i < indices.size(); ++i) { for (unsigned int i = 0; i < indices.size(); ++i) {
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list"; << "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = IndexExpr(indices[i]) - begin; oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]); begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
...@@ -2011,7 +2011,7 @@ bool SplitRel(const Array<Type>& types, ...@@ -2011,7 +2011,7 @@ bool SplitRel(const Array<Type>& types,
} }
CHECK(reporter->Assert(begin < data->shape[axis])) CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]"; << "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr>&& oshape = AsVector(data->shape); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = data->shape[axis] - begin; oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype); auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type); fields.push_back(vec_type);
...@@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array<Type>& types, ...@@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array<Type>& types,
const auto param = attrs.as<SliceLikeAttrs>(); const auto param = attrs.as<SliceLikeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Array<IndexExpr> dshape = data->shape; const Array<IndexExpr>& dshape = data->shape;
const Array<IndexExpr> target_shape = target->shape; const Array<IndexExpr>& target_shape = target->shape;
std::vector<IndexExpr>&& oshape = AsVector(dshape); std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
if (!param->axes.defined()) { if (!param->axes.defined()) {
for (size_t i = 0; i < dshape.size(); ++i) { for (size_t i = 0; i < dshape.size(); ++i) {
......
...@@ -53,7 +53,7 @@ bool YoloReorgRel(const Array<Type>& types, ...@@ -53,7 +53,7 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK(param != nullptr); CHECK(param != nullptr);
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension."; CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
std::vector<IndexExpr>&& oshape = AsVector(data->shape); std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[1] = oshape[1] * param->stride * param->stride; oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = oshape[2] / param->stride; oshape[2] = oshape[2] / param->stride;
oshape[3] = oshape[3] / param->stride; oshape[3] = oshape[3] / param->stride;
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
* under the License. * under the License.
*/ */
#include <vector>
#include <unordered_map>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
...@@ -42,6 +44,13 @@ TEST(Array, Mutate) { ...@@ -42,6 +44,13 @@ TEST(Array, Mutate) {
CHECK(list2[1].same_as(z)); CHECK(list2[1].same_as(z));
} }
TEST(Array, Iterator) {
using namespace tvm;
Array<Expr> array{1, 2, 3};
std::vector<Expr> vector(array.begin(), array.end());
CHECK(vector[1].as<IntImm>()->value == 2);
}
TEST(Map, Expr) { TEST(Map, Expr) {
using namespace tvm; using namespace tvm;
Var x("x"); Var x("x");
...@@ -86,6 +95,14 @@ TEST(Map, Mutate) { ...@@ -86,6 +95,14 @@ TEST(Map, Mutate) {
LOG(INFO) << dict; LOG(INFO) << dict;
} }
TEST(Map, Iterator) {
using namespace tvm;
Expr a = 1, b = 2;
Map<Expr, Expr> map1{{a, b}};
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(), map1.end());
CHECK(map2[a].as<IntImm>()->value == 2);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment