/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file graph_runtime.h
 * \brief Tiny graph runtime that can run graph containing only tvm PackedFunc.
 */
#ifndef TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_

#include <dlpack/dlpack.h>

#include "load_json.h"
#include "ndarray.h"
#include "packed_func.h"
#include "module.h"

/*! \brief operator attributes about tvm op */
typedef struct TVMOpParam {
  char func_name[120];
  uint32_t num_inputs;
  uint32_t num_outputs;
  uint32_t flatten_data;
} TVMOpParam;

// Memory pool entry.
typedef struct TVMGraphRuntimePoolEntry {
  size_t size;
  int device_type;
} TVMGraphRuntimePoolEntry;

// Node entry
typedef struct TVMGraphRuntimeNodeEntry {
  uint32_t node_id;
  uint32_t index;
  uint32_t version;
  // JSON Loader
  void (*Load)(JSONReader *reader);
} TVMGraphRuntimeNodeEntry;

// Node
typedef struct TVMGraphRuntimeNode {
  // operator type in string
  char op_type[16];
  // name of the op
  char name[120];
  // parameters
  TVMOpParam param;
  // inputs
  TVMGraphRuntimeNodeEntry inputs[GRAPH_RUNTIME_NODE_MAX_INPUTS];
  size_t                   inputs_count;
  // control deps
  uint32_t control_deps[200];
  // JSON Loader
  void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param);
  // JSON Loader
  int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader);
} TVMGraphRuntimeNode;

// Graph attribute
typedef struct TVMGraphRuntimeGraphAttr {
  uint32_t storage_num_not_alloctaed;
  uint32_t storage_id[GRAPH_RUNTIME_MAX_NODES];
  uint32_t device_index[GRAPH_RUNTIME_MAX_NODES];
  char     dltype[GRAPH_RUNTIME_MAX_NODES][10];  // "int8", "int16", "float32"
  uint32_t dltype_count;
  int64_t  shape[GRAPH_RUNTIME_MAX_NODES][TVM_CRT_MAX_NDIM];
  uint32_t ndim[GRAPH_RUNTIME_MAX_NODES];
  uint32_t shape_count;
} TVMGraphRuntimeGraphAttr;

typedef DLTensor* DLTensorPtr;

/*!
 * \brief Tiny graph runtime.
 *
 *  This runtime can be acccesibly in various language via
 *  TVM runtime PackedFunc API.
 */
/* class GraphRuntime : public ModuleNode { */
typedef struct TVMGraphRuntime {
  void (*Run)(struct TVMGraphRuntime * runtime);

  /*!
   * \brief Initialize the graph executor with graph and context.
   * \param graph_json The execution graph.
   * \param module The module containing the compiled functions for the host
   *  processor.
   * \param ctxs The context of the host and devices where graph nodes will be
   *  executed on.
   */
  void (*Init)(struct TVMGraphRuntime * runtime,
               const char * graph_json,
               const TVMModule * module,
               const TVMContext * ctxs);

  /*!
   * \brief Get the input index given the name of input.
   * \param name The name of the input.
   * \return The index of input.
   */
  int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name);

  /*!
   * \brief set index-th input to the graph.
   * \param index The input index.
   * \param data_in The input data.
   */
  void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
  /*!
   * \brief Return NDArray for given output index.
   * \param index The output index.
   *
   * \return NDArray corresponding to given output node index.
   */
  int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out);
  /*!
   * \brief Load parameters from parameter blob.
   * \param param_blob A binary blob of parameter.
   */
  int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob,
                    const uint32_t param_size);

  // The graph attribute fields.
  int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader);
  /*! \brief Setup the temporal storage */
  void (*SetupStorage)(struct TVMGraphRuntime * runtime);
  /*! \brief Setup the executors. */
  int (*SetupOpExecs)(struct TVMGraphRuntime * runtime);

  /*!
   * \brief Create an execution function given input.
   * \param attrs The node attributes.
   * \param args The arguments to the functor, including inputs and outputs.
   * \param num_inputs Number of inputs.
   * \return The created executor.
   */
  int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs,
                         DLTensorPtr * args, const uint32_t args_count,
                         uint32_t num_inputs, TVMPackedFunc * pf);

  // Get node entry index.
  uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index);

  // /*! \brief The graph nodes. */
  /* GraphRuntimeNode nodes_[GRAPH_RUNTIME_MAX_NODES]; */
  TVMGraphRuntimeNode nodes[GRAPH_RUNTIME_MAX_NODES];
  uint32_t           nodes_count;
  /*! \brief The argument nodes. */
  uint32_t input_nodes[GRAPH_RUNTIME_MAX_INPUT_NODES];
  uint32_t   input_nodes_count;
  /*! \brief Used for quick entry indexing. */
  uint32_t node_row_ptr[GRAPH_RUNTIME_MAX_NODE_ROW_PTR];
  uint32_t node_row_ptr_count;
  /*! \brief Output entries. */
  TVMGraphRuntimeNodeEntry outputs[GRAPH_RUNTIME_MAX_OUTPUTS];
  uint32_t              outputs_count;
  /*! \brief Additional graph attributes. */
  TVMGraphRuntimeGraphAttr attrs;
  /*! \brief The code module that contains both host and device code. */
  TVMModule module;
  /*! \brief Execution context of all devices including the host. */
  TVMContext ctxs[GRAPH_RUNTIME_MAX_CONTEXTS];
  uint32_t   ctxs_count;
  /*! \brief Common storage pool for all devices. */
  TVMNDArray  storage_pool[GRAPH_RUNTIME_MAX_NODES];
  uint32_t storage_pool_count;
  /*! \brief Data entry of each node. */
  TVMNDArray  data_entry[GRAPH_RUNTIME_MAX_NODES];
  uint32_t data_entry_count;
  /*! \brief Operator on each node. */
  TVMPackedFunc op_execs[GRAPH_RUNTIME_MAX_NODES];
  uint32_t op_execs_count;
} TVMGraphRuntime;

// public functions
TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m,
                                        const TVMContext * ctxs);
void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime);

// private functions
void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in);
int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob,
                               const uint32_t param_size);
void TVMGraphRuntime_Run(TVMGraphRuntime * runtime);
int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out);

#endif  // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_