Unverified Commit 2f8a01f7 by Tianqi Chen Committed by GitHub

[REFACTOR] Get rid of packed_func_ext. (#4735)

Move the conversion extensions to the specific class definitions
so that we longer need to include packed_func_ext.
parent 703ed9b7
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/expr_operator.h>
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <iostream>
#include <limits>
#include "node/node.h" #include "node/node.h"
#include "node/container.h" #include "node/container.h"
#include "node/functor.h" #include "node/functor.h"
...@@ -460,6 +461,26 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { ...@@ -460,6 +461,26 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
} }
} // namespace tvm } // namespace tvm
namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
namespace std { namespace std {
template <> template <>
struct hash<::tvm::IterVar> : public ::tvm::ObjectHash { struct hash<::tvm::IterVar> : public ::tvm::ObjectHash {
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/ir/span.h> #include <tvm/ir/span.h>
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <string> #include <string>
#include <limits>
namespace tvm { namespace tvm {
...@@ -114,6 +115,11 @@ class PrimExpr : public BaseExpr { ...@@ -114,6 +115,11 @@ class PrimExpr : public BaseExpr {
} }
TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
private:
// Internal function for conversion.
friend class runtime::TVMPODValue_;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
}; };
/*! /*!
...@@ -322,4 +328,25 @@ inline const TTypeNode* RelayExprNode::type_as() const { ...@@ -322,4 +328,25 @@ inline const TTypeNode* RelayExprNode::type_as() const {
} }
} // namespace tvm } // namespace tvm
namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_ #endif // TVM_IR_EXPR_H_
...@@ -655,6 +655,65 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -655,6 +655,65 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key)); return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
} }
}; };
} // namespace tvm
namespace tvm {
namespace runtime {
// Additional overloads for PackedFunc checking.
template<typename T>
struct ObjectTypeChecker<Array<T> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
template<typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<MapNode>()) return false;
const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
} // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_CONTAINER_H_ #endif // TVM_NODE_CONTAINER_H_
...@@ -56,6 +56,9 @@ using runtime::Downcast; ...@@ -56,6 +56,9 @@ using runtime::Downcast;
using runtime::ObjectHash; using runtime::ObjectHash;
using runtime::ObjectEqual; using runtime::ObjectEqual;
using runtime::make_object; using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
} // namespace tvm } // namespace tvm
#endif // TVM_NODE_NODE_H_ #endif // TVM_NODE_NODE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass ObjectRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <tvm/top/tensor.h>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>
#include "expr.h"
#include "runtime/packed_func.h"
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
template<typename T>
struct ObjectTypeChecker<Array<T> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
template<typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<MapNode>()) return false;
const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
// extensions for tvm arg value
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Object>(ptr))->var;
}
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ObjectPtr<Object>(ptr));
}
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#ifndef TVM_RELAY_TRANSFORM_H_ #ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -184,7 +183,7 @@ TVM_DLL Pass InferType(); ...@@ -184,7 +183,7 @@ TVM_DLL Pass InferType();
* *
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
/*! /*!
* \brief Combine parallel 2d convolutions into a single convolution if the * \brief Combine parallel 2d convolutions into a single convolution if the
......
...@@ -28,9 +28,7 @@ ...@@ -28,9 +28,7 @@
#include <tvm/ir/type_relation.h> #include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <string> #include <string>
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
......
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/serialization.h> #include <tvm/node/serialization.h>
namespace tvm { namespace tvm {
......
...@@ -26,8 +26,6 @@ ...@@ -26,8 +26,6 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
......
...@@ -27,8 +27,6 @@ ...@@ -27,8 +27,6 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/top/schedule_pass.h> #include <tvm/top/schedule_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "../top/schedule/graph.h" #include "../top/schedule/graph.h"
......
...@@ -26,8 +26,6 @@ ...@@ -26,8 +26,6 @@
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
// Attrs used to python API // Attrs used to python API
......
...@@ -26,8 +26,6 @@ ...@@ -26,8 +26,6 @@
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
......
...@@ -26,8 +26,6 @@ ...@@ -26,8 +26,6 @@
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <stack> #include <stack>
#include <vector> #include <vector>
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
/*! /*!
* \file codegen_c_host.cc * \file codegen_c_host.cc
*/ */
#include <tvm/packed_func_ext.h> #include <tvm/codegen.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include "codegen_c_host.h" #include "codegen_c_host.h"
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_C_HOST_H_ #define TVM_CODEGEN_CODEGEN_C_HOST_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h> #include <tvm/ir.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <string> #include <string>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_CUDA_H_ #define TVM_CODEGEN_CODEGEN_CUDA_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h> #include <tvm/ir.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
/*! /*!
* \file codegen_metal.cc * \file codegen_metal.cc
*/ */
#include <tvm/packed_func_ext.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_METAL_H_ #define TVM_CODEGEN_CODEGEN_METAL_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
/*! /*!
* \file codegen_opencl.cc * \file codegen_opencl.cc
*/ */
#include <tvm/packed_func_ext.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <string> #include <string>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_OPENCL_H_ #define TVM_CODEGEN_CODEGEN_OPENCL_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
* We are targeting OpenGL 3.3. The reason of not targeting a recent version * We are targeting OpenGL 3.3. The reason of not targeting a recent version
* of OpenGL is to have better compatibility of WebGL 2. * of OpenGL is to have better compatibility of WebGL 2.
*/ */
#include <tvm/packed_func_ext.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_OPENGL_H_ #define TVM_CODEGEN_CODEGEN_OPENGL_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_VHLS_H_ #define TVM_CODEGEN_CODEGEN_VHLS_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/packed_func_ext.h> #include <tvm/ir.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -16,15 +16,15 @@ ...@@ -16,15 +16,15 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
#include "registry.h"
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include "registry.h"
namespace tvm { namespace tvm {
namespace datatype { namespace datatype {
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_REGISTER_GLOBAL("_datatype_register") TVM_REGISTER_GLOBAL("_datatype_register")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int())); datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file intrin_rule_default.cc * \file intrin_rule_default.cc
* \brief Default intrinsic rules. * \brief Default intrinsic rules.
*/ */
#include <tvm/expr_operator.h>
#include "intrin_rule.h" #include "intrin_rule.h"
namespace tvm { namespace tvm {
......
...@@ -27,9 +27,6 @@ ...@@ -27,9 +27,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
......
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <sstream> #include <sstream>
namespace tvm { namespace tvm {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <sstream> #include <sstream>
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* \file intrin_rule_spirv.cc * \file intrin_rule_spirv.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <GLSL.std.450.h> #include <GLSL.std.450.h>
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* \file codegen_stackvm.cc * \file codegen_stackvm.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "codegen_stackvm.h" #include "codegen_stackvm.h"
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#define TVM_IR_ATTR_FUNCTOR_H_ #define TVM_IR_ATTR_FUNCTOR_H_
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <tvm/ir.h>
#include <utility> #include <utility>
namespace tvm { namespace tvm {
......
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
*/ */
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "attr_functor.h" #include "attr_functor.h"
namespace tvm { namespace tvm {
......
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
// NOTE on dependencies on relay AsText. // NOTE: reverse dependency on relay.
// We calls into relay's printing module for better rendering. // These dependencies do not happen at the interface-level,
// These dependency does not happen at the interface-level. // and are only used in minimum cases where they are clearly marked.
// And is only used to enhance developer experiences when relay //
// functions are presented. // Rationale: use relay's printer for astext.
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
......
...@@ -23,9 +23,30 @@ ...@@ -23,9 +23,30 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
// NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: convert from IterVar and top::Tensor
#include <tvm/top/tensor.h>
#include <tvm/expr.h>
namespace tvm { namespace tvm {
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ptr)->var;
}
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ptr)();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ptr);
}
IntImm::IntImm(DataType dtype, int64_t value) { IntImm::IntImm(DataType dtype, int64_t value) {
CHECK(dtype.is_scalar()) CHECK(dtype.is_scalar())
<< "ValueError: IntImm can only take scalar."; << "ValueError: IntImm can only take scalar.";
......
...@@ -23,12 +23,11 @@ ...@@ -23,12 +23,11 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
// NOTE on dependencies on relay analysis. // NOTE: reverse dependency on relay.
// We calls into relay's analysis module to verify correctness // These dependencies do not happen at the interface-level,
// when a relay function is presented. // and are only used in minimum cases where they are clearly marked.
// These dependency does not happen at the interface-level. //
// And is only used to enhance developer experiences when relay // Rationale: We calls into relay's analysis module to verify correctness.
// functions are presented.
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
*/ */
#include <tvm/ir/span.h> #include <tvm/ir/span.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
......
...@@ -23,8 +23,6 @@ ...@@ -23,8 +23,6 @@
*/ */
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
PrimType::PrimType(runtime::DataType dtype) { PrimType::PrimType(runtime::DataType dtype) {
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h> #include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
TypeCall::TypeCall(Type func, tvm::Array<Type> args) { TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <tvm/arith/pattern.h> #include <tvm/arith/pattern.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "../arith/pattern_match.h" #include "../arith/pattern_match.h"
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
* \file ir_functor.cc * \file ir_functor.cc
*/ */
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include "../codegen/datatype/registry.h" #include "../codegen/datatype/registry.h"
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <unordered_set> #include <unordered_set>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "compile_engine.h" #include "compile_engine.h"
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/top/operation.h> #include <tvm/top/operation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h> #include <tvm/relay/attrs/device_copy.h>
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Memory index assignment pass for executing * \brief Memory index assignment pass for executing
* the program in the graph runtime. * the program in the graph runtime.
*/ */
#include <tvm/expr_operator.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* \file src/tvm/relay/interpreter.cc * \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR. * \brief An interpreter for the Relay IR.
*/ */
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_RELAY_BACKEND_PARAM_DICT_H_ #define TVM_RELAY_BACKEND_PARAM_DICT_H_
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <tvm/packed_func_ext.h> #include <tvm/ir.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/base.h> #include <tvm/relay/base.h>
namespace tvm { namespace tvm {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief The type system AST nodes of Relay. * \brief The type system AST nodes of Relay.
*/ */
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/expr_operator.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file multibox_op.cc * \file multibox_op.cc
* \brief Multibox related operators * \brief Multibox related operators
*/ */
#include <tvm/expr_operator.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h> #include <tvm/relay/attrs/vision.h>
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file type_solver.cc * \file type_solver.cc
* \brief Type solver implementations. * \brief Type solver implementations.
*/ */
#include <tvm/expr_operator.h>
#include <string> #include <string>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#define TVM_RELAY_QNN_UTIL_H_ #define TVM_RELAY_QNN_UTIL_H_
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/expr_operator.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include <limits> #include <limits>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "op_util.h" #include "op_util.h"
#include "compute_op.h" #include "compute_op.h"
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h> #include <tvm/ir.h>
namespace tvm { namespace tvm {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <topi/cuda/injective.h> #include <topi/cuda/injective.h>
#include <tvm/top/operation.h> #include <tvm/top/operation.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <string> #include <string>
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/packed_func_ext.h> #include <tvm/expr_operator.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <new> #include <new>
#include <unordered_map> #include <unordered_map>
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h> #include <tvm/ir.h>
TEST(PackedFunc, Basic) { TEST(PackedFunc, Basic) {
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
TVM_REGISTER_GLOBAL("test.sch") TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) { .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/top/operation.h> #include <tvm/top/operation.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/ir/expr.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <topi/broadcast.h> #include <topi/broadcast.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment