Commit 1e57ee6c by Lianmin Zheng Committed by Tianqi Chen

[RUNTIME] Improve memory usage for RPC (#1741)

parent 51fe00fb
......@@ -31,8 +31,10 @@ python3 gpu_imagenet_bench.py --model titanx
```
### ARM CPU & Mali GPU
For embedded deivces, we use RPC infrastructure in TVM to make the management easy.
So you need to use it for reproducing benchmark results.
For embedded devices, we use RPC infrastructure in TVM to make the management easy.
You need to use it for reproducing benchmark results.
**Note**: We use llvm-4.0 in our tuning environment. Mismatch of the LLVM version during tuning and deployment can influence the performance, so you have to use a same version for reproduction.
0. Build TVM with LLVM enabled. [Help](https://docs.tvm.ai/install/from_source.html)
......@@ -87,6 +89,10 @@ python3 -m tvm.exec.rpc_tracker
python3 arm_cpu_imagenet_bench.py --model mate10pro --rpc-key mate10pro
# Mali GPU
# NOTE: To make the test environment more stable, we close GUI and lock the frequency
sudo /etc/init.d/lightdm stop
sudo -i
echo performance > /sys/class/misc/mali0/device/devfreq/ff9a0000.gpu/governor
python3 mobile_gpu_imagenet_bench.py --model rk3399 --rpc-key rk3399
```
......
......@@ -41,15 +41,12 @@ def evaluate_network(network, target, target_host, number):
print_progress("%-20s uploading..." % network)
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
rlib = remote.load_module(filename)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
del rparams
module.set_input(**params)
# evaluate
print_progress("%-20s evaluating..." % network)
......
......@@ -46,9 +46,7 @@ def evaluate_network(network, target, target_host, number):
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
del rparams
module.set_input(**params)
# evaluate
print_progress("%-20s evaluating..." % network)
......@@ -87,4 +85,4 @@ if __name__ == "__main__":
print("--------------------------------------------------")
for network in networks:
evaluate_network(network, target, target_host, args.number)
\ No newline at end of file
evaluate_network(network, target, target_host, args.number)
......@@ -8,12 +8,13 @@ import re
import os.path
import collections
import numpy as np
from tvm.contrib import util
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util
######################################################################
# Some helper functions
# ---------------------
......
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
import numpy as np
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from ..rpc import base as rpc_base
......@@ -97,9 +99,13 @@ class GraphModule(object):
"""
if key:
self._set_input(key, nd.array(value, ctx=self.ctx))
for k, v in params.items():
self._set_input(k, nd.array(v, ctx=self.ctx))
return self
if params:
# upload big arrays first to avoid memory issue in rpc mode
keys = list(params.keys())
keys.sort(key=lambda x: -np.prod(params[x].shape))
for k in keys:
self._set_input(k, nd.array(params[k], ctx=self.ctx))
def run(self, **input_dict):
"""Run forward execution of the graph
......
......@@ -36,19 +36,31 @@ class RingBuffer {
* \param n The size of capacity.
*/
void Reserve(size_t n) {
if (ring_.size() >= n) return;
size_t old_size = ring_.size();
size_t new_size = ring_.size();
while (new_size < n) {
new_size *= 2;
}
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
if (ring_.size() < n) {
size_t old_size = ring_.size();
size_t new_size = static_cast<size_t>(n * 1.2);
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
} else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_;
std::vector<char> tmp(old_bytes);
Read(&tmp[0], old_bytes);
ring_.resize(kInitCapacity);
ring_.shrink_to_fit();
memcpy(&ring_[0], &tmp[0], old_bytes);
head_ptr_ = 0;
bytes_available_ = old_bytes;
}
}
/*!
* \brief Peform a non-blocking read from buffer
* size must be smaller than this->bytes_available()
......
......@@ -327,11 +327,10 @@ def tune_and_evaluate(tuning_opt):
# upload parameters to device
ctx = remote.context(str(target), 0)
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
......
......@@ -229,11 +229,10 @@ def tune_and_evaluate(tuning_opt):
# load parameters
ctx = tvm.context(str(target), 0)
params_tvm = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
......
......@@ -328,11 +328,10 @@ def tune_and_evaluate(tuning_opt):
# upload parameters to device
ctx = remote.context(str(target), 0)
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
......@@ -357,9 +356,28 @@ def tune_and_evaluate(tuning_opt):
#
# Extract tasks...
# Tuning...
# [Task 1/17] Current/Best: 12.22/ 36.05 GFLOPS | Progress: (32/1000) | 42.12 s
# [Task 1/17] Current/Best: 25.30/ 39.12 GFLOPS | Progress: (992/1000) | 751.22 s Done.
# [Task 2/17] Current/Best: 40.70/ 45.50 GFLOPS | Progress: (736/1000) | 545.46 s Done.
# [Task 3/17] Current/Best: 38.83/ 42.35 GFLOPS | Progress: (992/1000) | 1549.85 s Done.
# [Task 4/17] Current/Best: 23.31/ 31.02 GFLOPS | Progress: (640/1000) | 1059.31 s Done.
# [Task 5/17] Current/Best: 0.06/ 2.34 GFLOPS | Progress: (544/1000) | 305.45 s Done.
# [Task 6/17] Current/Best: 10.97/ 17.20 GFLOPS | Progress: (992/1000) | 1050.00 s Done.
# [Task 7/17] Current/Best: 8.98/ 10.94 GFLOPS | Progress: (928/1000) | 421.36 s Done.
# [Task 8/17] Current/Best: 4.48/ 14.86 GFLOPS | Progress: (704/1000) | 582.60 s Done.
# [Task 9/17] Current/Best: 10.30/ 25.99 GFLOPS | Progress: (864/1000) | 899.85 s Done.
# [Task 10/17] Current/Best: 11.73/ 12.52 GFLOPS | Progress: (608/1000) | 304.85 s Done.
# [Task 11/17] Current/Best: 15.26/ 18.68 GFLOPS | Progress: (800/1000) | 747.52 s Done.
# [Task 12/17] Current/Best: 17.48/ 26.71 GFLOPS | Progress: (1000/1000) | 1166.40 s Done.
# [Task 13/17] Current/Best: 0.96/ 11.43 GFLOPS | Progress: (960/1000) | 611.65 s Done.
# [Task 14/17] Current/Best: 17.88/ 20.22 GFLOPS | Progress: (672/1000) | 670.29 s Done.
# [Task 15/17] Current/Best: 11.62/ 13.98 GFLOPS | Progress: (736/1000) | 449.25 s Done.
# [Task 16/17] Current/Best: 19.90/ 23.83 GFLOPS | Progress: (608/1000) | 708.64 s Done.
# [Task 17/17] Current/Best: 17.98/ 22.75 GFLOPS | Progress: (736/1000) | 1122.60 s Done.
# Compile...
# Upload...
# Evaluate inference time cost...
# Mean inference time (std dev): 128.05 ms (7.74 ms)
#
# (The following part is running, will update it later).
######################################################################
#
......
......@@ -132,7 +132,6 @@ batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
######################################################################
# Compile The Graph
......@@ -197,20 +196,17 @@ else:
remote.upload(lib_fname)
rlib = remote.load_module('net.tar')
ctx = remote.cpu(0) if local_demo else remote.cl(0)
# upload the parameter
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
# create the remote runtime module
ctx = remote.cl(0) if not local_demo else remote.cpu(0)
module = runtime.create(graph, rlib, ctx)
# set parameter
module.set_input(**rparams)
# set parameter (upload params to the remote device. This may take a while)
module.set_input(**params)
# set input data
module.set_input('data', tvm.nd.array(x.astype('float32')))
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape, ctx=ctx))
out = module.get_output(0)
# get top1 result
top1 = np.argmax(out.asnumpy())
print('TVM prediction top-1: {}'.format(synset[top1]))
......@@ -128,7 +128,6 @@ batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
######################################################################
# Compile The Graph
......@@ -188,20 +187,17 @@ else:
remote.upload(lib_fname)
rlib = remote.load_module('net.tar')
# upload the parameter (this may take a while)
ctx = remote.cpu(0)
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
# create the remote runtime module
ctx = remote.cpu(0)
module = runtime.create(graph, rlib, ctx)
# set parameter
module.set_input(**rparams)
# set parameter (upload params to the remote device. This may take a while)
module.set_input(**params)
# set input data
module.set_input('data', tvm.nd.array(x.astype('float32')))
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape, ctx=ctx))
out = module.get_output(0)
# get top1 result
top1 = np.argmax(out.asnumpy())
print('TVM prediction top-1: {}'.format(synset[top1]))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment