packed_func_ext.h 7.75 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/packed_func_ext.h
22 23 24 25 26 27 28 29 30
 * \brief Extension package to PackedFunc
 *   This enales pass NodeRef types into/from PackedFunc.
 */
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
31
#include <limits>
32 33
#include <type_traits>

34 35 36 37
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "runtime/packed_func.h"
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
 * \brief Runtime type checker for node type.
 * \tparam T the type to be checked.
 */
template<typename T>
struct NodeTypeChecker {
  static inline bool Check(Node* sptr) {
    // This is the only place in the project where RTTI is used
    // It can be turned off, but will make non strict checking.
    // TODO(tqchen) possibly find alternative to turn of RTTI
    using ContainerType = typename T::ContainerType;
56 57
    // always allow nullptr.
    if (sptr == nullptr) return true;
58
    return sptr->derived_from<ContainerType>();
59 60 61 62 63 64 65 66 67 68
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    using ContainerType = typename T::ContainerType;
    os << ContainerType::_type_key;
  }
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
  static inline bool Check(Node* sptr) {
69
    if (sptr == nullptr) return true;
70 71 72 73 74 75 76 77 78 79 80 81 82 83
    if (!sptr->is_type<ArrayNode>()) return false;
    ArrayNode* n = static_cast<ArrayNode*>(sptr);
    for (const auto& p : n->data) {
      if (!NodeTypeChecker<T>::Check(p.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "array<";
    NodeTypeChecker<T>::PrintName(os);
    os << ">";
  }
};

84 85 86
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
  static inline bool Check(Node* sptr) {
87
    if (sptr == nullptr) return true;
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    if (!sptr->is_type<StrMapNode>()) return false;
    StrMapNode* n = static_cast<StrMapNode*>(sptr);
    for (const auto& kv : n->data) {
      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "map<string";
    os << ',';
    NodeTypeChecker<V>::PrintName(os);
    os << '>';
  }
};

103 104 105
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
  static inline bool Check(Node* sptr) {
106
    if (sptr == nullptr) return true;
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    if (!sptr->is_type<MapNode>()) return false;
    MapNode* n = static_cast<MapNode*>(sptr);
    for (const auto& kv : n->data) {
      if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "map<";
    NodeTypeChecker<K>::PrintName(os);
    os << ',';
    NodeTypeChecker<V>::PrintName(os);
    os << '>';
  }
};

template<typename T>
inline std::string NodeTypeName() {
  std::ostringstream os;
  NodeTypeChecker<T>::PrintName(os);
  return os.str();
}

// extensions for tvm arg value

133 134
template<typename TNodeRef>
inline TNodeRef TVMArgValue::AsNodeRef() const {
135 136 137
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
138
  if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
139
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
140
  NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
141 142 143 144 145 146
  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<TNodeRef>()
      << " but get " << sptr->type_key();
  return TNodeRef(sptr);
}

147
inline TVMArgValue::operator HalideIR::Expr() const {
148
  if (type_code_ == kNull) return Expr();
149
  if (type_code_ == kDLInt) {
150 151
    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
152 153
    return Expr(static_cast<int>(value_.v_int64));
  }
154
  if (type_code_ == kDLFloat) {
155 156 157
    return Expr(static_cast<float>(value_.v_float64));
  }
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
158
  NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
159 160 161
  if (sptr->is_type<IterVarNode>()) {
    return IterVar(sptr)->var;
  }
162 163 164
  if (sptr->is_type<TensorNode>()) {
    return Tensor(sptr)();
  }
165 166 167 168 169 170
  CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<Expr>()
      << " but get " << sptr->type_key();
  return Expr(sptr);
}

171 172 173 174 175 176 177 178 179 180 181 182 183 184
inline TVMArgValue::operator tvm::Integer() const {
  if (type_code_ == kNull) return Integer();
  if (type_code_ == kDLInt) {
    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
    return Integer(static_cast<int>(value_.v_int64));
  }
  NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
  CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<Expr>()
      << " but get " << sptr->type_key();
  return Integer(sptr);
}

185
inline NodePtr<Node>& TVMArgValue::node_sptr() {
186
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
187
  return *ptr<NodePtr<Node> >();
188 189 190 191 192 193
}


template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
194 195
  NodePtr<Node>& sptr =
      *ptr<NodePtr<Node> >();
196 197 198 199 200
  return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}

// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
201
    const NodePtr<Node>& other) {
202 203 204
  if (other.get() == nullptr) {
    SwitchToPOD(kNull);
  } else {
205
    SwitchToClass<NodePtr<Node> >(kNodeHandle, other);
206
  }
207 208 209 210
  return *this;
}

inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
211 212 213
  if (!other.defined()) {
    SwitchToPOD(kNull);
  } else {
214
    SwitchToClass<NodePtr<Node> >(kNodeHandle, other.node_);
215
  }
216 217 218
  return *this;
}

219 220
template<typename TNodeRef>
inline TNodeRef TVMRetValue::AsNodeRef() const {
221 222 223 224 225
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
  if (type_code_ == kNull) return TNodeRef();
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
226
  NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
227 228 229 230
  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<TNodeRef>()
      << " but get " << sptr->type_key();
  return TNodeRef(sptr);
231 232
}

233
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const {  // NOLINT(*)
234
  if (other.defined()) {
235
    values_[i].v_handle = const_cast<NodePtr<Node>*>(&(other.node_));
236 237 238 239
    type_codes_[i] = kNodeHandle;
  } else {
    type_codes_[i] = kNull;
  }
240 241
}

242
// type related stuffs
243
inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
244 245 246
  return this->operator=(Type2TVMType(t));
}

247
inline TVMRetValue::operator HalideIR::Type() const {
248 249 250
  return TVMType2Type(operator TVMType());
}

251
inline TVMArgValue::operator HalideIR::Type() const {
252 253 254 255
  return TVMType2Type(operator TVMType());
}

inline void TVMArgsSetter::operator()(
256
    size_t i, const HalideIR::Type& t) const {
257 258 259 260 261
  this->operator()(i, Type2TVMType(t));
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_PACKED_FUNC_EXT_H_