Unverified Commit 8f9796bd by Haichen Shen Committed by GitHub

[Relay] Fix memory leak when accessing NDArray (#5413)

parent 6cb5b882
......@@ -68,7 +68,6 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
runtime::NDArray array = cn->data;
const auto& shape = array.Shape();
const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor;
// Get the number of elements.
int64_t num_elems = 1;
......@@ -83,11 +82,11 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
// to avoid possible stack overflow.
buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
if (dtype == "float") {
float* p_flt = static_cast<float*>(dl_tensor.data);
float* p_flt = static_cast<float*>(array->data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else if (dtype == "int") {
int* p_flt = static_cast<int*>(dl_tensor.data);
int* p_flt = static_cast<int*>(array->data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else {
......
......@@ -169,7 +169,7 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
std::ostringstream buf_stream;
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
const float* ptr = static_cast<float*>(array->data);
// Allocate large arrays on the static section to avoid stakc overflow.
// Note that this would probably increase compilation time as the source
......
......@@ -193,35 +193,26 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause
return else_branch;
}
std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
// TODO(@jroesch): we really need to standaridize the bit width of
// all of the shape manipulating code.
CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
return raw_shape;
}
std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
// TODO(@jroesch): we really need to standaridize the bit width of
// all of the shape manipulating code.
CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
CHECK_EQ(shape->ndim, 1u);
CHECK_EQ(shape->dtype.code, 0U)
<< "The dtype of constant shape must be int32 or int64, but got "
<< DLDataType2String(shape->dtype);
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
<< "The dtype of constant shape must be int32 or int64, but got"
<< DLDataType2String(shape->dtype);
if (shape->dtype.bits == 64) {
int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
for (auto i = 0; i < shape->shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
} else { // int32
int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
for (auto i = 0; i < shape->shape[0]; i++) {
raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
}
}
return raw_shape;
}
......@@ -546,17 +537,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
if (const_shape) {
NDArray shape = const_shape->data;
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
// TODO(@jroesch): we need to get an RFC done to standarize this
if (tensor.dtype.bits == 64) {
raw_shape = ToAllocTensorShape64(shape);
} else if (tensor.dtype.bits == 32) {
raw_shape = ToAllocTensorShape32(shape);
} else {
LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
}
// TODO(@jroesch): we need to get an RFC done to standarize shape dtype
std::vector<int64_t> raw_shape = ToAllocTensorShape(shape);
// Add context field.
Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
} else {
......
......@@ -23,6 +23,7 @@
*/
#include <topi/elemwise.h>
#include <tvm/runtime/data_type.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
......@@ -107,21 +108,22 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
std::vector<int64_t> FromConstShape(Constant konst) {
runtime::NDArray shape = konst->data;
std::vector<int64_t> raw_shape;
DLTensor tensor = shape.ToDLPack()->dl_tensor;
CHECK_EQ(tensor.ndim, 1u);
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32)
<< "found " << static_cast<int>(tensor.dtype.bits);
if (tensor.dtype.bits == 32) {
const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
CHECK_EQ(shape->ndim, 1u);
CHECK_EQ(shape->dtype.code, 0U)
<< "The dtype of constant shape must be int32 or int64, but got "
<< runtime::DLDataType2String(shape->dtype);
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
<< "The dtype of constant shape must be int32 or int64, but got"
<< runtime::DLDataType2String(shape->dtype);
if (shape->dtype.bits == 32) {
const int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
for (auto i = 0; i < shape->shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
} else if (tensor.dtype.bits == 64) {
const int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
for (auto i = 0; i < tensor.shape[0]; i++) {
} else if (shape->dtype.bits == 64) {
const int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
for (auto i = 0; i < shape->shape[0]; i++) {
raw_shape.push_back(int_ptr[i]);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment