c_api_graph.cc 2.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_api_graph.cc
 * \brief C API related to Graph IR.
 */
#include <nnvm/c_api.h>
#include <nnvm/op.h>
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
11
#include <dmlc/json.h>
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
#include "./c_api_common.h"

using namespace nnvm;

int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) {
  Graph* g = new Graph();
  API_BEGIN();
  g->outputs = static_cast<Symbol*>(symbol)->outputs;
  *graph = g;
  API_END_HANDLE_ERROR(delete g);
}

int NNGraphFree(GraphHandle handle) {
  API_BEGIN();
  delete static_cast<Graph*>(handle);
  API_END();
}

int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
  Symbol* s = new Symbol();
  API_BEGIN();
  s->outputs = static_cast<Graph*>(graph)->outputs;
  *symbol = s;
  API_END_HANDLE_ERROR(delete s);
}

38 39 40 41 42 43 44 45 46 47 48
int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
                                 const char* key,
                                 SymbolHandle list) {
  API_BEGIN();
  Symbol* s = static_cast<Symbol*>(list);
  Graph* g = static_cast<Graph*>(handle);
  g->attrs[std::string(key)]
      = std::make_shared<any>(s->outputs);
  API_END();
}

49 50 51
int NNGraphSetJSONAttr(GraphHandle handle,
                       const char* key,
                       const char* json_value) {
52 53
  API_BEGIN();
  Graph* g = static_cast<Graph*>(handle);
54 55 56 57 58 59
  std::string temp(json_value);
  std::istringstream is(temp);
  dmlc::JSONReader reader(&is);
  nnvm::any value;
  reader.Read(&value);
  g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
60 61 62
  API_END();
}

63
int NNGraphGetJSONAttr(GraphHandle handle,
64
                      const char* key,
65
                      const char** json_out,
66
                      int *success) {
67
  NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
68 69 70 71 72
  API_BEGIN();
  Graph* g = static_cast<Graph*>(handle);
  std::string skey(key);
  auto it = g->attrs.find(skey);
  if (it != g->attrs.end()) {
73 74 75 76 77
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    writer.Write(*it->second.get());
    ret->ret_str = os.str();
    *json_out = (ret->ret_str).c_str();
78 79 80 81 82 83 84
    *success = 1;
  } else {
    *success = 0;
  }
  API_END();
}

85 86 87 88
int NNGraphApplyPasses(GraphHandle src,
                       nn_uint num_pass,
                       const char** pass_names,
                       GraphHandle *dst) {
89 90 91 92 93 94
  Graph* g = new Graph();
  API_BEGIN();
  std::vector<std::string> vpass;
  for (nn_uint i = 0; i < num_pass; ++i) {
    vpass.emplace_back(std::string(pass_names[i]));
  }
95
  *g = ApplyPasses(*static_cast<Graph*>(src), vpass);
96 97 98
  *dst = g;
  API_END_HANDLE_ERROR(delete g);
}