dsl_api.cc 6.87 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *  Copyright (c) 2016 by Contributors
22 23
 *  Implementation of DSL API
 * \file dsl_api.cc
24 25 26 27 28
 */
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
29
#include <tvm/attrs.h>
30 31 32
#include <vector>
#include <string>
#include <exception>
33
#include "../runtime/dsl_api.h"
34

35 36
namespace tvm {
namespace runtime {
37 38 39 40 41 42 43 44 45 46 47 48 49
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
  std::vector<const char *> ret_vec_charp;
  /*! \brief result holder for retruning string */
  std::string ret_str;
};

/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;

50
using TVMAPINode = NodePtr<Node>;
51 52 53 54

struct APIAttrGetter : public AttrVisitor {
  std::string skey;
  TVMRetValue* ret;
55
  bool found_ref_object{false};
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

  void Visit(const char* key, double* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, int64_t* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, uint64_t* value) final {
    CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
        << "cannot return too big constant";
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, int* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, bool* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
74 75 76
  void Visit(const char* key, void** value) final {
    if (skey == key) *ret = static_cast<void*>(value[0]);
  }
77 78 79 80 81 82 83 84 85
  void Visit(const char* key, Type* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, std::string* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, NodeRef* value) final {
    if (skey == key) {
      *ret = value[0];
86 87 88 89 90 91 92
      found_ref_object = true;
    }
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    if (skey == key) {
      *ret = value[0];
      found_ref_object = true;
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    }
  }
};

struct APIAttrDir : public AttrVisitor {
  std::vector<std::string>* names;

  void Visit(const char* key, double* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, int64_t* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, uint64_t* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, bool* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, int* value) final {
    names->push_back(key);
  }
115 116 117
  void Visit(const char* key, void** value) final {
    names->push_back(key);
  }
118 119 120 121 122 123 124 125 126
  void Visit(const char* key, Type* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, std::string* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, NodeRef* value) final {
    names->push_back(key);
  }
127 128 129
  void Visit(const char* key, runtime::NDArray* value) final {
    names->push_back(key);
  }
130 131
};

132 133 134 135 136 137 138 139 140 141
class DSLAPIImpl : public DSLAPI {
 public:
  void NodeFree(NodeHandle handle) const final {
    delete static_cast<TVMAPINode*>(handle);
  }
  void NodeTypeKey2Index(const char* type_key,
                        int* out_index) const final {
    *out_index = static_cast<int>(Node::TypeKey2Index(type_key));
  }
  void NodeGetTypeIndex(NodeHandle handle,
142
                        int* out_index) const final {
143 144 145 146
    *out_index = static_cast<int>(
        (*static_cast<TVMAPINode*>(handle))->type_index());
  }
  void NodeGetAttr(NodeHandle handle,
147 148 149 150
                   const char* key,
                   TVMValue* ret_val,
                   int* ret_type_code,
                   int* ret_success) const final {
151 152
    TVMRetValue rv;
    APIAttrGetter getter;
153
    TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
154 155 156 157
    getter.skey = key;
    getter.ret = &rv;
    if (getter.skey == "type_key") {
      ret_val->v_str = (*tnode)->type_key();
158
      *ret_type_code = kStr;
159
      *ret_success = 1;
160 161
      return;
    } else if (!(*tnode)->is_type<DictAttrsNode>()) {
162
      (*tnode)->VisitAttrs(&getter);
163
      *ret_success = getter.found_ref_object || rv.type_code() != kNull;
164 165 166 167 168 169 170 171 172 173 174 175
    } else {
      // specially handle dict attr
      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
      auto it = dnode->dict.find(key);
      if (it != dnode->dict.end()) {
        *ret_success = 1;
        rv = (*it).second;
      } else {
        *ret_success = 0;
      }
    }
    if (*ret_success) {
176 177 178 179 180 181 182 183 184
      if (rv.type_code() == kStr ||
          rv.type_code() == kTVMType) {
        TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
        e->ret_str = rv.operator std::string();
        *ret_type_code = kStr;
        ret_val->v_str = e->ret_str.c_str();
      } else {
        rv.MoveToCHost(ret_val, ret_type_code);
      }
185 186
    }
  }
187 188 189 190 191 192 193 194
  void NodeListAttrNames(NodeHandle handle,
                        int *out_size,
                        const char*** out_array) const final {
    TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
    ret->ret_vec_str.clear();
    TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
    APIAttrDir dir;
    dir.names = &(ret->ret_vec_str);
195 196 197 198 199 200 201 202 203 204

    if (!(*tnode)->is_type<DictAttrsNode>()) {
      (*tnode)->VisitAttrs(&dir);
    } else {
      // specially handle dict attr
      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
      for (const auto& kv : dnode->dict) {
        ret->ret_vec_str.push_back(kv.first);
      }
    }
205 206 207 208 209 210 211 212
    ret->ret_vec_charp.clear();
    for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
      ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
    }
    *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
    *out_size = static_cast<int>(ret->ret_vec_str.size());
  }
};
213

214 215 216 217 218 219 220 221
TVM_REGISTER_GLOBAL("dsl_api.singleton")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    static DSLAPIImpl impl;
    void* ptr = &impl;
    *rv = ptr;
  });
}  // namespace runtime
}  // namespace tvm