packed_func_ext.h 6.37 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/packed_func_ext.h
4 5 6 7 8 9 10 11 12 13 14 15 16
 * \brief Extension package to PackedFunc
 *   This enales pass NodeRef types into/from PackedFunc.
 */
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <type_traits>

#include "./base.h"
#include "./expr.h"
17
#include "./tensor.h"
18
#include "./runtime/packed_func.h"
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
 * \brief Runtime type checker for node type.
 * \tparam T the type to be checked.
 */
template<typename T>
struct NodeTypeChecker {
  static inline bool Check(Node* sptr) {
    // This is the only place in the project where RTTI is used
    // It can be turned off, but will make non strict checking.
    // TODO(tqchen) possibly find alternative to turn of RTTI
    using ContainerType = typename T::ContainerType;
37
    return sptr->derived_from<ContainerType>();
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    using ContainerType = typename T::ContainerType;
    os << ContainerType::_type_key;
  }
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
  static inline bool Check(Node* sptr) {
    if (sptr == nullptr) return false;
    if (!sptr->is_type<ArrayNode>()) return false;
    ArrayNode* n = static_cast<ArrayNode*>(sptr);
    for (const auto& p : n->data) {
      if (!NodeTypeChecker<T>::Check(p.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "array<";
    NodeTypeChecker<T>::PrintName(os);
    os << ">";
  }
};

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
  static inline bool Check(Node* sptr) {
    if (sptr == nullptr) return false;
    if (!sptr->is_type<StrMapNode>()) return false;
    StrMapNode* n = static_cast<StrMapNode*>(sptr);
    for (const auto& kv : n->data) {
      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "map<string";
    os << ',';
    NodeTypeChecker<V>::PrintName(os);
    os << '>';
  }
};

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
  static inline bool Check(Node* sptr) {
    if (sptr == nullptr) return false;
    if (!sptr->is_type<MapNode>()) return false;
    MapNode* n = static_cast<MapNode*>(sptr);
    for (const auto& kv : n->data) {
      if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
      if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
    }
    return true;
  }
  static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
    os << "map<";
    NodeTypeChecker<K>::PrintName(os);
    os << ',';
    NodeTypeChecker<V>::PrintName(os);
    os << '>';
  }
};

template<typename T>
inline std::string NodeTypeName() {
  std::ostringstream os;
  NodeTypeChecker<T>::PrintName(os);
  return os.str();
}

// extensions for tvm arg value

112 113
template<typename TNodeRef>
inline TNodeRef TVMArgValue::AsNodeRef() const {
114 115 116 117 118 119 120 121 122 123 124 125
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
  if (type_code_ == kNull) return TNodeRef();
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<TNodeRef>()
      << " but get " << sptr->type_key();
  return TNodeRef(sptr);
}

126
inline TVMArgValue::operator HalideIR::Expr() const {
127
  if (type_code_ == kNull) return Expr();
128
  if (type_code_ == kDLInt) {
129 130
    return Expr(static_cast<int>(value_.v_int64));
  }
131
  if (type_code_ == kDLFloat) {
132 133 134 135 136 137 138
    return Expr(static_cast<float>(value_.v_float64));
  }
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
  if (sptr->is_type<IterVarNode>()) {
    return IterVar(sptr)->var;
  }
139 140 141
  if (sptr->is_type<TensorNode>()) {
    return Tensor(sptr)();
  }
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<Expr>()
      << " but get " << sptr->type_key();
  return Expr(sptr);
}

inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  return *ptr<std::shared_ptr<Node> >();
}


template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
  std::shared_ptr<Node>& sptr =
      *ptr<std::shared_ptr<Node> >();
  return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}

// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
    const std::shared_ptr<Node>& other) {
165 166 167 168 169
  if (other.get() == nullptr) {
    SwitchToPOD(kNull);
  } else {
    SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
  }
170 171 172 173
  return *this;
}

inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
174 175 176 177 178
  if (!other.defined()) {
    SwitchToPOD(kNull);
  } else {
    SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
  }
179 180 181
  return *this;
}

182 183
template<typename TNodeRef>
inline TNodeRef TVMRetValue::AsNodeRef() const {
184 185 186 187 188
  static_assert(
      std::is_base_of<NodeRef, TNodeRef>::value,
      "Conversion only works for NodeRef");
  if (type_code_ == kNull) return TNodeRef();
  TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
189 190 191 192 193
  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
      << "Expected type " << NodeTypeName<TNodeRef>()
      << " but get " << sptr->type_key();
  return TNodeRef(sptr);
194 195
}

196
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const {  // NOLINT(*)
197 198 199 200 201 202
  if (other.defined()) {
    values_[i].v_handle = const_cast<std::shared_ptr<Node>*>(&(other.node_));
    type_codes_[i] = kNodeHandle;
  } else {
    type_codes_[i] = kNull;
  }
203 204
}

205
// type related stuffs
206
inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
207 208 209
  return this->operator=(Type2TVMType(t));
}

210
inline TVMRetValue::operator HalideIR::Type() const {
211 212 213
  return TVMType2Type(operator TVMType());
}

214
inline TVMArgValue::operator HalideIR::Type() const {
215 216 217 218
  return TVMType2Type(operator TVMType());
}

inline void TVMArgsSetter::operator()(
219
    size_t i, const HalideIR::Type& t) const {
220 221 222 223 224
  this->operator()(i, Type2TVMType(t));
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_PACKED_FUNC_EXT_H_