/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_api_symbolic.cc
 * \brief C API related to symbolic graph compsition.
 */
#include <nnvm/c_api.h>
#include <nnvm/op.h>
#include <nnvm/symbolic.h>
#include "c_api_common.h"

using namespace nnvm;

int NNListAllOpNames(nn_uint *out_size,
                     const char*** out_array) {
  API_BEGIN();
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
  ret->ret_vec_charp.resize(0);
  ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
  API_END();
}

int NNGetOpHandle(const char* op_name,
                  OpHandle* op_out) {
  API_BEGIN();
  *op_out = (OpHandle)Op::Get(op_name);  // NOLINT(*)
  API_END();
}

int NNListUniqueOps(nn_uint *out_size,
                    OpHandle **out_array) {
  API_BEGIN();
  auto &vec = dmlc::Registry<Op>::List();
  *out_size = static_cast<nn_uint>(vec.size());
  *out_array = (OpHandle*)(dmlc::BeginPtr(vec));  //  NOLINT(*)
  API_END();
}

int NNAddControlDeps(SymbolHandle handle,
                     SymbolHandle src_dep) {
  API_BEGIN();
  static_cast<Symbol*>(handle)->AddControlDeps(
      *static_cast<Symbol*>(src_dep));
  API_END();
}

int NNGetOpInfo(OpHandle handle,
                const char **name,
                const char **description,
                nn_uint *num_doc_args,
                const char ***arg_names,
                const char ***arg_type_infos,
                const char ***arg_descriptions,
                const char **return_type) {
  const Op *op = static_cast<const Op *>(handle);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();

  API_BEGIN();
  *name = op->name.c_str();
  *description = op->description.c_str();
  *num_doc_args = static_cast<nn_uint>(op->arguments.size());
  if (return_type) *return_type = nullptr;
  ret->ret_vec_charp.resize(0);
  ret->ret_vec_charp.reserve(op->arguments.size() * 3);
  for (size_t i = 0; i < op->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
  }
  for (size_t i = 0; i < op->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
  }
  for (size_t i = 0; i < op->arguments.size(); ++i) {
    ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
  }
  *arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
  *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
  *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
  API_END();
}

int NNSymbolCreateAtomicSymbol(OpHandle creator,
                               nn_uint num_param,
                               const char **keys,
                               const char **vals,
                               SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  const Op* op = static_cast<const Op*>(creator);
  std::unordered_map<std::string, std::string> kwargs;
  for (nn_uint i = 0; i < num_param; ++i) {
    kwargs.insert({std::string(keys[i]), std::string(vals[i])});
  }
  *s = Symbol::CreateFunctor(op, std::move(kwargs));
  *out = s;
  API_END_HANDLE_ERROR(delete s;);
}

int NNSymbolCreateVariable(const char *name, SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  *s = Symbol::CreateVariable(name);
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolCreateGroup(nn_uint num_symbols,
                        SymbolHandle *symbols,
                        SymbolHandle *out) {
  Symbol *s = new Symbol();
  Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*)
  API_BEGIN();
  std::vector<Symbol> syms;
  for (nn_uint i = 0; i < num_symbols; ++i) {
    syms.push_back(*sym_arr[i]);
  }
  *s = Symbol::CreateGroup(syms);
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolGetOutput(SymbolHandle symbol,
                      nn_uint index,
                      SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  *s = (*static_cast<Symbol*>(symbol))[index];
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolGetInternals(SymbolHandle symbol,
                         SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  *s = static_cast<Symbol*>(symbol)->GetInternals();
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolGetChildren(SymbolHandle symbol,
                        SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  *s = static_cast<Symbol*>(symbol)->GetChildren();
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolFree(SymbolHandle symbol) {
  API_BEGIN();
  delete static_cast<Symbol*>(symbol);
  API_END();
}

int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
  Symbol *s = new Symbol();
  API_BEGIN();
  *s = static_cast<const Symbol*>(symbol)->Copy();
  *out = s;
  API_END_HANDLE_ERROR(delete s);
}

int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  std::ostringstream os;
  s->Print(os);
  ret->ret_str = os.str();
  *out_str = (ret->ret_str).c_str();
  API_END();
}

int NNSymbolGetAttr(SymbolHandle symbol,
                    const char* key,
                    const char** out,
                    int* success) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  if (s->GetAttr(key, &(ret->ret_str))) {
    *out = (ret->ret_str).c_str();
    *success = 1;
  } else {
    *out = nullptr;
    *success = 0;
  }
  API_END();
}

int NNSymbolSetAttrs(SymbolHandle symbol,
                     nn_uint num_param,
                     const char** keys,
                     const char** vals) {
  Symbol *s = static_cast<Symbol*>(symbol);
  API_BEGIN();
  std::vector<std::pair<std::string, std::string> > kwargs;
  for (nn_uint i = 0; i < num_param; ++i) {
    kwargs.emplace_back(
        std::make_pair(std::string(keys[i]), std::string(vals[i])));
  }
  s->SetAttrs(kwargs);
  API_END();
}

int NNSymbolListAttrs(SymbolHandle symbol,
                      int option,
                      nn_uint *out_size,
                      const char*** out) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  std::unordered_map<std::string, std::string> attr =
      s->ListAttrs(static_cast<Symbol::ListAttrOption>(option));  // NOLINT(*)

  std::vector<std::string>& attr_list = ret->ret_vec_str;
  attr_list.resize(0);
  attr_list.reserve(attr.size());
  for (const auto& kv : attr) {
    attr_list.push_back(kv.first);
    attr_list.push_back(kv.second);
  }
  *out_size = attr.size();
  ret->ret_vec_charp.clear();
  ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

int NNSymbolListInputVariables(SymbolHandle symbol,
                               int option,
                               nn_uint *out_size,
                               SymbolHandle** out_sym_array) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
  ret->ret_handles.resize(0);
  ret->ret_handles.reserve(vs.size());
  for (size_t i = 0; i < vs.size(); ++i) {
    nnvm::Symbol* rs = new nnvm::Symbol();
    rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
    ret->ret_handles.push_back(rs);
  }
  *out_size = static_cast<nn_uint>(vs.size());
  *out_sym_array = dmlc::BeginPtr(ret->ret_handles);
  API_END();
}

int NNSymbolListInputNames(SymbolHandle symbol,
                           int option,
                           nn_uint *out_size,
                           const char ***out_str_array) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  ret->ret_vec_str =
      s->ListInputNames(Symbol::ListInputOption(option));
  ret->ret_vec_charp.resize(0);
  ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
  *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

int NNSymbolListOutputNames(SymbolHandle symbol,
                            nn_uint *out_size,
                            const char ***out_str_array) {
  Symbol *s = static_cast<Symbol*>(symbol);
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  API_BEGIN();
  ret->ret_vec_str = s->ListOutputNames();
  ret->ret_vec_charp.resize(0);
  ret->ret_vec_charp.reserve(ret->ret_vec_str.size());
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
  *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

int NNSymbolGetNumOutputs(SymbolHandle symbol,
                           nn_uint *output_count) {
  Symbol *s = static_cast<Symbol*>(symbol);
  API_BEGIN();
  *output_count = static_cast<nn_uint>(s->outputs.size());
  API_END();
}

int NNSymbolCompose(SymbolHandle sym,
                    const char *name,
                    nn_uint num_args,
                    const char** keys,
                    SymbolHandle* args) {
  API_BEGIN();
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
  std::string& s_name = ret->ret_str;
  std::unordered_map<std::string, const Symbol*>& kwargs
      = ret->kwarg_symbol;
  kwargs.clear();
  if (name != nullptr) {
    s_name = name;
  } else {
    s_name.clear();
  }
  Symbol* s = static_cast<Symbol*>(sym);
  if (keys == nullptr && num_args != 0) {
    kwargs.clear();
    array_view<const Symbol*> parg(
        (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
    s->Compose(parg, kwargs, s_name);
  } else {
    for (nn_uint i = 0; i < num_args; ++i) {
      kwargs[keys[i]] = (Symbol*)args[i];  //  NOLINT(*)
    }
    s->Compose(array_view<const Symbol*>(), kwargs, s_name);
  }
  API_END();
}