/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file rpc_session.h
 * \brief Base RPC session interface.
 */
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <mutex>
#include <string>
#include <memory>
#include <utility>
#include "../../support/ring_buffer.h"

namespace tvm {
namespace runtime {

// Magic header for RPC data plane
const int kRPCMagic = 0xff271;
// magic header for RPC tracker(control plane)
const int kRPCTrackerMagic = 0x2f271;
// sucess response
const int kRPCSuccess = kRPCMagic + 0;
// cannot found matched key in server
const int kRPCMismatch = kRPCMagic + 2;

/*! \brief Enumeration code for the RPC tracker */
enum class TrackerCode : int {
    kFail = -1,
    kSuccess = 0,
    kPing = 1,
    kStop = 2,
    kPut = 3,
    kRequest = 4,
    kUpdateInfo = 5,
    kSummary = 6,
    kGetPendingMatchKeys = 7
};
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;

struct RPCArgBuffer;

/*! \brief The RPC code */
enum class RPCCode : int {
  kNone,
  kCallFunc,
  kReturn,
  kException,
  kShutdown,
  kCopyFromRemote,
  kCopyToRemote,
  kCopyAck,
  // The following are code that can send over CallRemote
  kSystemFuncStart,
  kGetGlobalFunc,
  kGetTimeEvaluator,
  kFreeFunc,
  kDevSetDevice,
  kDevGetAttr,
  kDevAllocData,
  kDevFreeData,
  kDevStreamSync,
  kCopyAmongRemote,
  kModuleLoad,
  kModuleImport,
  kModuleFree,
  kModuleGetFunc,
  kModuleGetSource,
  kNDArrayFree
};

/*!
 * \brief Function that unwraps a remote object to its handle.
 * \param rpc_sess_table_index RPC session table index for validation.
 * \param obj Handle to the object argument.
 * \return The corresponding handle.
 */
typedef void* (*FUnwrapRemoteObject)(
    int rpc_sess_table_index,
    const TVMArgValue& obj);

/*!
 * \brief Abstract channel interface used to create RPCSession.
 */
class RPCChannel {
 public:
  /*! \brief virtual destructor */
  virtual ~RPCChannel() {}
  /*!
   * \brief Send data over to the channel.
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes sent.
   */
  virtual size_t Send(const void* data, size_t size) = 0;
  /*!
   * \brief Recv data from channel.
   *
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes received.
   */
  virtual size_t Recv(void* data, size_t size) = 0;
};

// Bidirectional Communication Session of PackedRPC
class RPCSession {
 public:
  /*! \brief virtual destructor */
  ~RPCSession();
  /*!
   *  \brief The server loop that server runs to handle RPC calls.
   */
  void ServerLoop();
  /*!
   * \brief Message handling function for event driven server.
   *  Called when the server receives a message.
   *  Event driven handler will never call recv on the channel
   *  and always relies on the ServerEventHandler.
   *  to receive the data.
   *
   * \param in_bytes The incoming bytes.
   * \param event_flag  1: read_available, 2: write_avaiable.
   * \return State flag.
   *     1: continue running, no need to write,
   *     2: need to write
   *     0: shutdown
   */
  int ServerEventHandler(const std::string& in_bytes,
                         int event_flag);
  /*!
   * \brief Call into remote function
   * \param handle The function handle
   * \param args The arguments
   * \param rv The return value.
   * \param funpwrap Function that takes a remote object and returns the raw handle.
   * \param fwrap Wrapper function to turn Function/Module handle into real return.
   */
  void CallFunc(RPCFuncHandle handle,
                TVMArgs args,
                TVMRetValue* rv,
                FUnwrapRemoteObject funwrap,
                const PackedFunc* fwrap);
  /*!
   * \brief Copy bytes into remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
   * \param nbytes The size of the memory in bytes.
   * \param ctx_to The target context.
   * \param type_hint Hint of content data type.
   */
  void CopyToRemote(void* from,
                    size_t from_offset,
                    void* to,
                    size_t to_offset,
                    size_t nbytes,
                    TVMContext ctx_to,
                    DLDataType type_hint);
  /*!
   * \brief Copy bytes from remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
   * \param nbytes The size of the memory in bytes.
   * \param ctx_from The source context.
   * \param type_hint Hint of content data type.
   */
  void CopyFromRemote(void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t nbytes,
                      TVMContext ctx_from,
                      DLDataType type_hint);
  /*!
   * \brief Get a remote timer function on ctx.
   *  This function consumes fhandle, caller should not call Free on fhandle.
   *
   * \param fhandle The function handle.
   * \param ctx The ctx to run measurement on.
   * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
   * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
   * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
   * \return A remote timer function
   */
  RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
                                 TVMContext ctx,
                                 int number,
                                 int repeat,
                                 int min_repeat_ms);
  /*!
   * \brief Call a remote defined system function with arguments.
   * \param fcode The function code.
   * \param args The arguments
   * \return The returned remote value.
   */
  template<typename... Args>
  inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
  /*!
   * \return The session table index of the session.
   */
  int table_index() const {
    return table_index_;
  }
  /*!
   * \brief Create a RPC session with given channel.
   * \param channel The communication channel.
   * \param name The local name of the session, used for debug
   * \param remote_key The remote key of the session
   *   if remote_key equals "%toinit", we need to re-intialize
   *   it by event handler.
   */
  static std::shared_ptr<RPCSession> Create(
      std::unique_ptr<RPCChannel> channel,
      std::string name,
      std::string remote_key);
  /*!
   * \brief Try get session from the global session table by table index.
   * \param table_index The table index of the session.
   * \return The shared_ptr to the session, can be nullptr.
   */
  static std::shared_ptr<RPCSession> Get(int table_index);

 private:
  class EventHandler;
  // Handle events until receives a return
  // Also flushes channels so that the function advances.
  RPCCode HandleUntilReturnEvent(
      TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
  // Initalization
  void Init();
  // Shutdown
  void Shutdown();
  // Internal channel.
  std::unique_ptr<RPCChannel> channel_;
  // Internal mutex
  std::recursive_mutex mutex_;
  // Internal ring buffer.
  support::RingBuffer reader_, writer_;
  // Event handler.
  std::shared_ptr<EventHandler> handler_;
  // call remote with specified function code.
  PackedFunc call_remote_;
  // The index of this session in RPC session table.
  int table_index_{0};
  // The name of the session.
  std::string name_;
  // The remote key
  std::string remote_key_;
};

/*!
 * \brief RPC channel which callback
 * frontend (Python/Java/etc.)'s send & recv function
 */
class CallbackChannel final : public RPCChannel {
 public:
  explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
      : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}

  ~CallbackChannel() {}
  /*!
   * \brief Send data over to the channel.
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes sent.
   */
  size_t Send(const void* data, size_t size) final;
  /*!
   * \brief Recv data from channel.
   *
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes received.
   */
  size_t Recv(void* data, size_t size) final;

 private:
  PackedFunc fsend_;
  PackedFunc frecv_;
};

/*!
 * \brief Wrap a timer function to measure the time cost of a given packed function.
 * \param f The function argument.
 * \param ctx The context.
 * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
 * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
 * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
 * \return f_timer A timer function.
 */
PackedFunc WrapTimeEvaluator(PackedFunc f,
                             TVMContext ctx,
                             int number,
                             int repeat,
                             int min_repeat_ms);

/*!
 * \brief Create a Global RPC module that refers to the session.
 * \param sess The RPC session of the global module.
 * \return The created module.
 */
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);

// Remote space pointer.
struct RemoteSpace {
  void* data;
  std::shared_ptr<RPCSession> sess;
};

// implementation of inline functions
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
  std::lock_guard<std::recursive_mutex> lock(mutex_);
  writer_.Write(&code, sizeof(code));
  return call_remote_(std::forward<Args>(args)...);
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_RPC_RPC_SESSION_H_