Unverified Commit 2f8a01f7 by Tianqi Chen Committed by GitHub

[REFACTOR] Get rid of packed_func_ext. (#4735)

Move the conversion extensions to the specific class definitions
so that we longer need to include packed_func_ext.
parent 703ed9b7
......@@ -26,8 +26,8 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
#include <tvm/expr_operator.h>
using namespace tvm;
using namespace tvm::runtime;
......
......@@ -29,6 +29,7 @@
#include <algorithm>
#include <unordered_map>
#include <iostream>
#include <limits>
#include "node/node.h"
#include "node/container.h"
#include "node/functor.h"
......@@ -460,6 +461,26 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
} // namespace tvm
namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::IterVar> : public ::tvm::ObjectHash {
......
......@@ -30,6 +30,7 @@
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
#include <string>
#include <limits>
namespace tvm {
......@@ -114,6 +115,11 @@ class PrimExpr : public BaseExpr {
}
TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
private:
// Internal function for conversion.
friend class runtime::TVMPODValue_;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
};
/*!
......@@ -322,4 +328,25 @@ inline const TTypeNode* RelayExprNode::type_as() const {
}
} // namespace tvm
namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_
......@@ -655,6 +655,65 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
}
};
} // namespace tvm
namespace tvm {
namespace runtime {
// Additional overloads for PackedFunc checking.
template<typename T>
struct ObjectTypeChecker<Array<T> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
template<typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<MapNode>()) return false;
const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_CONTAINER_H_
......@@ -56,6 +56,9 @@ using runtime::Downcast;
using runtime::ObjectHash;
using runtime::ObjectEqual;
using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
} // namespace tvm
#endif // TVM_NODE_NODE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass ObjectRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <tvm/top/tensor.h>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>
#include "expr.h"
#include "runtime/packed_func.h"
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
template<typename T>
struct ObjectTypeChecker<Array<T> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
template<typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<MapNode>()) return false;
const MapNode* n = static_cast<const MapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
// extensions for tvm arg value
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Object>(ptr))->var;
}
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ObjectPtr<Object>(ptr));
}
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
......@@ -24,7 +24,6 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
......@@ -184,7 +183,7 @@ TVM_DLL Pass InferType();
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
......
......@@ -28,9 +28,7 @@
#include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir.h>
#include <string>
......
......@@ -29,7 +29,6 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/top/tensor.h>
......
......@@ -25,8 +25,6 @@
#include <tvm/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/serialization.h>
namespace tvm {
......
......@@ -26,8 +26,6 @@
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
namespace codegen {
......
......@@ -24,7 +24,6 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/expr_operator.h>
......
......@@ -28,7 +28,6 @@
#include <tvm/buffer.h>
#include <tvm/top/schedule.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h>
#include <tvm/data_layout.h>
......
......@@ -27,8 +27,6 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
namespace ir {
......
......@@ -26,7 +26,6 @@
#include <tvm/top/schedule.h>
#include <tvm/top/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "../top/schedule/graph.h"
......
......@@ -26,8 +26,6 @@
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
// Attrs used to python API
......
......@@ -26,8 +26,6 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_set>
#include <unordered_map>
......
......@@ -26,8 +26,6 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_set>
#include <unordered_map>
......
......@@ -25,7 +25,6 @@
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <utility>
#include <algorithm>
......
......@@ -28,7 +28,6 @@
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <stack>
#include <vector>
......
......@@ -20,7 +20,7 @@
/*!
* \file codegen_c_host.cc
*/
#include <tvm/packed_func_ext.h>
#include <tvm/codegen.h>
#include <vector>
#include <string>
#include "codegen_c_host.h"
......
......@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_C_HOST_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <string>
#include "codegen_c.h"
......
......@@ -22,7 +22,6 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <cmath>
#include <vector>
#include <string>
......
......@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_CUDA_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <string>
#include <unordered_map>
#include "codegen_c.h"
......
......@@ -20,7 +20,6 @@
/*!
* \file codegen_metal.cc
*/
#include <tvm/packed_func_ext.h>
#include <vector>
#include <string>
#include <algorithm>
......
......@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_METAL_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include "codegen_c.h"
......
......@@ -20,7 +20,6 @@
/*!
* \file codegen_opencl.cc
*/
#include <tvm/packed_func_ext.h>
#include <cmath>
#include <vector>
#include <string>
......
......@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_OPENCL_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include "codegen_c.h"
......
......@@ -23,7 +23,6 @@
* We are targeting OpenGL 3.3. The reason of not targeting a recent version
* of OpenGL is to have better compatibility of WebGL 2.
*/
#include <tvm/packed_func_ext.h>
#include <vector>
#include <string>
#include <utility>
......
......@@ -25,7 +25,6 @@
#define TVM_CODEGEN_CODEGEN_OPENGL_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include <unordered_set>
#include <unordered_map>
......
......@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_VHLS_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <string>
#include "codegen_c.h"
......
......@@ -16,15 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "registry.h"
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "registry.h"
namespace tvm {
namespace datatype {
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_REGISTER_GLOBAL("_datatype_register")
.set_body([](TVMArgs args, TVMRetValue* ret) {
datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
......
......@@ -21,6 +21,7 @@
* \file intrin_rule_default.cc
* \brief Default intrinsic rules.
*/
#include <tvm/expr_operator.h>
#include "intrin_rule.h"
namespace tvm {
......
......@@ -27,9 +27,6 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <string>
namespace tvm {
......
......@@ -27,7 +27,6 @@
#include <tvm/ir.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/codegen.h>
#include <string>
......
......@@ -25,8 +25,6 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <sstream>
namespace tvm {
......
......@@ -25,7 +25,6 @@
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <sstream>
......
......@@ -21,7 +21,6 @@
* \file intrin_rule_spirv.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <GLSL.std.450.h>
......
......@@ -21,7 +21,6 @@
* \file codegen_stackvm.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <limits>
#include <utility>
#include "codegen_stackvm.h"
......
......@@ -31,6 +31,7 @@
#define TVM_IR_ATTR_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <tvm/ir.h>
#include <utility>
namespace tvm {
......
......@@ -22,8 +22,6 @@
*/
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "attr_functor.h"
namespace tvm {
......
......@@ -24,11 +24,11 @@
#include <tvm/ir/module.h>
#include <tvm/ir/error.h>
// NOTE on dependencies on relay AsText.
// We calls into relay's printing module for better rendering.
// These dependency does not happen at the interface-level.
// And is only used to enhance developer experiences when relay
// functions are presented.
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: use relay's printer for astext.
#include <tvm/relay/expr.h>
#include <string>
......
......@@ -23,9 +23,30 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
// NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: convert from IterVar and top::Tensor
#include <tvm/top/tensor.h>
#include <tvm/expr.h>
namespace tvm {
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ptr)->var;
}
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ptr)();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ptr);
}
IntImm::IntImm(DataType dtype, int64_t value) {
CHECK(dtype.is_scalar())
<< "ValueError: IntImm can only take scalar.";
......
......@@ -23,12 +23,11 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
// NOTE on dependencies on relay analysis.
// We calls into relay's analysis module to verify correctness
// when a relay function is presented.
// These dependency does not happen at the interface-level.
// And is only used to enhance developer experiences when relay
// functions are presented.
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into relay's analysis module to verify correctness.
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
......
......@@ -22,7 +22,6 @@
*/
#include <tvm/ir/span.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
......
......@@ -23,8 +23,6 @@
*/
#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
PrimType::PrimType(runtime::DataType dtype) {
......
......@@ -24,8 +24,6 @@
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
......
......@@ -24,7 +24,6 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <unordered_map>
#include <unordered_set>
......
......@@ -23,7 +23,6 @@
*/
#include <tvm/arith/pattern.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include "../arith/pattern_match.h"
......
......@@ -20,7 +20,6 @@
* \file ir_functor.cc
*/
#include <tvm/ir_functor_ext.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
namespace ir {
......
......@@ -23,7 +23,6 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include "../codegen/datatype/registry.h"
namespace tvm {
......
......@@ -24,7 +24,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/expr_operator.h>
#include <unordered_set>
......
......@@ -25,7 +25,6 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
......
......@@ -24,7 +24,6 @@
#include "compile_engine.h"
#include <tvm/top/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/top/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
......
......@@ -22,6 +22,7 @@
* \brief Memory index assignment pass for executing
* the program in the graph runtime.
*/
#include <tvm/expr_operator.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
......
......@@ -21,7 +21,6 @@
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
......
......@@ -25,7 +25,7 @@
#define TVM_RELAY_BACKEND_PARAM_DICT_H_
#include <tvm/node/node.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
......
......@@ -24,7 +24,6 @@
#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/base.h>
namespace tvm {
......
......@@ -22,6 +22,7 @@
* \brief The type system AST nodes of Relay.
*/
#include <tvm/relay/type.h>
#include <tvm/expr_operator.h>
namespace tvm {
namespace relay {
......
......@@ -21,6 +21,7 @@
* \file multibox_op.cc
* \brief Multibox related operators
*/
#include <tvm/expr_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
......
......@@ -21,6 +21,7 @@
* \file type_solver.cc
* \brief Type solver implementations.
*/
#include <tvm/expr_operator.h>
#include <string>
#include <memory>
#include <tuple>
......
......@@ -26,6 +26,7 @@
#define TVM_RELAY_QNN_UTIL_H_
#include <tvm/expr.h>
#include <tvm/expr_operator.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/qnn/attrs.h>
#include <limits>
......
......@@ -25,7 +25,6 @@
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include "op_util.h"
#include "compute_op.h"
......
......@@ -21,7 +21,6 @@
#include <gtest/gtest.h>
#include <tvm/ir/attrs.h>
#include <tvm/expr_operator.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
namespace tvm {
......
......@@ -22,7 +22,6 @@
#include <topi/cuda/injective.h>
#include <tvm/top/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h>
#include <string>
......
......@@ -19,7 +19,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/packed_func_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/runtime/container.h>
#include <new>
#include <unordered_map>
......
......@@ -21,8 +21,6 @@
#include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h>
TEST(PackedFunc, Basic) {
......
......@@ -28,8 +28,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
......
......@@ -20,7 +20,6 @@
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
......
......@@ -33,7 +33,6 @@
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/top/operation.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
......
......@@ -26,7 +26,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir/expr.h>
#include <tvm/build_module.h>
#include <topi/broadcast.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment