Unverified Commit 48082358 by hlu1 Committed by GitHub

[NDArray] Set shape_ in NDArray::FromDLPack (#5301)

parent 02d3a59b
......@@ -208,6 +208,10 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
// fill up content.
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
// update shape_
data->shape_.resize(data->dl_tensor.ndim);
data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim);
data->dl_tensor.shape = data->shape_.data();
return NDArray(GetObjectPtr<Object>(data));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment