Unverified Commit b171cf1d by Tianqi Chen Committed by GitHub

[REFACTOR] Polish runtime (#4729)

- Remove operator bool from base object ref macro
  - Raitionale: operator bool can be dangerous for sub-classes
    that also overloads other operators(e.g. ==).
  - If bool is still needed, use explicit operator bool.
- Use absolute include when necessary
- Move type related util to data_type
- Isolate stackvm code from compiler
parent eaa23800
......@@ -30,7 +30,6 @@
#include <vector>
#include <utility>
#include "expr.h"
#include "runtime/util.h"
namespace tvm {
namespace ir {
......@@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic
/*!
......
......@@ -49,15 +49,7 @@ class BaseExprNode : public Object {
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
};
/*!
......@@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode {
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
......@@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr {
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;
TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
};
/*!
......@@ -157,28 +142,13 @@ class IntImmNode : public PrimExprNode {
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};
/*!
......@@ -207,28 +177,13 @@ class FloatImmNode : public PrimExprNode {
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};
/*!
......
......@@ -28,7 +28,7 @@
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
#include <tvm/runtime/c_runtime_api.h>
#ifdef __cplusplus
extern "C" {
......
......@@ -28,7 +28,6 @@
#include <dmlc/logging.h>
#include <type_traits>
namespace tvm {
namespace runtime {
/*!
......@@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) {
return data_bits / 8;
}
/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime
using DataType = runtime::DataType;
......
......@@ -24,9 +24,9 @@
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
......
......@@ -23,10 +23,10 @@
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_
#include <tvm/runtime/object.h>
#include <cstdlib>
#include <utility>
#include <type_traits>
#include "object.h"
namespace tvm {
namespace runtime {
......
......@@ -29,7 +29,6 @@
#include <string>
#include <utility>
/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
......@@ -715,7 +714,6 @@ struct ObjectEqual {
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
/*
......@@ -734,7 +732,6 @@ struct ObjectEqual {
ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
/*!
......
......@@ -27,8 +27,8 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "ndarray.h"
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>
namespace dmlc {
namespace serializer {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/util.h
* \brief Useful runtime util.
*/
#ifndef TVM_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime
} // namespace tvm
// Forward declare the intrinsic id we need
// in structure fetch to enable stackvm in runtime
namespace tvm {
namespace ir {
namespace intrinsic {
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic
} // namespace ir
} // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_
......@@ -32,6 +32,28 @@ namespace codegen {
using namespace ir;
// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
switch (val) {
case intrinsic::kArrData: return StackVM::kArrData;
case intrinsic::kArrShape: return StackVM::kArrShape;
case intrinsic::kArrAddr: return StackVM::kArrAddr;
case intrinsic::kArrStrides: return StackVM::kArrStrides;
case intrinsic::kArrNDim: return StackVM::kArrNDim;
case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
default: LOG(FATAL) << "Do not know how to map field " << kind;
}
return StackVM::kArrData;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
......@@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = kind;
code.v_int = MapFieldKind(kind);
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U);
......@@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImmNode>()->value;
code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
vm_.code.push_back(code);
} else {
this->Push(ev->value);
......
......@@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{PrimExpr("IfThenElse")});
if (if_stmt.as<IfThenElseNode>()->else_case) {
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{PrimExpr("IfThenElse")});
}
......@@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
for2if_map_[for_stmt.get()].push_back(head);
const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case);
if (if_node->else_case) {
if (if_node->else_case.defined()) {
tracker.push(if_node->else_case);
}
......
......@@ -31,14 +31,6 @@
namespace tvm {
namespace relay {
TensorType ToTensorType(const Type& t) {
if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
}
}
bool IdentityRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
......@@ -115,11 +107,11 @@ bool BroadcastRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2],
ConcreteBroadcast(t0, t1, t0->dtype));
ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
return true;
}
}
......@@ -133,10 +125,11 @@ bool BroadcastCompRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::DataType::Bool()));
reporter->Assign(types[2],
ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), DataType::Bool()));
return true;
}
}
......
......@@ -749,7 +749,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic r = VisitExpr(op->ref, ll);
if (r->pstatic.defined()) {
PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
if (ret) {
if (ret.defined()) {
return ret;
}
}
......
......@@ -22,7 +22,7 @@
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include "gemm_common.h"
extern "C" {
......
......@@ -24,7 +24,7 @@
#pragma once
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <algorithm>
namespace tvm {
......
......@@ -21,7 +21,7 @@
* \file Use external cblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
......
......@@ -21,7 +21,7 @@
* \file Use external cudnn utils function
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include "cudnn_utils.h"
......
......@@ -21,7 +21,7 @@
* \file Use external miopen utils function
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include "miopen_utils.h"
......
......@@ -29,7 +29,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <vector>
#include "../../metal/metal_common.h"
......
......@@ -22,7 +22,7 @@
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <nnpack.h>
#include "nnpack_utils.h"
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file Use external nnpack library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <nnpack.h>
#include "nnpack_utils.h"
......
......@@ -23,7 +23,7 @@
#ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
#define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <nnpack.h>
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file External random functions for tensor.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <algorithm>
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file Use external rocblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/data_type.h>
#include <dmlc/logging.h>
#include "rocblas.h"
......
......@@ -22,7 +22,6 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
......
......@@ -22,7 +22,6 @@
* \file stackvm.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/c_backend_api.h>
#include <algorithm>
#include "stackvm.h"
......@@ -392,50 +391,49 @@ void StackVM::Run(State* s) const {
}
// intrinsics
case TVM_STRUCT_GET: {
using namespace ir;
int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int;
DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
switch (kind) {
case intrinsic::kArrData: {
case StackVM::kArrData: {
stack[sp].v_handle = arr[index].data; break;
}
case intrinsic::kArrShape: {
case StackVM::kArrShape: {
stack[sp].v_handle = arr[index].shape; break;
}
case intrinsic::kArrStrides: {
case StackVM::kArrStrides: {
stack[sp].v_handle = arr[index].strides; break;
}
case intrinsic::kArrNDim: {
case StackVM::kArrNDim: {
stack[sp].v_int64 = arr[index].ndim; break;
}
case intrinsic::kArrTypeCode: {
case StackVM::kArrTypeCode: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.code); break;
}
case intrinsic::kArrTypeBits: {
case StackVM::kArrTypeBits: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.bits); break;
}
case intrinsic::kArrTypeLanes: {
case StackVM::kArrTypeLanes: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.lanes); break;
}
case intrinsic::kArrByteOffset: {
case StackVM::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break;
}
case intrinsic::kArrDeviceId: {
case StackVM::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break;
}
case intrinsic::kArrDeviceType: {
case StackVM::kArrDeviceType: {
stack[sp].v_int64 = static_cast<int64_t>(
arr[index].ctx.device_type); break;
}
case intrinsic::kArrAddr: {
case StackVM::kArrAddr: {
stack[sp].v_handle = arr + index; break;
}
case intrinsic::kTVMValueContent: {
case StackVM::kTVMValueContent: {
stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
}
default: LOG(FATAL) << "unhandled get " << kind;
......@@ -444,51 +442,50 @@ void StackVM::Run(State* s) const {
break;
}
case TVM_STRUCT_SET: {
using namespace ir;
int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int;
DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
switch (kind) {
case intrinsic::kArrData: {
case StackVM::kArrData: {
arr[index].data = stack[sp].v_handle; break;
}
case intrinsic::kArrShape: {
case StackVM::kArrShape: {
arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
break;
}
case intrinsic::kArrStrides: {
case StackVM::kArrStrides: {
arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
break;
}
case intrinsic::kArrNDim: {
case StackVM::kArrNDim: {
arr[index].ndim = static_cast<int>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeCode: {
case StackVM::kArrTypeCode: {
arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeBits: {
case StackVM::kArrTypeBits: {
arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrTypeLanes: {
case StackVM::kArrTypeLanes: {
arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrByteOffset: {
case StackVM::kArrByteOffset: {
arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
break;
}
case intrinsic::kArrDeviceId: {
case StackVM::kArrDeviceId: {
arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
break;
}
case intrinsic::kArrDeviceType: {
case StackVM::kArrDeviceType: {
arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
break;
}
case intrinsic::kTVMValueContent: {
case StackVM::kTVMValueContent: {
static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
}
default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
......
......@@ -38,6 +38,7 @@ namespace tvm {
namespace runtime {
using runtime::operator<<;
/*!
* \brief A simple stack-based virtual machine program.
*/
......@@ -283,6 +284,25 @@ class StackVM {
*/
TVM_STRUCT_SET
};
/*! \brief The kind of structure field info */
enum StructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
/*! \brief The code structure */
union Code {
OpCode op_code;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment