/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file attrs.cc
 */
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include "attr_functor.h"

namespace tvm {

void DictAttrsNode::VisitAttrs(AttrVisitor* v)  {
  v->Visit("__dict__", &dict);
}

void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
  v->Visit("__dict__", &dict);
}

void DictAttrsNode::InitByPackedArgs(
    const runtime::TVMArgs& args, bool allow_unknown) {
  for (int i = 0; i < args.size(); i += 2) {
    std::string key = args[i];
    runtime::TVMArgValue val = args[i + 1];
    if (val.IsObjectRef<ObjectRef>()) {
      dict.Set(key, val.operator ObjectRef());
    } else if (val.type_code() == kTVMStr) {
      dict.Set(key, PrimExpr(val.operator std::string()));
    } else {
      dict.Set(key, val.operator PrimExpr());
    }
  }
}

Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
  return {};
}

Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
  ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
  n->dict = std::move(dict);
  return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const DictAttrsNode*>(node.get());
    p->stream << op->dict;
});

TVM_REGISTER_NODE_TYPE(DictAttrsNode);

TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);


using namespace tir;
// Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
  if (lhs.same_as(rhs)) return true;
  if (!lhs.defined() || !rhs.defined()) return false;
  return this->VisitAttr(lhs, rhs);
}

bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) {
  if (lhs->IsInstance<BaseAttrsNode>()) {
    AttrsEqual equal;
    equal.handler_ = this;
    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
        other.get(), equal);
  }
  return lhs == other.get();
}

bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<IntImmNode>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<FloatImmNode>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<StringImmNode>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<ArrayNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (size_t i = 0; i < lhs->data.size(); ++i) {
      if (!Equal(lhs->data[i], rhs->data[i])) return false;
    }
  }
  return true;
}

bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<StrMapNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (const auto& kv : lhs->data) {
      auto it = rhs->data.find(kv.first);
      if (it == rhs->data.end()) return false;
      if (!Equal(kv.second, it->second)) return false;
    }
  }
  return true;
}

#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
    if (const auto* rhs = other.as<NodeName>()) {                       \
      if (!Equal(lhs->a, rhs->a)) return false;                         \
      if (!Equal(lhs->b, rhs->b)) return false;                         \
      return true;                                                      \
    } else {                                                            \
      return false;                                                     \
    }                                                                   \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);

bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<NotNode>()) {
    return Equal(lhs->a, rhs->a);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<CastNode>()) {
    if (lhs->dtype != rhs->dtype) return false;
    return Equal(lhs->value, rhs->value);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<CallNode>()) {
    return
        lhs->name == rhs->name &&
        lhs->dtype == rhs->dtype &&
        lhs->call_type == rhs->call_type &&
        Equal(lhs->args, rhs->args);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<SelectNode>()) {
    return
        Equal(lhs->condition, rhs->condition) &&
        Equal(lhs->true_value, rhs->true_value) &&
        Equal(lhs->false_value, rhs->false_value);
  } else {
    return false;
  }
}

// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
  if (value->IsInstance<BaseAttrsNode>()) {
    AttrsHash hasher;
    hasher.handler_ = this;
    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
  } else {
    return ObjectHash()(GetRef<ObjectRef>(value));
  }
}

size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
  return std::hash<int64_t>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
  return std::hash<double>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
  return std::hash<std::string>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
  size_t result = op->data.size();
  for (size_t  i = 0; i < op->data.size(); ++i) {
    result = Combine(result, this->Hash(op->data[i]));
  }
  return result;
}

size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
    using Entry = std::pair<std::string, ObjectRef>;
    std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
    std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
        return a.first < b.first;
      });
    size_t result = 0;
    for (const Entry& kv : data) {
      result = Combine(result, std::hash<std::string>()(kv.first));
      result = Combine(result, this->Hash(kv.second));
    }
    return result;
}


#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);

size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
  static size_t key = std::hash<std::string>()(NotNode::_type_key);
  return Combine(key, Hash(op->a));
}

size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
  static size_t key = std::hash<std::string>()(CastNode::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->dtype));
  res = Combine(res, Hash(op->value));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
  static size_t key = std::hash<std::string>()(CallNode::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->name));
  res = Combine(res, hasher(op->dtype));
  res = Combine(res, Hash(op->args));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
  static size_t key = std::hash<std::string>()(SelectNode::_type_key);
  size_t res = key;
  res = Combine(res, Hash(op->condition));
  res = Combine(res, Hash(op->true_value));
  res = Combine(res, Hash(op->false_value));
  return res;
}


// Default case
bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
  if (lhs.same_as(rhs)) return true;
  if (handler_ == nullptr) {
    return AttrsEqualHandler().Equal(lhs, rhs);
  } else {
    return handler_->Equal(lhs, rhs);
  }
}

size_t AttrsHash::operator()(const ObjectRef& node) const {
  if (!node.defined()) return 0;
  if (handler_ == nullptr) {
    return AttrsHashHandler().Hash(node);
  } else {
    return handler_->Hash(node);
  }
}

size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
  return hasher(this->dict);
}

bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
  if (this == other) return true;
  if (other == nullptr) return false;
  if (this->type_index() != other->type_index()) return false;
  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}

TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = args[0].operator Attrs()->ListFieldInfo();
});

}  // namespace tvm