/*!
 *  Copyright (c) 2016 by Contributors
 * \file packed_func_ext.h
 * \brief Extension package to PackedFunc
 *   This enales pass NodeRef types into/from PackedFunc.
 */
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <type_traits>

#include "./base.h"
#include "./expr.h"
#include "./runtime/packed_func.h"

namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
 * \brief Runtime type checker for node type.
 * \tparam T the type to be checked.
 */
template<typename T>
struct NodeTypeChecker {
  static inline bool Check(Node* sptr) {
    // This is the only place in the project where RTTI is used
    // It can be turned off, but will make non strict checking.
    // TODO(tqchen) possibly find alternative to turn of RTTI
    using ContainerType = typename T::ContainerType;
    return sptr->derived_from<ContainerType>();
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    using ContainerType = typename T::ContainerType;
    os << ContainerType::_type_key;
  }
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
  static inline bool Check(Node* sptr) {
    if (sptr == nullptr) return false;
    if (!sptr->is_type<ArrayNode>()) return false;
    ArrayNode* n = static_cast<ArrayNode*>(sptr);
    for (const auto& p : n->data) {
      if (!NodeTypeChecker<T>::Check(p.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "array<";
    NodeTypeChecker<T>::PrintName(os);
    os << ">";
  }
};

template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
  static inline bool Check(Node* sptr) {
    if (sptr == nullptr) return false;
    if (!sptr->is_type<MapNode>()) return false;
    MapNode* n = static_cast<MapNode*>(sptr);
    for (const auto& kv : n->data) {
      if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "map<";
    NodeTypeChecker<K>::PrintName(os);
    os << ',';
    NodeTypeChecker<V>::PrintName(os);
    os << '>';
  }
};

template<typename T>
inline std::string NodeTypeName() {
  std::ostringstream os;
  NodeTypeChecker<T>::PrintName(os);
  return os.str();
}

// extensions for tvm arg value

template<typename TNodeRef, typename>
inline TVMArgValue::operator TNodeRef() const {
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
  if (type_code_ == kNull) return TNodeRef();
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<TNodeRef>()
      << " but get " << sptr->type_key();
  return TNodeRef(sptr);
}

inline TVMArgValue::operator Halide::Expr() const {
  if (type_code_ == kNull) return Expr();
  if (type_code_ == kInt) {
    return Expr(static_cast<int>(value_.v_int64));
  }
  if (type_code_ == kFloat) {
    return Expr(static_cast<float>(value_.v_float64));
  }
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
  if (sptr->is_type<IterVarNode>()) {
    return IterVar(sptr)->var;
  }
  CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<Expr>()
      << " but get " << sptr->type_key();
  return Expr(sptr);
}

inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  return *ptr<std::shared_ptr<Node> >();
}


template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr =
      *ptr<std::shared_ptr<Node> >();
  return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}

// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
    const std::shared_ptr<Node>& other) {
  if (other.get() == nullptr) {
    SwitchToPOD(kNull);
  } else {
    SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
  }
  return *this;
}

inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
  if (!other.defined()) {
    SwitchToPOD(kNull);
  } else {
    SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
  }
  return *this;
}

template<typename TNodeRef, typename>
inline TVMRetValue::operator TNodeRef() const {
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
  if (type_code_ == kNull) return TNodeRef();
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  return TNodeRef(*ptr<std::shared_ptr<Node> >());
}

inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const {  // NOLINT(*)
  values_[i].v_handle = &(other.node_);
  type_codes_[i] = kNodeHandle;
}

// type related stuffs
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
  return this->operator=(Type2TVMType(t));
}

inline TVMRetValue::operator Halide::Type() const {
  return TVMType2Type(operator TVMType());
}

inline TVMArgValue::operator Halide::Type() const {
  return TVMType2Type(operator TVMType());
}

inline void TVMArgsSetter::operator()(
    size_t i, const Halide::Type& t) const {
  this->operator()(i, Type2TVMType(t));
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_PACKED_FUNC_EXT_H_