api_base.cc 1.99 KB
Newer Older
1
 /*!
2 3 4 5
 *  Copyright (c) 2017 by Contributors
 *  Implementation of basic API functions
 * \file api_base.cc
 */
6
#include <dmlc/memory_io.h>
7 8 9 10 11
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>

namespace tvm {
12
TVM_REGISTER_API("_format_str")
13 14 15 16 17 18 19
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    std::ostringstream os;
    os << args[0].operator NodeRef();
    *ret = os.str();
  });

20
TVM_REGISTER_API("_raw_ptr")
21 22 23 24 25 26
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    CHECK(args[0].type_code() == kNodeHandle);
    *ret = reinterpret_cast<int64_t>(
        args[0].node_sptr().get());
  });

27
TVM_REGISTER_API("_save_json")
28
.set_body_typed<std::string(NodeRef)>(SaveJSON);
29

30
TVM_REGISTER_API("_load_json")
31
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
32

33 34 35 36
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    TVMSetStream(args[0], args[1], args[2]);
  });
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    CHECK_EQ(args.size() % 2, 0u);
    constexpr uint64_t TVMNDArrayListMagic = 0xF7E58D4F05049CB7;
    size_t num_params = args.size() / 2;
    std::vector<std::string> names;
    names.reserve(num_params);
    std::vector<DLTensor*> arrays;
    arrays.reserve(num_params);
    for (size_t i = 0; i < num_params * 2; i += 2) {
      names.emplace_back(args[i].operator std::string());
      arrays.emplace_back(args[i + 1].operator DLTensor*());
    }
    std::string bytes;
    dmlc::MemoryStringStream strm(&bytes);
    dmlc::Stream* fo = &strm;
    uint64_t header = TVMNDArrayListMagic, reserved = 0;
    fo->Write(header);
    fo->Write(reserved);
    fo->Write(names);
    {
      uint64_t sz = static_cast<uint64_t>(arrays.size());
      fo->Write(sz);
      for (size_t i = 0; i < sz; ++i) {
        tvm::runtime::SaveDLTensor(fo, arrays[i]);
      }
    }
    TVMByteArray arr;
    arr.data = bytes.c_str();
    arr.size = bytes.length();
    *rv = arr;
  });

70
}  // namespace tvm