graph_runtime.cc 3.77 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 * Copyright (c) 2017 by Contributors
 * \file graph_runtime.cc
 * \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
26
#include <utility>
27
#include "graph_runtime.h"
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

namespace nnvm {
namespace compiler {

using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;

DMLC_REGISTER_PARAMETER(TVMOpParam);

// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
  TVMOpParam param;
  param.Init(attrs->dict);
  attrs->parsed = std::move(param);
}

NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(TVMOpParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
    const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
    return param.num_inputs;
  })
.set_num_outputs([](const NodeAttrs& attrs) {
    const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
    return param.num_outputs;
  });


TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    CHECK_EQ(args.size() % 2, 0u);
    size_t num_params = args.size() / 2;
    std::vector<std::string> names;
    names.reserve(num_params);
    std::vector<DLTensor*> arrays;
    arrays.reserve(num_params);
    for (size_t i = 0; i < num_params * 2; i += 2) {
      names.emplace_back(args[i].operator std::string());
      arrays.emplace_back(args[i + 1].operator DLTensor*());
    }
    std::string bytes;
    dmlc::MemoryStringStream strm(&bytes);
    dmlc::Stream* fo = &strm;
    uint64_t header = kTVMNDArrayListMagic, reserved = 0;
tqchen committed
73 74
    fo->Write(header);
    fo->Write(reserved);
75 76 77
    fo->Write(names);
    {
      uint64_t sz = static_cast<uint64_t>(arrays.size());
tqchen committed
78
      fo->Write(sz);
79
      for (size_t i = 0; i < sz; ++i) {
80
        tvm::runtime::SaveDLTensor(fo, arrays[i]);
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      }
    }
    TVMByteArray arr;
    arr.data = bytes.c_str();
    arr.size = bytes.length();
    *rv = arr;
  });


TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    std::string bytes = args[0];
    std::vector<std::string> names;
    dmlc::MemoryStringStream memstrm(&bytes);
    dmlc::Stream* strm = &memstrm;
    uint64_t header, reserved;
    CHECK(strm->Read(&header))
        << "Invalid parameters file format";
    CHECK(header == kTVMNDArrayListMagic)
        << "Invalid parameters file format";
    CHECK(strm->Read(&reserved))
        << "Invalid parameters file format";
    CHECK(strm->Read(&names))
        << "Invalid parameters file format";
    uint64_t sz;
    strm->Read(&sz, sizeof(sz));
    size_t size = static_cast<size_t>(sz);
    CHECK(size == names.size())
        << "Invalid parameters file format";
110
    tvm::Array<NDArrayWrapper> ret;
111
    for (size_t i = 0; i < size; ++i) {
112 113
      tvm::runtime::NDArray temp;
      temp.Load(strm);
114
      auto n = tvm::make_node<NDArrayWrapperNode>();
115 116 117
      n->name = std::move(names[i]);
      n->array = temp;
      ret.push_back(NDArrayWrapper(n));
118
    }
119
    *rv = ret;
120
  });
121

122
TVM_REGISTER_NODE_TYPE(NDArrayWrapperNode);
123 124
}  // namespace compiler
}  // namespace nnvm