/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file relay/backend/compile_engine.h
 * \brief Internal compialtion engine handle function cache.
 *  and interface to low level code generation.
 */
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/lowered_func.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/pass.h>
#include <string>
#include <functional>

namespace tvm {
namespace relay {

/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
  /* \brief compiled target */
  tvm::Target target;
  /*! \brief Function name */
  std::string func_name;
  /* \brief The inputs to the function */
  tvm::Array<Tensor> inputs;
  /* \brief The outputs to the function */
  tvm::Array<Tensor> outputs;
  /*! \brief The lowered functions to support the function. */
  tvm::Array<tvm::LoweredFunc> funcs;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("target", &target);
    v->Visit("func_name", &func_name);
    v->Visit("inputs", &inputs);
    v->Visit("outputs", &outputs);
    v->Visit("funcs", &funcs);
  }

  static constexpr const char* _type_key = "relay.CachedFunc";
  TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node);
};

TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode);


class CCacheKey;
/*! \brief Compile cache key */
class CCacheKeyNode : public Node {
 public:
  /*! \brief The source function to be lowered. */
  Function source_func;
  /*! \brief The hardware target.*/
  Target target;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("source_func", &source_func);
    v->Visit("target", &target);
  }
  /*! \return The hash value of CCacheKey. */
  inline size_t Hash() const;
  /*!
   * \brief check content equality
   * \param other The other value.
   * \return The result of equality check.
   */
  inline bool Equal(const CCacheKeyNode* other) const;
  /*!
   * \brief create a cache key.
   * \param source_func The source function.
   * \param target The target device.
   * \return the created key.
   */
  TVM_DLL static CCacheKey make(Function source_func,
                                Target target);

  static constexpr const char* _type_key = "relay.CCacheKey";
  TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node);

 private:
  /*!
   * \brief internal cached hash value.
   */
  mutable size_t hash_{0};
};

/*! \brief cache entry used in compile engine */
class CCacheKey : public NodeRef {
 public:
  CCacheKey() {}
  explicit CCacheKey(NodePtr<Node> n) : NodeRef(n) {}
  const CCacheKeyNode* operator->() const {
    return static_cast<CCacheKeyNode*>(node_.get());
  }
  // comparator
  inline bool operator==(const CCacheKey& other) const {
    CHECK(defined() && other.defined());
    return (*this)->Equal(other.operator->());
  }
  using ContainerType = CCacheKeyNode;
};

/*! \brief Node container for compile cache. */
class CCacheValueNode : public Node {
 public:
  /*! \brief The corresponding function */
  CachedFunc cached_func;
  /*! \brief Result of Packed function generated by JIT */
  PackedFunc packed_func;
  /*! \brief usage statistics */
  int use_count{0};

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("cached_func", &cached_func);
    v->Visit("use_count", &use_count);
  }
  static constexpr const char* _type_key = "relay.CCacheValue";
  TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node);
};

/*! \brief cache entry used in compile engine */
class CCacheValue : public NodeRef {
 public:
  CCacheValue() {}
  explicit CCacheValue(NodePtr<Node> n) : NodeRef(n) {}
  CCacheValueNode* operator->() {
    return static_cast<CCacheValueNode*>(node_.get());
  }
  const CCacheValueNode* operator->() const {
    return static_cast<const CCacheValueNode*>(node_.get());
  }
  using ContainerType = CCacheValueNode;
};

/*!
 * \brief Backend compilation engine for
 *        low level code generation.
 */
class CompileEngineNode : public Node {
 public:
  /*!
   * \brief Get lowered result.
   * \param key The key to the cached function.
   * \return The result.
   */
  virtual CachedFunc Lower(const CCacheKey& key) = 0;
  /*!
   * \brief Just in time compile to get a PackedFunc.
   * \param key The key to the cached function.
   * \return The result.
   */
  virtual PackedFunc JIT(const CCacheKey& key) = 0;
  /*! \brief clear the cache. */
  virtual void Clear() = 0;

  // VisitAttrs
  void VisitAttrs(AttrVisitor*) final {}

  static constexpr const char* _type_key = "relay.CompileEngine";
  TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};

/*! \brier cache entry used in compile engine */
class CompileEngine : public NodeRef {
 public:
  CompileEngine() {}
  explicit CompileEngine(NodePtr<Node> n) : NodeRef(n) {}
  CompileEngineNode* operator->() {
    return static_cast<CompileEngineNode*>(node_.get());
  }
  using ContainerType = CompileEngineNode;
  /*! \brief The global compile engine. */
  TVM_DLL static const CompileEngine& Global();
};

// implementations
inline size_t CCacheKeyNode::Hash() const {
  if (hash_ != 0) return hash_;
  // do structral hash, avoid 0.
  hash_ = StructuralHash()(this->source_func);
  hash_ = dmlc::HashCombine(
      hash_, std::hash<std::string>()(target->str()));
  if (hash_ == 0) hash_ = 1;
  return hash_;
}

inline bool CCacheKeyNode::Equal(
    const CCacheKeyNode* other) const {
  if (Hash() != other->Hash()) return false;
  return this->target->str() == other->target->str() &&
      AlphaEqual(this->source_func, other->source_func);
}

}  // namespace relay
}  // namespace tvm

namespace std {
// overload hash
template<>
struct hash<::tvm::relay::CCacheKey> {
  size_t operator()(const ::tvm::relay::CCacheKey& key) const {
    CHECK(key.defined());
    return key->Hash();
  }
};
}  // namespace std
#endif  // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_