/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/relay/op/layout.cc
 * \brief Layout expression.
 */

#include "layout.h"

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(LayoutNode);

std::vector<IndexExpr> ConvertLayout(
    std::vector<IndexExpr> src,
    const Layout& src_layout,
    const Layout& dst_layout) {
  CHECK_EQ(src_layout.ndim(), src.size());
  if (src_layout == dst_layout) {
    return src;
  } else if (!src_layout.defined()) {
    LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
  } else if (!dst_layout.defined()) {
    LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
  }

  CHECK(src_layout.Convertible(dst_layout))
    << "cannot convert from "
    << src_layout << " to " << dst_layout;

  std::vector<IndexExpr> dst(dst_layout.ndim());
  for (size_t i = 0; i < src_layout.ndim(); ++i) {
    Layout::LayoutDim src_dim = src_layout[i];
    if (Layout::IsSuperdim(src_dim)) {
      int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
      int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
      int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
      int src_factor = src_layout.Subsizeof(src_dim);
      int dst_factor = dst_layout.Subsizeof(src_dim);
      IndexExpr src_dim_size = src[i];

      if (src_minor_pos >= 0) {
        CHECK(is_const_int(src[src_minor_pos], src_factor))
          << "src shape " << Array<IndexExpr>(src)
          << " does not agree with layout "
          << src_layout;
        src_dim_size *= src_factor;
      }
      dst[dst_major_pos] = src_dim_size;
      if (dst_minor_pos >= 0) {
        CHECK_GT(dst_factor, 0);
        if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
          CHECK_LE(dst_factor, const_src_dim_size[0])
            << "Converting " << Array<IndexExpr>(src)
            << " from " << src_layout
            << " to " << dst_layout
            << ": cannot split dimension size of "
            << src_dim_size << " by " << dst_factor;
        }
        dst[dst_major_pos] /= dst_factor;
        dst[dst_minor_pos] = dst_factor;
      }
    }
  }
  return dst;
}

std::vector<IndexExpr> ConvertLayout(
    const Array<IndexExpr>& src,
    const Layout& src_layout,
    const Layout& dst_layout) {
  std::vector<IndexExpr> ret(src.size());
  for (size_t i = 0; i < src.size(); ++i) {
    ret[i] = src[i];
  }
  return ConvertLayout(ret, src_layout, dst_layout);
}

}  // namespace relay
}  // namespace tvm