/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file reorg.h
 */
#ifndef NNVM_TOP_VISION_YOLO_REORG_H_
#define NNVM_TOP_VISION_YOLO_REORG_H_

#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>

namespace nnvm {
namespace top {

template <typename AttrType,
          bool (*is_none)(const AttrType &),
          bool (*assign)(AttrType *,
          const AttrType &),
          bool reverse_infer,
          std::string (*attr_string)(const AttrType &),
          int n_in = -1,
          int n_out = -1>
inline bool ReorgAttr(const nnvm::NodeAttrs &attrs,
                      std::vector<AttrType> *in_attrs,
                      std::vector<AttrType> *out_attrs,
                      const AttrType &none) {
  AttrType dattr = none;
  size_t in_size = in_attrs->size();
  size_t out_size = out_attrs->size();
  if (n_in != -1) {
    in_size = static_cast<size_t>(n_in);
  }
  if (n_out != -1) {
    out_size = static_cast<size_t>(n_out);
  }

  auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
    for (size_t i = 0; i < size; ++i) {
      if (i == 0) {
        CHECK(assign(&dattr, (*vec)[i]))
            << "Incompatible attr in node " << attrs.name << " at " << i
            << "-th " << name << ": "
            << "expected " << attr_string(dattr) << ", got "
            << attr_string((*vec)[i]);
      }
    }
  };
  deduce(in_attrs, in_size, "input");

  auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
    for (size_t i = 0; i < size; ++i) {
      CHECK(assign(&(*vec)[i], dattr))
          << "Incompatible attr in node " << attrs.name << " at " << i << "-th "
          << name << ": "
          << "expected " << attr_string(dattr) << ", got "
          << attr_string((*vec)[i]);
    }
  };
  write(out_attrs, out_size, "output");

  if (is_none(dattr)) {
    return false;
  }
  return true;
}

template <int n_in, int n_out>
inline bool ReorgShape(const NodeAttrs &attrs,
                       std::vector<TShape> *in_attrs,
                       std::vector<TShape> *out_attrs) {
  if (n_in != -1) {
    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
        << " in operator " << attrs.name;
  }
  if (n_out != -1) {
    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
        << " in operator " << attrs.name;
  }
  return ReorgAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
      attrs, in_attrs, out_attrs, TShape());
}

template <int n_in, int n_out>
inline bool ReorgType(const NodeAttrs &attrs,
                      std::vector<int> *in_attrs,
                      std::vector<int> *out_attrs) {
  if (n_in != -1) {
    CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
        << " in operator " << attrs.name;
  }
  if (n_out != -1) {
    CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
        << " in operator " << attrs.name;
  }
  return ReorgAttr<int, type_is_none, type_assign, true, type_string>(
      attrs, in_attrs, out_attrs, -1);
}

struct ReorgParam : public dmlc::Parameter<ReorgParam> {
  int stride;

  DMLC_DECLARE_PARAMETER(ReorgParam) {
    DMLC_DECLARE_FIELD(stride).set_default(1).describe("Stride value");
  }
};
}  // namespace top
}  // namespace nnvm
#endif  // NNVM_TOP_VISION_YOLO_REORG_H_