/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file attrs.cc
 */
#include <tvm/attrs.h>
#include <tvm/api_registry.h>
#include "attr_functor.h"

namespace tvm {

void DictAttrsNode::VisitAttrs(AttrVisitor* v)  {
  v->Visit("__dict__", &dict);
}

void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
  v->Visit("__dict__", &dict);
}

void DictAttrsNode::InitByPackedArgs(
    const runtime::TVMArgs& args, bool allow_unknown) {
  for (int i = 0; i < args.size(); i += 2) {
    std::string key = args[i];
    runtime::TVMArgValue val = args[i + 1];
    if (val.type_code() == kNodeHandle) {
      dict.Set(key, val.operator NodeRef());
    } else if (val.type_code() == kStr) {
      dict.Set(key, Expr(val.operator std::string()));
    } else {
      dict.Set(key, val.operator Expr());
    }
  }
}

Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
  return {};
}

Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
  NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
  n->dict = std::move(dict);
  return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
    p->stream << op->dict;
});

TVM_REGISTER_NODE_TYPE(DictAttrsNode);

TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);


using namespace ir;
// Equal handler.
bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
  if (lhs.same_as(rhs)) return true;
  if (!lhs.defined() || !rhs.defined()) return false;
  return this->VisitAttr(lhs, rhs);
}

bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
  if (lhs->derived_from<BaseAttrsNode>()) {
    AttrsEqual equal;
    equal.handler_ = this;
    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
        other.get(), equal);
  }
  return lhs == other.get();
}

bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<IntImm>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<UIntImm>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<FloatImm>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<StringImm>()) {
    return lhs->value == rhs->value;
  }
  return false;
}

bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<ArrayNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (size_t  i = 0; i < lhs->data.size(); ++i) {
      if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
    }
  }
  return true;
}

bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<StrMapNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (const auto& kv : lhs->data) {
      auto it = rhs->data.find(kv.first);
      if (it == rhs->data.end()) return false;
      if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
    }
  }
  return true;
}

#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
    if (const auto* rhs = other.as<NodeName>()) {                       \
      if (!Equal(lhs->a, rhs->a)) return false;                         \
      if (!Equal(lhs->b, rhs->b)) return false;                         \
      return true;                                                      \
    } else {                                                            \
      return false;                                                     \
    }                                                                   \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);

bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Not>()) {
    return Equal(lhs->a, rhs->a);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Cast>()) {
    if (lhs->type != rhs->type) return false;
    return Equal(lhs->value, rhs->value);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Call>()) {
    return
        lhs->name == rhs->name &&
        lhs->type == rhs->type &&
        lhs->call_type == rhs->call_type &&
        Equal(lhs->args, rhs->args);
  } else {
    return false;
  }
}

bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Select>()) {
    return
        Equal(lhs->condition, rhs->condition) &&
        Equal(lhs->true_value, rhs->true_value) &&
        Equal(lhs->false_value, rhs->false_value);
  } else {
    return false;
  }
}

// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
  if (value->derived_from<BaseAttrsNode>()) {
    AttrsHash hasher;
    hasher.handler_ = this;
    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
  } else {
    return NodeHash()(GetRef<NodeRef>(value));
  }
}

size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
  return std::hash<int64_t>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
  return std::hash<uint64_t>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
  return std::hash<double>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
  return std::hash<std::string>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
  size_t result = op->data.size();
  for (size_t  i = 0; i < op->data.size(); ++i) {
    result = Combine(result, this->Hash(NodeRef(op->data[i])));
  }
  return result;
}

size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
    using Entry = std::pair<std::string, NodePtr<Node> >;
    std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
    std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
        return a.first < b.first;
      });
    size_t result = 0;
    for (const Entry& kv : data) {
      result = Combine(result, std::hash<std::string>()(kv.first));
      result = Combine(result, this->Hash(NodeRef(kv.second)));
    }
    return result;
}


#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);

size_t AttrsHashHandler::VisitAttr_(const Not* op) {
  static size_t key = std::hash<std::string>()(Not::_type_key);
  return Combine(key, Hash(op->a));
}

size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
  static size_t key = std::hash<std::string>()(Cast::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->type));
  res = Combine(res, Hash(op->value));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const Call* op) {
  static size_t key = std::hash<std::string>()(Call::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->name));
  res = Combine(res, hasher(op->type));
  res = Combine(res, Hash(op->args));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const Select* op) {
  static size_t key = std::hash<std::string>()(Select::_type_key);
  size_t res = key;
  res = Combine(res, Hash(op->condition));
  res = Combine(res, Hash(op->true_value));
  res = Combine(res, Hash(op->false_value));
  return res;
}


// Default case
bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
  if (lhs.same_as(rhs)) return true;
  if (handler_ == nullptr) {
    return AttrsEqualHandler().Equal(lhs, rhs);
  } else {
    return handler_->Equal(lhs, rhs);
  }
}

size_t AttrsHash::operator()(const NodeRef& node) const {
  if (!node.defined()) return 0;
  if (handler_ == nullptr) {
    return AttrsHashHandler().Hash(node);
  } else {
    return handler_->Hash(node);
  }
}

size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
  return hasher(this->dict);
}

bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
  if (this == other) return true;
  if (other == nullptr) return false;
  if (this->type_index() != other->type_index()) return false;
  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}

TVM_REGISTER_API("_AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = args[0].operator Attrs()->ListFieldInfo();
});

}  // namespace tvm