Unverified Commit b171cf1d by Tianqi Chen Committed by GitHub

[REFACTOR] Polish runtime (#4729)

- Remove operator bool from base object ref macro
  - Raitionale: operator bool can be dangerous for sub-classes
    that also overloads other operators(e.g. ==).
  - If bool is still needed, use explicit operator bool.
- Use absolute include when necessary
- Move type related util to data_type
- Isolate stackvm code from compiler
parent eaa23800
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "expr.h" #include "expr.h"
#include "runtime/util.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; ...@@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
*/ */
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic } // namespace intrinsic
/*! /*!
......
...@@ -49,15 +49,7 @@ class BaseExprNode : public Object { ...@@ -49,15 +49,7 @@ class BaseExprNode : public Object {
*/ */
class BaseExpr : public ObjectRef { class BaseExpr : public ObjectRef {
public: public:
/*! \brief Cosntructor */ TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
}; };
/*! /*!
...@@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode { ...@@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode {
*/ */
class PrimExpr : public BaseExpr { class PrimExpr : public BaseExpr {
public: public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*! /*!
* \brief construct from integer. * \brief construct from integer.
* \param value The value to be constructed. * \param value The value to be constructed.
...@@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr { ...@@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr {
DataType dtype() const { DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype; return static_cast<const PrimExprNode*>(get())->dtype;
} }
/*! \brief The container type. */
using ContainerType = PrimExprNode; TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
}; };
/*! /*!
...@@ -157,28 +142,13 @@ class IntImmNode : public PrimExprNode { ...@@ -157,28 +142,13 @@ class IntImmNode : public PrimExprNode {
class IntImm : public PrimExpr { class IntImm : public PrimExpr {
public: public:
/*! /*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor. * \brief Constructor.
* \param dtype The data type of the value. * \param dtype The data type of the value.
* \param value The internal value. * \param value The internal value.
*/ */
TVM_DLL IntImm(DataType dtype, int64_t value); TVM_DLL IntImm(DataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value. TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;
}; };
/*! /*!
...@@ -207,28 +177,13 @@ class FloatImmNode : public PrimExprNode { ...@@ -207,28 +177,13 @@ class FloatImmNode : public PrimExprNode {
class FloatImm : public PrimExpr { class FloatImm : public PrimExpr {
public: public:
/*! /*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor. * \brief Constructor.
* \param dtype The data type of the value. * \param dtype The data type of the value.
* \param value The internal value. * \param value The internal value.
*/ */
TVM_DLL FloatImm(DataType dtype, double value); TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container. TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
}; };
/*! /*!
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#ifndef TVM_RUNTIME_C_BACKEND_API_H_ #ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_ #define TVM_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h" #include <tvm/runtime/c_runtime_api.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <type_traits> #include <type_traits>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
...@@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) { ...@@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) {
return data_bits / 8; return data_bits / 8;
} }
/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime } // namespace runtime
using DataType = runtime::DataType; using DataType = runtime::DataType;
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#ifndef TVM_RUNTIME_DEVICE_API_H_ #ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#ifndef TVM_RUNTIME_MEMORY_H_ #ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_ #define TVM_RUNTIME_MEMORY_H_
#include <tvm/runtime/object.h>
#include <cstdlib> #include <cstdlib>
#include <utility> #include <utility>
#include <type_traits> #include <type_traits>
#include "object.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <string> #include <string>
#include <utility> #include <utility>
/*! /*!
* \brief Whether or not use atomic reference counter. * \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic, * If the reference counter is not atomic,
...@@ -715,7 +714,6 @@ struct ObjectEqual { ...@@ -715,7 +714,6 @@ struct ObjectEqual {
const ObjectName* operator->() const { \ const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \ return static_cast<const ObjectName*>(data_.get()); \
} \ } \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
/* /*
...@@ -734,7 +732,6 @@ struct ObjectEqual { ...@@ -734,7 +732,6 @@ struct ObjectEqual {
ObjectName* operator->() const { \ ObjectName* operator->() const { \
return static_cast<ObjectName*>(data_.get()); \ return static_cast<ObjectName*>(data_.get()); \
} \ } \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
/*! /*!
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "c_runtime_api.h" #include <tvm/runtime/c_runtime_api.h>
#include "ndarray.h" #include <tvm/runtime/ndarray.h>
namespace dmlc { namespace dmlc {
namespace serializer { namespace serializer {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/util.h
* \brief Useful runtime util.
*/
#ifndef TVM_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type.
*/
inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime
} // namespace tvm
// Forward declare the intrinsic id we need
// in structure fetch to enable stackvm in runtime
namespace tvm {
namespace ir {
namespace intrinsic {
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
} // namespace intrinsic
} // namespace ir
} // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_
...@@ -32,6 +32,28 @@ namespace codegen { ...@@ -32,6 +32,28 @@ namespace codegen {
using namespace ir; using namespace ir;
// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
switch (val) {
case intrinsic::kArrData: return StackVM::kArrData;
case intrinsic::kArrShape: return StackVM::kArrShape;
case intrinsic::kArrAddr: return StackVM::kArrAddr;
case intrinsic::kArrStrides: return StackVM::kArrStrides;
case intrinsic::kArrNDim: return StackVM::kArrNDim;
case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
default: LOG(FATAL) << "Do not know how to map field " << kind;
}
return StackVM::kArrData;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) { StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) { for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i]; Var v = f->args[i];
...@@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { ...@@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
vm_.code.push_back(code); vm_.code.push_back(code);
code.v_int = index->value; code.v_int = index->value;
vm_.code.push_back(code); vm_.code.push_back(code);
code.v_int = kind; code.v_int = MapFieldKind(kind);
vm_.code.push_back(code); vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U); CHECK_GE(op->args.size(), 5U);
...@@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) { ...@@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
vm_.code.push_back(code); vm_.code.push_back(code);
code.v_int = index->value; code.v_int = index->value;
vm_.code.push_back(code); vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImmNode>()->value; code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
vm_.code.push_back(code); vm_.code.push_back(code);
} else { } else {
this->Push(ev->value); this->Push(ev->value);
......
...@@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { ...@@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
then_for = IRTransform(for_stmt, nullptr, replace_then_case, then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{PrimExpr("IfThenElse")}); {PrimExpr("IfThenElse")});
if (if_stmt.as<IfThenElseNode>()->else_case) { if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case, else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{PrimExpr("IfThenElse")}); {PrimExpr("IfThenElse")});
} }
...@@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { ...@@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
for2if_map_[for_stmt.get()].push_back(head); for2if_map_[for_stmt.get()].push_back(head);
const IfThenElseNode* if_node = head.as<IfThenElseNode>(); const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case); tracker.push(if_node->then_case);
if (if_node->else_case) { if (if_node->else_case.defined()) {
tracker.push(if_node->else_case); tracker.push(if_node->else_case);
} }
......
...@@ -31,14 +31,6 @@ ...@@ -31,14 +31,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TensorType ToTensorType(const Type& t) {
if (const auto* tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
}
}
bool IdentityRel(const Array<Type>& types, bool IdentityRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -115,11 +107,11 @@ bool BroadcastRel(const Array<Type>& types, ...@@ -115,11 +107,11 @@ bool BroadcastRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl; // << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) { if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto t1 = ToTensorType(types[1])) { if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype); CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], reporter->Assign(types[2],
ConcreteBroadcast(t0, t1, t0->dtype)); ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
return true; return true;
} }
} }
...@@ -133,10 +125,11 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -133,10 +125,11 @@ bool BroadcastCompRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl; // << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) { if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto t1 = ToTensorType(types[1])) { if (auto* t1 = types[1].as<TensorTypeNode>()) {
CHECK_EQ(t0->dtype, t1->dtype); CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::DataType::Bool())); reporter->Assign(types[2],
ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), DataType::Bool()));
return true; return true;
} }
} }
......
...@@ -749,7 +749,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -749,7 +749,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic r = VisitExpr(op->ref, ll); PStatic r = VisitExpr(op->ref, ll);
if (r->pstatic.defined()) { if (r->pstatic.defined()) {
PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>()); PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
if (ret) { if (ret.defined()) {
return ret; return ret;
} }
} }
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include "gemm_common.h" #include "gemm_common.h"
extern "C" { extern "C" {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#pragma once #pragma once
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <algorithm> #include <algorithm>
namespace tvm { namespace tvm {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file Use external cblas library call. * \file Use external cblas library call.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "../cblas/gemm_common.h" #include "../cblas/gemm_common.h"
#include "cublas_utils.h" #include "cublas_utils.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file Use external cudnn utils function * \file Use external cudnn utils function
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include "cudnn_utils.h" #include "cudnn_utils.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file Use external miopen utils function * \file Use external miopen utils function
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include "miopen_utils.h" #include "miopen_utils.h"
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <vector> #include <vector>
#include "../../metal/metal_common.h" #include "../../metal/metal_common.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nnpack.h> #include <nnpack.h>
#include "nnpack_utils.h" #include "nnpack_utils.h"
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file Use external nnpack library call. * \file Use external nnpack library call.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nnpack.h> #include <nnpack.h>
#include "nnpack_utils.h" #include "nnpack_utils.h"
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
#define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <nnpack.h> #include <nnpack.h>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file External random functions for tensor. * \file External random functions for tensor.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file Use external rocblas library call. * \file Use external rocblas library call.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/data_type.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "rocblas.h" #include "rocblas.h"
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* \file stackvm.cc * \file stackvm.cc
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/util.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <algorithm> #include <algorithm>
#include "stackvm.h" #include "stackvm.h"
...@@ -392,50 +391,49 @@ void StackVM::Run(State* s) const { ...@@ -392,50 +391,49 @@ void StackVM::Run(State* s) const {
} }
// intrinsics // intrinsics
case TVM_STRUCT_GET: { case TVM_STRUCT_GET: {
using namespace ir;
int index = code[pc + 1].v_int; int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int; int kind = code[pc + 2].v_int;
DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle); DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
switch (kind) { switch (kind) {
case intrinsic::kArrData: { case StackVM::kArrData: {
stack[sp].v_handle = arr[index].data; break; stack[sp].v_handle = arr[index].data; break;
} }
case intrinsic::kArrShape: { case StackVM::kArrShape: {
stack[sp].v_handle = arr[index].shape; break; stack[sp].v_handle = arr[index].shape; break;
} }
case intrinsic::kArrStrides: { case StackVM::kArrStrides: {
stack[sp].v_handle = arr[index].strides; break; stack[sp].v_handle = arr[index].strides; break;
} }
case intrinsic::kArrNDim: { case StackVM::kArrNDim: {
stack[sp].v_int64 = arr[index].ndim; break; stack[sp].v_int64 = arr[index].ndim; break;
} }
case intrinsic::kArrTypeCode: { case StackVM::kArrTypeCode: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.code); break; arr[index].dtype.code); break;
} }
case intrinsic::kArrTypeBits: { case StackVM::kArrTypeBits: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.bits); break; arr[index].dtype.bits); break;
} }
case intrinsic::kArrTypeLanes: { case StackVM::kArrTypeLanes: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].dtype.lanes); break; arr[index].dtype.lanes); break;
} }
case intrinsic::kArrByteOffset: { case StackVM::kArrByteOffset: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].byte_offset); break; arr[index].byte_offset); break;
} }
case intrinsic::kArrDeviceId: { case StackVM::kArrDeviceId: {
stack[sp].v_int64 = arr[index].ctx.device_id; break; stack[sp].v_int64 = arr[index].ctx.device_id; break;
} }
case intrinsic::kArrDeviceType: { case StackVM::kArrDeviceType: {
stack[sp].v_int64 = static_cast<int64_t>( stack[sp].v_int64 = static_cast<int64_t>(
arr[index].ctx.device_type); break; arr[index].ctx.device_type); break;
} }
case intrinsic::kArrAddr: { case StackVM::kArrAddr: {
stack[sp].v_handle = arr + index; break; stack[sp].v_handle = arr + index; break;
} }
case intrinsic::kTVMValueContent: { case StackVM::kTVMValueContent: {
stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break; stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
} }
default: LOG(FATAL) << "unhandled get " << kind; default: LOG(FATAL) << "unhandled get " << kind;
...@@ -444,51 +442,50 @@ void StackVM::Run(State* s) const { ...@@ -444,51 +442,50 @@ void StackVM::Run(State* s) const {
break; break;
} }
case TVM_STRUCT_SET: { case TVM_STRUCT_SET: {
using namespace ir;
int index = code[pc + 1].v_int; int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int; int kind = code[pc + 2].v_int;
DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle); DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
switch (kind) { switch (kind) {
case intrinsic::kArrData: { case StackVM::kArrData: {
arr[index].data = stack[sp].v_handle; break; arr[index].data = stack[sp].v_handle; break;
} }
case intrinsic::kArrShape: { case StackVM::kArrShape: {
arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle); arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
break; break;
} }
case intrinsic::kArrStrides: { case StackVM::kArrStrides: {
arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle); arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
break; break;
} }
case intrinsic::kArrNDim: { case StackVM::kArrNDim: {
arr[index].ndim = static_cast<int>(stack[sp].v_int64); arr[index].ndim = static_cast<int>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrTypeCode: { case StackVM::kArrTypeCode: {
arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64); arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrTypeBits: { case StackVM::kArrTypeBits: {
arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64); arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrTypeLanes: { case StackVM::kArrTypeLanes: {
arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64); arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrByteOffset: { case StackVM::kArrByteOffset: {
arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64); arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrDeviceId: { case StackVM::kArrDeviceId: {
arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64); arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kArrDeviceType: { case StackVM::kArrDeviceType: {
arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64); arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
break; break;
} }
case intrinsic::kTVMValueContent: { case StackVM::kTVMValueContent: {
static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break; static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
} }
default: LOG(FATAL) << "unhandled tvm_struct_set " << kind; default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
......
...@@ -38,6 +38,7 @@ namespace tvm { ...@@ -38,6 +38,7 @@ namespace tvm {
namespace runtime { namespace runtime {
using runtime::operator<<; using runtime::operator<<;
/*! /*!
* \brief A simple stack-based virtual machine program. * \brief A simple stack-based virtual machine program.
*/ */
...@@ -283,6 +284,25 @@ class StackVM { ...@@ -283,6 +284,25 @@ class StackVM {
*/ */
TVM_STRUCT_SET TVM_STRUCT_SET
}; };
/*! \brief The kind of structure field info */
enum StructFieldKind : int {
// array head address
kArrAddr,
kArrData,
kArrShape,
kArrStrides,
kArrNDim,
kArrTypeCode,
kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
};
/*! \brief The code structure */ /*! \brief The code structure */
union Code { union Code {
OpCode op_code; OpCode op_code;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment