packed_func.h 40.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/runtime/packed_func.h
22
 * \brief Type-erased function used across TVM API.
23
 */
24 25
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
26

27 28 29
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream>
#endif
30
#include <dmlc/logging.h>
31 32
#include <functional>
#include <tuple>
33 34
#include <vector>
#include <string>
35 36
#include <limits>
#include <memory>
37
#include <utility>
38
#include <type_traits>
39 40 41
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
42
#include "object.h"
43
#include "node_base.h"
44

45 46 47 48 49
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#endif

50
namespace tvm {
51 52
// forward declarations
class Integer;
53 54
class DataType;
class Expr;
55

56
namespace runtime {
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

/*!
 * \brief Runtime utility for getting custom type name from code
 * \param type_code Custom type code
 * \return Custom type name
 */
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);

/*!
 * \brief Runtime utility for checking whether custom type is registered
 * \param type_code Custom type code
 * \return Bool representing whether type is registered
 */
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);

/*!
 * \brief Runtime utility for parsing string of the form "custom[<typename>]"
 * \param s String to parse
 * \param scan pointer to parsing pointer, which is scanning across s
 * \return type code of custom type parsed
 */
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);

80 81 82 83 84
// forward declarations
class TVMArgs;
class TVMArgValue;
class TVMRetValue;
class TVMArgsSetter;
85 86

/*!
87 88
 * \brief Packed function is a type-erased function.
 *  The arguments are passed by packed format.
89
 *
90 91 92
 *  This is an useful unified interface to call generated functions,
 *  It is the unified function function type of TVM.
 *  It corresponds to TVMFunctionHandle in C runtime API.
93 94 95
 */
class PackedFunc {
 public:
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  /*!
   * \brief The internal std::function
   * \param args The arguments to the function.
   * \param rv The return value.
   *
   * \code
   *   // Example code on how to implemented FType
   *   void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
   *     // automatically convert arguments to desired type.
   *     int a0 = args[0];
   *     float a1 = args[1];
   *     ...
   *     // automatically assign values to rv
   *     std::string my_return_value = "x";
   *     *rv = my_return_value;
   *   }
   * \endcode
   */
  using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
115
  /*! \brief default constructor */
116
  PackedFunc() {}
117 118
  /*! \brief constructor from null */
  PackedFunc(std::nullptr_t null) {}  // NOLINT(*)
119 120 121 122
  /*!
   * \brief constructing a packed function from a std::function.
   * \param body the internal container of packed function.
   */
123 124
  explicit PackedFunc(FType body) : body_(body) {}
  /*!
125
   * \brief Call packed function by directly passing in unpacked format.
126 127
   * \param args Arguments to be passed.
   * \tparam Args arguments to be passed.
128 129 130 131 132 133 134 135 136
   *
   * \code
   *   // Example code on how to call packed function
   *   void CallPacked(PackedFunc f) {
   *     // call like normal functions by pass in arguments
   *     // return value is automatically converted back
   *     int rvalue = f(1, 2.0);
   *   }
   * \endcode
137 138
   */
  template<typename... Args>
139
  inline TVMRetValue operator()(Args&& ...args) const;
140 141 142
  /*!
   * \brief Call the function in packed format.
   * \param args The arguments
143
   * \param rv The return value.
144
   */
145
  inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
146
  /*! \return the internal body function */
147
  inline FType body() const;
148 149 150 151 152 153 154 155
  /*! \return Whether the packed function is nullptr */
  bool operator==(std::nullptr_t null) const {
    return body_ == nullptr;
  }
  /*! \return Whether the packed function is not nullptr */
  bool operator!=(std::nullptr_t null) const {
    return body_ != nullptr;
  }
156 157 158 159 160 161

 private:
  /*! \brief internal container of packed function */
  FType body_;
};

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
/*!
 * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
 */
template<typename FType>
class TypedPackedFunc;

/*!
 * \anchor TypedPackedFuncAnchor
 * \brief A PackedFunc wrapper to provide typed function signature.
 * It is backed by a PackedFunc internally.
 *
 * TypedPackedFunc enables compile time type checking.
 * TypedPackedFunc works with the runtime system:
 * - It can be passed as an argument of PackedFunc.
 * - It can be assigned to TVMRetValue.
 * - It can be directly converted to a type-erased PackedFunc.
 *
 * Developers should prefer TypedPackedFunc over PackedFunc in C++ code
 * as it enables compile time checking.
 * We can construct a TypedPackedFunc from a lambda function
 * with the same signature.
 *
 * \code
 *  // user defined lambda function.
 *  auto addone = [](int x)->int {
 *    return x + 1;
 *  };
 *  // We can directly convert
 *  // lambda function to TypedPackedFunc
 *  TypedPackedFunc<int(int)> ftyped(addone);
 *  // invoke the function.
 *  int y = ftyped(1);
 *  // Can be directly converted to PackedFunc
 *  PackedFunc packed = ftype;
 * \endcode
 * \tparam R The return value of the function.
 * \tparam Args The argument signature of the function.
 */
template<typename R, typename ...Args>
class TypedPackedFunc<R(Args...)> {
 public:
  /*! \brief short hand for this function type */
  using TSelf = TypedPackedFunc<R(Args...)>;
  /*! \brief default constructor */
  TypedPackedFunc() {}
207 208
  /*! \brief constructor from null */
  TypedPackedFunc(std::nullptr_t null) {}  // NOLINT(*)
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  /*!
   * \brief construct by wrap a PackedFunc
   *
   * Example usage:
   * \code
   * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) {
   *   int x = args[0];
   *   *rv = x + 1;
   *  });
   * // construct from packed function
   * TypedPackedFunc<int(int)> ftyped(packed);
   * // call the typed version.
   * CHECK_EQ(ftyped(1), 2);
   * \endcode
   *
   * \param packed The packed function
   */
226 227 228 229 230 231 232 233 234 235 236
  inline TypedPackedFunc(PackedFunc packed);  // NOLINT(*)
  /*!
   * \brief constructor from TVMRetValue
   * \param value The TVMRetValue
   */
  inline TypedPackedFunc(const TVMRetValue& value);  // NOLINT(*)
  /*!
   * \brief constructor from TVMArgValue
   * \param value The TVMArgValue
   */
  inline TypedPackedFunc(const TVMArgValue& value);  // NOLINT(*)
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
  /*!
   * \brief construct from a lambda function with the same signature.
   *
   * Example usage:
   * \code
   * auto typed_lambda = [](int x)->int { return x + 1; }
   * // construct from packed function
   * TypedPackedFunc<int(int)> ftyped(typed_lambda);
   * // call the typed version.
   * CHECK_EQ(ftyped(1), 2);
   * \endcode
   *
   * \param typed_lambda typed lambda function.
   * \tparam FLambda the type of the lambda function.
   */
  template<typename FLambda,
           typename = typename std::enable_if<
             std::is_convertible<FLambda,
                                 std::function<R(Args...)>
                                 >::value>::type>
257
  TypedPackedFunc(const FLambda& typed_lambda) {  // NOLINT(*)
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    this->AssignTypedLambda(typed_lambda);
  }
  /*!
   * \brief copy assignment operator from typed lambda
   *
   * Example usage:
   * \code
   * // construct from packed function
   * TypedPackedFunc<int(int)> ftyped;
   * ftyped = [](int x) { return x + 1; }
   * // call the typed version.
   * CHECK_EQ(ftyped(1), 2);
   * \endcode
   *
   * \param typed_lambda typed lambda function.
   * \tparam FLambda the type of the lambda function.
   * \returns reference to self.
   */
  template<typename FLambda,
           typename = typename std::enable_if<
             std::is_convertible<FLambda,
                                 std::function<R(Args...)>
                                 >::value>::type>
  TSelf& operator=(FLambda typed_lambda) {  // NOLINT(*)
    this->AssignTypedLambda(typed_lambda);
    return *this;
  }
  /*!
   * \brief copy assignment operator from PackedFunc.
   * \param packed The packed function.
   * \returns reference to self.
   */
  TSelf& operator=(PackedFunc packed) {
    packed_ = packed;
    return *this;
  }
  /*!
   * \brief Invoke the operator.
   * \param args The arguments
   * \returns The return value.
   */
  inline R operator()(Args ...args) const;
  /*!
   * \brief convert to PackedFunc
   * \return the internal PackedFunc
   */
  operator PackedFunc() const {
    return packed();
  }
  /*!
   * \return reference the internal PackedFunc
   */
  const PackedFunc& packed() const {
    return packed_;
  }
313 314 315 316 317 318 319 320
  /*! \return Whether the packed function is nullptr */
  bool operator==(std::nullptr_t null) const {
    return packed_ == nullptr;
  }
  /*! \return Whether the packed function is not nullptr */
  bool operator!=(std::nullptr_t null) const {
    return packed_ != nullptr;
  }
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336

 private:
  friend class TVMRetValue;
  /*! \brief The internal packed function */
  PackedFunc packed_;
  /*!
   * \brief Assign the packed field using a typed lambda function.
   *
   * \param flambda The lambda function.
   * \tparam FLambda The lambda function type.
   * \note We capture the lambda when possible for maximum efficiency.
   */
  template<typename FLambda>
  inline void AssignTypedLambda(FLambda flambda);
};

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
/*! \brief Arguments into TVM functions. */
class TVMArgs {
 public:
  const TVMValue* values;
  const int* type_codes;
  int num_args;
  /*!
   * \brief constructor
   * \param values The argument values
   * \param type_codes The argument type codes
   * \param num_args number of arguments.
   */
  TVMArgs(const TVMValue* values,
          const int* type_codes,
          int num_args)
      : values(values),
        type_codes(type_codes),
        num_args(num_args) { }
  /*! \return size of the arguments */
  inline int size() const;
  /*!
   * \brief Get i-th argument
   * \param i the index.
   * \return the ith argument.
   */
  inline TVMArgValue operator[](int i) const;
};

/*!
 * \brief Convert type code to its name
 * \param type_code The type code .
 * \return The name of type code.
 */
inline const char* TypeCode2Str(int type_code);

/*!
 * \brief convert a string to TVM type.
 * \param s The string to be converted.
 * \return The corresponding tvm type.
 */
inline TVMType String2TVMType(std::string s);

379 380 381 382 383 384 385
/*!
 * \brief convert a TVM type to string.
 * \param t The type to be converted.
 * \return The corresponding tvm type in string.
 */
inline std::string TVMType2String(TVMType t);

386 387 388 389 390 391
// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T)                           \
  CHECK_EQ(CODE, T) << " expected "                            \
  << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)      \

/*!
392 393 394 395 396 397 398 399 400 401 402
 * \brief Type traits to mark if a class is tvm extension type.
 *
 * To enable extension type in C++ must be register () ed via marco.
 * TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
 *
 * Extension class can be passed and returned via PackedFunc in all tvm runtime.
 * Internally extension class is stored as T*.
 *
 * \tparam T the typename
 */
template<typename T>
403
struct extension_type_info {
404 405 406 407
  static const int code = 0;
};

/*!
408
 * \brief Runtime function table about extension type.
409
 */
410 411
class ExtTypeVTable {
 public:
412 413 414 415
  /*! \brief function to be called to delete a handle */
  void (*destroy)(void* handle);
  /*! \brief function to be called when clone a handle */
  void* (*clone)(void* handle);
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
  /*!
   * \brief Register type
   * \tparam T The type to be register.
   * \return The registered vtable.
   */
  template <typename T>
  static inline ExtTypeVTable* Register_();
  /*!
   * \brief Get a vtable based on type code.
   * \param type_code The type code
   * \return The registered vtable.
   */
  TVM_DLL static ExtTypeVTable* Get(int type_code);

 private:
  // Internal registration function.
  TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
433 434 435
};

/*!
436 437 438 439 440 441
 * \brief Internal base class to
 *  handle conversion to POD values.
 */
class TVMPODValue_ {
 public:
  operator double() const {
442 443 444 445 446 447
    // Allow automatic conversion from int to float
    // This avoids errors when user pass in int from
    // the frontend while the API expects a float.
    if (type_code_ == kDLInt) {
      return static_cast<double>(value_.v_int64);
    }
448
    TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
449 450 451
    return value_.v_float64;
  }
  operator int64_t() const {
452
    TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
453 454 455
    return value_.v_int64;
  }
  operator uint64_t() const {
456
    TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
457 458 459
    return value_.v_int64;
  }
  operator int() const {
460
    TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
461 462 463 464 465
    CHECK_LE(value_.v_int64,
             std::numeric_limits<int>::max());
    return static_cast<int>(value_.v_int64);
  }
  operator bool() const {
466
    TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
467 468 469 470 471 472 473 474
    return value_.v_int64 != 0;
  }
  operator void*() const {
    if (type_code_ == kNull) return nullptr;
    if (type_code_ == kArrayHandle) return value_.v_handle;
    TVM_CHECK_TYPE_CODE(type_code_, kHandle);
    return value_.v_handle;
  }
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
  operator DLTensor*() const {
    if (type_code_ == kArrayHandle ||
        type_code_ == kNDArrayContainer) {
      return static_cast<DLTensor*>(value_.v_handle);
    } else {
      if (type_code_ == kNull) return nullptr;
      LOG(FATAL) << "Expected "
                 << "DLTensor* or NDArray but get "
                 << TypeCode2Str(type_code_);
      return nullptr;
    }
  }
  operator NDArray() const {
    if (type_code_ == kNull) return NDArray();
    TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
    return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
491
  }
492 493
  operator Object() const {
    if (type_code_ == kNull) return Object();
494
    TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
495 496
    return Object(static_cast<ObjectCell*>(value_.v_handle));
  }
497 498 499 500
  operator TVMContext() const {
    TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
    return value_.v_ctx;
  }
501 502 503 504 505 506 507 508 509
  template<typename TNDArray,
           typename = typename std::enable_if<
           std::is_base_of<NDArray, TNDArray>::value>::type>
  TNDArray AsNDArray() const {
    if (type_code_ == kNull) return TNDArray(nullptr);
    auto *container = static_cast<NDArray::Container*>(value_.v_handle);
    CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
    return TNDArray(container);
  }
510 511
  template<typename TExtension>
  const TExtension& AsExtension() const {
512 513
    CHECK_LT(type_code_, kExtEnd);
    return static_cast<TExtension*>(value_.v_handle)[0];
514
  }
515 516 517 518 519 520 521 522 523 524 525 526
  int type_code() const {
    return type_code_;
  }
  /*!
   * \brief return handle as specific pointer type.
   * \tparam T the data type.
   * \return The pointer type.
   */
  template<typename T>
  T* ptr() const {
    return static_cast<T*>(value_.v_handle);
  }
527 528 529 530 531 532 533 534

 protected:
  friend class TVMArgsSetter;
  friend class TVMRetValue;
  TVMPODValue_() : type_code_(kNull) {}
  TVMPODValue_(TVMValue value, int type_code)
      : value_(value), type_code_(type_code) {}

535 536 537 538 539 540 541 542 543 544 545 546 547 548
  /*! \brief The value */
  TVMValue value_;
  /*! \brief the type code */
  int type_code_;
};

/*!
 * \brief A single argument value to PackedFunc.
 *  Containing both type_code and TVMValue
 *
 *  Provides utilities to do type cast into other types.
 */
class TVMArgValue : public TVMPODValue_ {
 public:
549 550
  /*! \brief default constructor */
  TVMArgValue() {}
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
  /*!
   * \brief constructor
   * \param value of the function
   * \param type_code The type code.
   */
  TVMArgValue(TVMValue value, int type_code)
      : TVMPODValue_(value, type_code) {
  }
  // reuse converter from parent
  using TVMPODValue_::operator double;
  using TVMPODValue_::operator int64_t;
  using TVMPODValue_::operator uint64_t;
  using TVMPODValue_::operator int;
  using TVMPODValue_::operator bool;
  using TVMPODValue_::operator void*;
566 567
  using TVMPODValue_::operator DLTensor*;
  using TVMPODValue_::operator NDArray;
568
  using TVMPODValue_::operator TVMContext;
569
  using TVMPODValue_::operator Object;
570

571 572
  // conversion operator.
  operator std::string() const {
573 574
    if (type_code_ == kTVMType) {
      return TVMType2String(operator TVMType());
575 576 577 578 579 580
    } else if (type_code_ == kBytes) {
      TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
      return std::string(arr->data, arr->size);
    } else {
      TVM_CHECK_TYPE_CODE(type_code_, kStr);
      return std::string(value_.v_str);
581
    }
582 583 584 585 586
  }
  operator TVMType() const {
    if (type_code_ == kStr) {
      return String2TVMType(operator std::string());
    }
587 588 589 590 591 592
    // None type
    if (type_code_ == kNull) {
      TVMType t;
      t.code = kHandle; t.bits = 0; t.lanes = 0;
      return t;
    }
593 594 595 596
    TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
    return value_.v_type;
  }
  operator PackedFunc() const {
597
    if (type_code_ == kNull) return PackedFunc();
598 599 600
    TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
    return *ptr<PackedFunc>();
  }
601 602 603 604
  template<typename FType>
  operator TypedPackedFunc<FType>() const {
    return TypedPackedFunc<FType>(operator PackedFunc());
  }
605 606 607 608
  operator Module() const {
    TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
    return *ptr<Module>();
  }
609 610 611
  const TVMValue& value() const {
    return value_;
  }
612 613 614 615
  // Deferred extension handler.
  template<typename TNodeRef>
  inline TNodeRef AsNodeRef() const;
  template<typename T,
616
           typename = typename std::enable_if<
617
           std::is_class<T>::value>::type>
618
  inline operator T() const;
619 620 621 622
  template<typename TNodeRef,
           typename = typename std::enable_if<
             std::is_class<TNodeRef>::value>::type>
  inline bool IsNodeType() const;
623 624
  inline operator tvm::DataType() const;
  inline operator tvm::Expr() const;
625
  inline operator tvm::Integer() const;
626
  // get internal node ptr, if it is node
627
  inline NodePtr<Node>& node_sptr();
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
};

/*!
 * \brief Return Value container,
 *  Unlike TVMArgValue, which only holds reference and do not delete
 *  the underlying container during destruction.
 *
 *  TVMRetValue holds value and will manage the underlying containers
 *  when it stores a complicated data type.
 */
class TVMRetValue : public TVMPODValue_ {
 public:
  /*! \brief default constructor */
  TVMRetValue() {}
  /*!
   * \brief move constructor from anoter return value.
   * \param other The other return value.
   */
  TVMRetValue(TVMRetValue&& other)
      : TVMPODValue_(other.value_, other.type_code_) {
648 649
    other.value_.v_handle = nullptr;
    other.type_code_ = kNull;
650 651 652 653 654 655 656 657 658 659 660 661
  }
  /*! \brief destructor */
  ~TVMRetValue() {
    this->Clear();
  }
  // reuse converter from parent
  using TVMPODValue_::operator double;
  using TVMPODValue_::operator int64_t;
  using TVMPODValue_::operator uint64_t;
  using TVMPODValue_::operator int;
  using TVMPODValue_::operator bool;
  using TVMPODValue_::operator void*;
662
  using TVMPODValue_::operator DLTensor*;
663
  using TVMPODValue_::operator TVMContext;
664
  using TVMPODValue_::operator NDArray;
665
  using TVMPODValue_::operator Object;
666
  TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
667 668 669 670
    this->Assign(other);
  }
  // conversion operators
  operator std::string() const {
671 672
    if (type_code_ == kTVMType) {
      return TVMType2String(operator TVMType());
673 674
    } else if (type_code_ == kBytes) {
      return *ptr<std::string>();
675
    }
676 677 678 679 680 681 682 683 684 685 686
    TVM_CHECK_TYPE_CODE(type_code_, kStr);
    return *ptr<std::string>();
  }
  operator TVMType() const {
    if (type_code_ == kStr) {
      return String2TVMType(operator std::string());
    }
    TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
    return value_.v_type;
  }
  operator PackedFunc() const {
687
    if (type_code_ == kNull) return PackedFunc();
688 689 690
    TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
    return *ptr<PackedFunc>();
  }
691 692 693 694
  template<typename FType>
  operator TypedPackedFunc<FType>() const {
    return TypedPackedFunc<FType>(operator PackedFunc());
  }
695 696 697 698
  operator Module() const {
    TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
    return *ptr<Module>();
  }
699 700 701 702 703 704 705 706 707
  // Assign operators
  TVMRetValue& operator=(TVMRetValue&& other) {
    this->Clear();
    value_ = other.value_;
    type_code_ = other.type_code_;
    other.type_code_ = kNull;
    return *this;
  }
  TVMRetValue& operator=(double value) {
708
    this->SwitchToPOD(kDLFloat);
709 710 711 712 713 714 715 716 717 718 719 720 721 722
    value_.v_float64 = value;
    return *this;
  }
  TVMRetValue& operator=(std::nullptr_t value) {
    this->SwitchToPOD(kNull);
    value_.v_handle = value;
    return *this;
  }
  TVMRetValue& operator=(void* value) {
    this->SwitchToPOD(kHandle);
    value_.v_handle = value;
    return *this;
  }
  TVMRetValue& operator=(int64_t value) {
723
    this->SwitchToPOD(kDLInt);
724 725 726 727
    value_.v_int64 = value;
    return *this;
  }
  TVMRetValue& operator=(int value) {
728
    this->SwitchToPOD(kDLInt);
729 730 731
    value_.v_int64 = value;
    return *this;
  }
732 733 734 735 736
  TVMRetValue& operator=(TVMContext value) {
    this->SwitchToPOD(kTVMContext);
    value_.v_ctx = value;
    return *this;
  }
737 738 739 740 741 742
  TVMRetValue& operator=(TVMType t) {
    this->SwitchToPOD(kTVMType);
    value_.v_type = t;
    return *this;
  }
  TVMRetValue& operator=(bool value) {
743
    this->SwitchToPOD(kDLInt);
744 745 746 747 748 749 750
    value_.v_int64 = value;
    return *this;
  }
  TVMRetValue& operator=(std::string value) {
    this->SwitchToClass(kStr, value);
    return *this;
  }
751 752 753 754
  TVMRetValue& operator=(TVMByteArray value) {
    this->SwitchToClass(kBytes, std::string(value.data, value.size));
    return *this;
  }
755 756 757 758 759 760 761
  TVMRetValue& operator=(NDArray other) {
    this->Clear();
    type_code_ = kNDArrayContainer;
    value_.v_handle = other.data_;
    other.data_ = nullptr;
    return *this;
  }
762 763
  TVMRetValue& operator=(Object other) {
    this->Clear();
764
    type_code_ = kObjectCell;
765 766 767 768
    value_.v_handle = other.ptr_.data_;
    other.ptr_.data_ = nullptr;
    return *this;
  }
769 770 771 772
  TVMRetValue& operator=(PackedFunc f) {
    this->SwitchToClass(kFuncHandle, f);
    return *this;
  }
773 774 775 776
  template<typename FType>
  TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
    return operator=(f.packed());
  }
777 778 779 780
  TVMRetValue& operator=(Module m) {
    this->SwitchToClass(kModuleHandle, m);
    return *this;
  }
781 782 783 784
  TVMRetValue& operator=(const TVMRetValue& other) {  // NOLINT(*0
    this->Assign(other);
    return *this;
  }
785
  TVMRetValue& operator=(const TVMArgValue& other) {
786 787 788
    this->Assign(other);
    return *this;
  }
789 790
  template<typename T,
           typename = typename std::enable_if<
791
             extension_type_info<T>::code != 0>::type>
792 793
  TVMRetValue& operator=(const T& other) {
    this->SwitchToClass<T>(
794
        extension_type_info<T>::code, other);
795 796
    return *this;
  }
797 798 799 800 801 802 803 804 805 806 807 808
  /*!
   * \brief Move the value back to front-end via C API.
   *  This marks the current container as null.
   *  The managed resources is moved to front-end and
   *  the front end should take charge in managing them.
   *
   * \param ret_value The return value.
   * \param ret_type_code The return type code.
   */
  void MoveToCHost(TVMValue* ret_value,
                   int* ret_type_code) {
    // cannot move str; need specially handle.
809
    CHECK(type_code_ != kStr && type_code_ != kBytes);
810 811 812 813
    *ret_value = value_;
    *ret_type_code = type_code_;
    type_code_ = kNull;
  }
814 815 816 817
  /*! \return The value field, if the data is POD */
  const TVMValue& value() const {
    CHECK(type_code_ != kNodeHandle &&
          type_code_ != kFuncHandle &&
818
          type_code_ != kModuleHandle &&
819 820 821
          type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
    return value_;
  }
822
  // NodeRef related extenstions: in tvm/packed_func_ext.h
823 824 825 826 827 828
  template<typename T,
           typename = typename std::enable_if<
             std::is_class<T>::value>::type>
  inline operator T() const;
  template<typename TNodeRef>
  inline TNodeRef AsNodeRef() const;
829
  inline TVMRetValue& operator=(const NodeRef& other);
830
  inline TVMRetValue& operator=(const NodePtr<Node>& other);
831
  // type related
832 833
  inline operator tvm::DataType() const;
  inline TVMRetValue& operator=(const tvm::DataType& other);
834 835 836 837 838

 private:
  template<typename T>
  void Assign(const T& other) {
    switch (other.type_code()) {
839
      case kStr: {
840 841 842
        SwitchToClass<std::string>(kStr, other);
        break;
      }
843 844 845 846
      case kBytes: {
        SwitchToClass<std::string>(kBytes, other);
        break;
      }
847 848 849 850
      case kFuncHandle: {
        SwitchToClass<PackedFunc>(kFuncHandle, other);
        break;
      }
851
      case kModuleHandle: {
852
        SwitchToClass<Module>(kModuleHandle, other);
853 854
        break;
      }
855 856 857 858
      case kNDArrayContainer: {
        *this = other.operator NDArray();
        break;
      }
859
      case kNodeHandle: {
860 861
        SwitchToClass<NodePtr<Node> >(
            kNodeHandle, *other.template ptr<NodePtr<Node> >());
862 863
        break;
      }
864
      case kObjectCell: {
865 866 867
        *this = other.operator Object();
        break;
      }
868
      default: {
869 870 871 872
        if (other.type_code() < kExtBegin) {
          SwitchToPOD(other.type_code());
          value_ = other.value_;
        } else {
873 874 875
#if TVM_RUNTIME_HEADER_ONLY
          LOG(FATAL) << "Header only mode do not support ext type";
#else
876 877 878 879 880
          this->Clear();
          type_code_ = other.type_code();
          value_.v_handle =
              (*(ExtTypeVTable::Get(other.type_code())->clone))(
                  other.value().v_handle);
881
#endif
882
        }
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
        break;
      }
    }
  }
  // get the internal container.
  void SwitchToPOD(int type_code) {
    if (type_code_ != type_code) {
      this->Clear();
      type_code_ = type_code;
    }
  }
  template<typename T>
  void SwitchToClass(int type_code, T v) {
    if (type_code_ != type_code) {
      this->Clear();
      type_code_ = type_code;
      value_.v_handle = new T(v);
    } else {
      *static_cast<T*>(value_.v_handle) = v;
    }
  }
  void Clear() {
    if (type_code_ == kNull) return;
    switch (type_code_) {
      case kStr: delete ptr<std::string>(); break;
      case kFuncHandle: delete ptr<PackedFunc>(); break;
909
      case kModuleHandle: delete ptr<Module>(); break;
910
      case kNodeHandle: delete ptr<NodePtr<Node> >(); break;
911 912 913 914
      case kNDArrayContainer: {
        static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
        break;
      }
915
      case kObjectCell: {
916 917 918
        static_cast<ObjectCell*>(value_.v_handle)->DecRef();
        break;
      }
919
    }
920
    if (type_code_ > kExtBegin) {
921 922 923
#if TVM_RUNTIME_HEADER_ONLY
          LOG(FATAL) << "Header only mode do not support ext type";
#else
924
      (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
925
#endif
926
    }
927 928 929 930 931 932 933
    type_code_ = kNull;
  }
};

// implementation details
inline const char* TypeCode2Str(int type_code) {
  switch (type_code) {
934 935 936
    case kDLInt: return "int";
    case kDLUInt: return "uint";
    case kDLFloat: return "float";
937
    case kStr: return "str";
938
    case kBytes: return "bytes";
939
    case kHandle: return "handle";
940 941 942 943
    case kNull: return "NULL";
    case kNodeHandle: return "NodeHandle";
    case kArrayHandle: return "ArrayHandle";
    case kTVMType: return "TVMType";
944
    case kTVMContext: return "TVMContext";
945
    case kFuncHandle: return "FunctionHandle";
946
    case kModuleHandle: return "ModuleHandle";
947
    case kNDArrayContainer: return "NDArrayContainer";
948
    case kObjectCell: return "ObjectCell";
949 950 951 952 953
    default: LOG(FATAL) << "unknown type_code="
                        << static_cast<int>(type_code); return "";
  }
}

nhynes committed
954
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
955
inline std::ostream& operator<<(std::ostream& os, TVMType t) {  // NOLINT(*)
956 957 958
  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
    os << "bool"; return os;
  }
959
  if (t.code < kCustomBegin) {
960
    os << TypeCode2Str(t.code);
961 962
  } else {
    os << "custom[" << GetCustomTypeName(t.code) << "]";
963
  }
964 965
  if (t.code == kHandle) return os;
  os << static_cast<int>(t.bits);
966 967 968 969 970
  if (t.lanes != 1) {
    os << 'x' << static_cast<int>(t.lanes);
  }
  return os;
}
971

nhynes committed
972
#endif
973 974

inline std::string TVMType2String(TVMType t) {
975
  if (t.bits == 0) return "";
nhynes committed
976
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
977 978 979
  std::ostringstream os;
  os << t;
  return os.str();
nhynes committed
980
#else
981 982 983
  if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
    return "bool";
  }
984
  if (t.code < kCustomBegin) {
985
    repr += TypeCode2Str(t.code);
986 987
  } else {
    repr += "custom[" + GetCustomTypeName(t.code) + "]";
988
  }
nhynes committed
989 990 991 992 993 994 995
  if (t.code == kHandle) return repr;
  repr += std::to_string(static_cast<int>(t.bits));
  if (t.lanes != 1) {
    repr += "x" + std::to_string(static_cast<int>(t.lanes));
  }
  return repr;
#endif
996 997
}

998 999
inline TVMType String2TVMType(std::string s) {
  TVMType t;
1000 1001 1002 1003 1004
  // handle None type
  if (s.length() == 0) {
    t.bits = 0; t.lanes = 0; t.code = kHandle;
    return t;
  }
1005 1006 1007
  t.bits = 32; t.lanes = 1;
  const char* scan;
  if (s.substr(0, 3) == "int") {
1008
    t.code = kDLInt;  scan = s.c_str() + 3;
1009
  } else if (s.substr(0, 4) == "uint") {
1010
    t.code = kDLUInt; scan = s.c_str() + 4;
1011
  } else if (s.substr(0, 5) == "float") {
1012
    t.code = kDLFloat; scan = s.c_str() + 5;
1013
  } else if (s.substr(0, 6) == "handle") {
1014 1015 1016
    t.code = kHandle;
    t.bits = 64;  // handle uses 64 bit by default.
    scan = s.c_str() + 6;
1017 1018 1019 1020 1021
  } else if (s == "bool") {
    t.code = kDLUInt;
    t.bits = 1;
    t.lanes = 1;
    return t;
1022 1023
  } else if (s.substr(0, 6) == "custom") {
    t.code = ParseCustomDatatype(s, &scan);
1024 1025 1026 1027
  } else {
    scan = s.c_str();
    LOG(FATAL) << "unknown type " << s;
  }
nhynes committed
1028
  char* xdelim;  // emulate sscanf("%ux%u", bits, lanes)
1029 1030
  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
  if (bits != 0) t.bits = bits;
1031
  char* endpt = xdelim;
nhynes committed
1032
  if (*xdelim == 'x') {
1033
    t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
nhynes committed
1034
  }
1035
  CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
1036 1037 1038 1039 1040 1041 1042
  return t;
}

inline TVMArgValue TVMArgs::operator[](int i) const {
  CHECK_LT(i, num_args)
      << "not enough argument passed, "
      << num_args << " passed"
1043
      << " but request arg[" << i << "].";
1044 1045 1046 1047 1048 1049 1050 1051 1052
  return TVMArgValue(values[i], type_codes[i]);
}

inline int TVMArgs::size() const {
  return num_args;
}

inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
  body_(args, rv);
1053 1054
}

1055 1056 1057 1058
inline PackedFunc::FType PackedFunc::body() const {
  return body_;
}

1059 1060


1061 1062
// internal namespace
namespace detail {
1063 1064

template<bool stop, std::size_t I, typename F>
1065
struct for_each_dispatcher {
1066 1067 1068 1069 1070
  template<typename T, typename ...Args>
  static void run(const F& f, T&& value, Args&&... args) {  // NOLINT(*)
    f(I, std::forward<T>(value));
    for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
        ::run(f, std::forward<Args>(args)...);
1071 1072 1073
  }
};

1074 1075 1076
template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F>  {
  static void run(const F& f) {}  // NOLINT(*)
1077 1078 1079
};

template<typename F, typename ...Args>
1080 1081 1082
inline void for_each(const F& f, Args&&... args) {  // NOLINT(*)
  for_each_dispatcher<sizeof...(Args) == 0, 0, F>
      ::run(f, std::forward<Args>(args)...);
1083
}
1084
}  // namespace detail
1085

1086 1087 1088
/* \brief argument settter to PackedFunc */
class TVMArgsSetter {
 public:
1089 1090
  TVMArgsSetter(TVMValue* values, int* type_codes)
      : values_(values), type_codes_(type_codes) {}
1091 1092
  // setters for POD types
  template<typename T,
1093 1094
           typename = typename std::enable_if<
             std::is_integral<T>::value>::type>
1095 1096
  void operator()(size_t i, T value) const {
    values_[i].v_int64 = static_cast<int64_t>(value);
1097
    type_codes_[i] = kDLInt;
1098 1099 1100 1101 1102
  }
  void operator()(size_t i, uint64_t value) const {
    values_[i].v_int64 = static_cast<int64_t>(value);
    CHECK_LE(value,
             static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1103
    type_codes_[i] = kDLInt;
1104 1105 1106
  }
  void operator()(size_t i, double value) const {
    values_[i].v_float64 = value;
1107
    type_codes_[i] = kDLFloat;
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
  }
  void operator()(size_t i, std::nullptr_t value) const {
    values_[i].v_handle = value;
    type_codes_[i] = kNull;
  }
  void operator()(size_t i, const TVMArgValue& value) const {
    values_[i] = value.value_;
    type_codes_[i] = value.type_code_;
  }
  void operator()(size_t i, void* value) const {
    values_[i].v_handle = value;
    type_codes_[i] = kHandle;
  }
1121
  void operator()(size_t i, DLTensor* value) const {
1122 1123 1124
    values_[i].v_handle = value;
    type_codes_[i] = kArrayHandle;
  }
1125 1126 1127 1128
  void operator()(size_t i, TVMContext value) const {
    values_[i].v_ctx = value;
    type_codes_[i] = kTVMContext;
  }
1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
  void operator()(size_t i, TVMType value) const {
    values_[i].v_type = value;
    type_codes_[i] = kTVMType;
  }
  void operator()(size_t i, const char* value) const {
    values_[i].v_str = value;
    type_codes_[i] = kStr;
  }
  // setters for container type
  // They must be reference(instead of const ref)
  // to make sure they are alive in the tuple(instead of getting converted)
1140
  void operator()(size_t i, const std::string& value) const {  // NOLINT(*)
1141 1142 1143
    values_[i].v_str = value.c_str();
    type_codes_[i] = kStr;
  }
1144 1145
  void operator()(size_t i, const TVMByteArray& value) const {  // NOLINT(*)
    values_[i].v_handle = const_cast<TVMByteArray*>(&value);
1146 1147
    type_codes_[i] = kBytes;
  }
1148 1149
  void operator()(size_t i, const PackedFunc& value) const {  // NOLINT(*)
    values_[i].v_handle = const_cast<PackedFunc*>(&value);
1150 1151
    type_codes_[i] = kFuncHandle;
  }
1152 1153 1154 1155
  template<typename FType>
  void operator()(size_t i, const TypedPackedFunc<FType>& value) const {  // NOLINT(*)
    operator()(i, value.packed());
  }
1156 1157
  void operator()(size_t i, const Module& value) const {  // NOLINT(*)
    values_[i].v_handle = const_cast<Module*>(&value);
1158 1159
    type_codes_[i] = kModuleHandle;
  }
1160 1161 1162 1163
  void operator()(size_t i, const NDArray& value) const {  // NOLINT(*)
    values_[i].v_handle = value.data_;
    type_codes_[i] = kNDArrayContainer;
  }
1164
  void operator()(size_t i, const TVMRetValue& value) const {  // NOLINT(*)
1165 1166 1167 1168
    if (value.type_code() == kStr) {
      values_[i].v_str = value.ptr<std::string>()->c_str();
      type_codes_[i] = kStr;
    } else {
1169
      CHECK_NE(value.type_code(), kBytes) << "not handled.";
1170 1171 1172 1173
      values_[i] = value.value_;
      type_codes_[i] = value.type_code();
    }
  }
1174 1175 1176
  // extension
  template<typename T,
           typename = typename std::enable_if<
1177
             extension_type_info<T>::code != 0>::type>
1178
  inline void operator()(size_t i, const T& value) const;
1179
  // NodeRef related extenstions: in tvm/packed_func_ext.h
1180
  inline void operator()(size_t i, const NodeRef& other) const;  // NOLINT(*)
1181
  inline void operator()(size_t i, const tvm::DataType& t) const;
1182 1183 1184 1185 1186 1187 1188 1189

 private:
  /*! \brief The values fields */
  TVMValue* values_;
  /*! \brief The type code fields */
  int* type_codes_;
};

1190
template<typename... Args>
1191
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
1192
  const int kNumArgs = sizeof...(Args);
1193 1194 1195
  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
  TVMValue values[kArraySize];
  int type_codes[kArraySize];
1196
  detail::for_each(TVMArgsSetter(values, type_codes),
1197
                   std::forward<Args>(args)...);
1198 1199 1200
  TVMRetValue rv;
  body_(TVMArgs(values, type_codes, kNumArgs), &rv);
  return rv;
1201
}
1202

1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
namespace detail {
template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
  template<typename ...Args>
  static void run(const F& f,
                  const TVMArgs& args_pack,
                  TVMRetValue* rv,
                  Args&&... unpacked_args) {
    unpack_call_dispatcher<R, nleft - 1, index + 1, F>
        ::run(f, args_pack, rv,
              std::forward<Args>(unpacked_args)...,
              args_pack[index]);
  }
};

template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
  template<typename ...Args>
  static void run(const F& f,
                  const TVMArgs& args_pack,
                  TVMRetValue* rv,
                  Args&&... unpacked_args) {
    *rv = R(f(std::forward<Args>(unpacked_args)...));
  }
};

template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
  template<typename ...Args>
  static void run(const F& f,
                  const TVMArgs& args_pack,
                  TVMRetValue* rv,
                  Args&&... unpacked_args) {
    f(std::forward<Args>(unpacked_args)...);
  }
};

template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
  unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}

template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
  return R(pf(std::forward<Args>(args)...));
}

template<typename R>
struct typed_packed_call_dispatcher {
  template<typename ...Args>
  static inline R run(const PackedFunc& pf, Args&& ...args) {
    return pf(std::forward<Args>(args)...);
  }
};

template<>
struct typed_packed_call_dispatcher<void> {
  template<typename ...Args>
  static inline void run(const PackedFunc& pf, Args&& ...args) {
    pf(std::forward<Args>(args)...);
  }
};
}  // namespace detail

template<typename R, typename ...Args>
1268 1269 1270 1271
TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
  : packed_(packed) {}

template<typename R, typename ...Args>
1272 1273 1274 1275 1276 1277 1278 1279
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value)
    : packed_(value.operator PackedFunc()) {}

template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
    : packed_(value.operator PackedFunc()) {}

template<typename R, typename ...Args>
1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292
template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
  packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
      detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
    });
}

template<typename R, typename ...Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
  return detail::typed_packed_call_dispatcher<R>
      ::run(packed_, std::forward<Args>(args)...);
}

1293 1294
// extension and node type handling
namespace detail {
1295
template<typename T, typename TSrc, bool is_ext, bool is_nd>
1296 1297
struct TVMValueCast {
  static T Apply(const TSrc* self) {
1298
    static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
1299 1300 1301 1302 1303
    return self->template AsNodeRef<T>();
  }
};

template<typename T, typename TSrc>
1304
struct TVMValueCast<T, TSrc, true, false> {
1305 1306 1307 1308
  static T Apply(const TSrc* self) {
    return self->template AsExtension<T>();
  }
};
1309 1310 1311 1312 1313 1314 1315 1316

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
  static T Apply(const TSrc* self) {
    return self->template AsNDArray<T>();
  }
};

1317 1318 1319 1320 1321
}  // namespace detail

template<typename T, typename>
inline TVMArgValue::operator T() const {
  return detail::
1322 1323 1324
      TVMValueCast<T, TVMArgValue,
                   (extension_type_info<T>::code != 0),
                   (array_type_info<T>::code > 0)>
1325 1326 1327 1328 1329 1330
      ::Apply(this);
}

template<typename T, typename>
inline TVMRetValue::operator T() const {
  return detail::
1331 1332 1333
      TVMValueCast<T, TVMRetValue,
                   (extension_type_info<T>::code != 0),
                   (array_type_info<T>::code > 0)>
1334 1335 1336
      ::Apply(this);
}

1337 1338
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
1339
  static_assert(extension_type_info<T>::code != 0,
1340
                "Need to have extesion code");
1341
  type_codes_[i] = extension_type_info<T>::code;
1342 1343 1344
  values_[i].v_handle = const_cast<T*>(&value);
}

1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355
// extension type handling
template<typename T>
struct ExtTypeInfo {
  static void destroy(void* handle) {
    delete static_cast<T*>(handle);
  }
  static void* clone(void* handle) {
    return new T(*static_cast<T*>(handle));
  }
};

1356 1357
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
1358
  const int code = extension_type_info<T>::code;
1359
  static_assert(code != 0,
1360
                "require extension_type_info traits to be declared with non-zero code");
1361 1362 1363 1364
  ExtTypeVTable vt;
  vt.clone = ExtTypeInfo<T>::clone;
  vt.destroy = ExtTypeInfo<T>::destroy;
  return ExtTypeVTable::RegisterInternal(code, vt);
1365
}
1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379

// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
  PackedFunc pf = node_->GetFunction(name, node_);
  if (pf != nullptr) return pf;
  if (query_imports) {
    for (const Module& m : node_->imports_) {
      pf = m.node_->GetFunction(name, m.node_);
      if (pf != nullptr) return pf;
    }
  }
  return pf;
}
1380 1381
}  // namespace runtime
}  // namespace tvm
1382
#endif  // TVM_RUNTIME_PACKED_FUNC_H_