Commit a1688998 by Lianmin Zheng Committed by Tianqi Chen

improve text summary (#1655)

parent 531efd6f
......@@ -104,11 +104,11 @@ You are supposed to find a free "android" in the queue status.
...
Queue Status
----------------------------
key free pending
----------------------------
android 1 0
----------------------------
-------------------------------
key total free pending
-------------------------------
android 1 1 0
-------------------------------
```
......
......@@ -40,14 +40,14 @@ python3 -m tvm.exec.rpc_tracker
For our test environment, one sample output can be
```bash
Queue Status
------------------------------
key free pending
------------------------------
mate10pro 1 0
p20pro 2 0
pixel2 2 0
rk3399 2 0
rasp3b 8 0
----------------------------------
key total free pending
----------------------------------
mate10pro 1 1 0
p20pro 2 2 0
pixel2 2 2 0
rk3399 2 2 0
rasp3b 8 8 0
```
4. Run benchmark
......
......@@ -218,6 +218,9 @@ class TrackerSession(object):
def text_summary(self):
"""Get a text summary of the tracker."""
data = self.summary()
total_ct = {}
res = ""
res += "Server List\n"
res += "----------------------------\n"
......@@ -225,8 +228,12 @@ class TrackerSession(object):
res += "----------------------------\n"
for item in data["server_info"]:
addr = item["addr"]
res += addr[0] + ":" + str(addr[1])+ "\t"
res += addr[0] + ":" + str(addr[1]) + "\t"
res += item["key"] + "\n"
key = item['key'].split(':')[1] # 'server:rasp3b` -> 'rasp3b'
if key not in total_ct:
total_ct[key] = 0
total_ct[key] += 1
res += "----------------------------\n"
res += "\n"
......@@ -240,14 +247,16 @@ class TrackerSession(object):
max_key_len = 0
res += "Queue Status\n"
res += "----------------------------\n"
res += ("%%-%ds" % max_key_len + "\tfree\tpending\n") % 'key'
res += "----------------------------\n"
title = ("%%-%ds" % max_key_len + " total free pending\n") % 'key'
separate_line = '-' * len(title) + '\n'
res += separate_line + title + separate_line
for k in keys:
res += ("%%-%ds" % max_key_len + "\t%d\t%g\n") % \
(k, queue_info[k]["free"], queue_info[k]["pending"])
res += "----------------------------\n"
total = total_ct.get(k, 0)
free, pending = queue_info[k]["free"], queue_info[k]["pending"]
if total or pending:
res += ("%%-%ds" % max_key_len + " %-5d %-4d %-7d\n") % \
(k, total, free, pending)
res += separate_line
return res
def request(self, key, priority=1, session_timeout=0, max_retry=5):
......
......@@ -66,6 +66,8 @@ class TCPHandler(object):
while self._pending_write:
try:
msg = self._pending_write[0]
if self._sock is None:
return
nsend = self._sock.send(msg)
if nsend != len(msg):
self._pending_write[0] = msg[nsend:]
......
......@@ -78,6 +78,16 @@ class Scheduler(object):
"""
raise NotImplementedError()
def remove(self, value):
"""Remove a resource in the scheduler
Parameters
----------
value: object
The resource to remove
"""
pass
def summary(self):
"""Get summary information of the scheduler."""
raise NotImplementedError()
......@@ -108,6 +118,11 @@ class PriorityScheduler(Scheduler):
heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule()
def remove(self, value):
if value in self._values:
self._values.remove(value)
self._schedule()
def summary(self):
"""Get summary information of the scheduler."""
return {"free": len(self._values),
......@@ -132,6 +147,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
# list of pending match keys that has not been used.
self.pending_matchkeys = set()
self._tracker._connections.add(self)
self.put_values = []
def name(self):
"""name of connection"""
......@@ -199,9 +215,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server)
if args[3] is not None:
self._tracker.put(key, (self, args[3], port, matchkey))
value = (self, args[3], port, matchkey)
else:
self._tracker.put(key, (self, self._addr[0], port, matchkey))
value = (self, self._addr[0], port, matchkey)
self._tracker.put(key, value)
self.put_values.append(value)
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
......@@ -239,7 +257,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
self.close()
def on_close(self):
self._tracker._connections.remove(self)
self._tracker.close(self)
def on_error(self, err):
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
......@@ -285,6 +303,13 @@ class TrackerServerHandler(object):
self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].request(user, priority, callback)
def close(self, conn):
self._connections.remove(conn)
if 'key' in conn._info:
key = conn._info['key'].split(':')[1] # 'server:rasp3b' -> 'rasp3b'
for value in conn.put_values:
self._scheduler_map[key].remove(value)
def stop(self):
"""Safely stop tracker."""
for conn in list(self._connections):
......
import tvm
import os
import logging
import numpy as np
import time
import multiprocessing
import numpy as np
from tvm import rpc
from tvm.contrib import util
from tvm.rpc.tracker import Tracker
def test_bigendian_rpc():
......@@ -237,6 +240,79 @@ def test_local_func():
rev = client.download("dat.bin")
assert rev == blob
def test_rpc_tracker_register():
# test registration
tracker = Tracker('localhost', port=9000, port_end=10000)
device_key = 'test_device'
server = rpc.Server('localhost', port=9000, port_end=10000,
key=device_key,
tracker_addr=(tracker.host, tracker.port))
time.sleep(1)
client = rpc.connect_tracker(tracker.host, tracker.port)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 1
remote = client.request(device_key)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
del remote
time.sleep(1)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 1
server.terminate()
time.sleep(1)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
tracker.terminate()
def test_rpc_tracker_request():
# test concurrent request
tracker = Tracker('localhost', port=9000, port_end=10000)
device_key = 'test_device'
server = rpc.Server('localhost', port=9000, port_end=10000,
key=device_key,
tracker_addr=(tracker.host, tracker.port))
client = rpc.connect_tracker(tracker.host, tracker.port)
def target(host, port, device_key, timeout):
client = rpc.connect_tracker(host, port)
remote = client.request(device_key, session_timeout=timeout)
while True:
pass
remote.cpu()
proc1 = multiprocessing.Process(target=target,
args=(tracker.host, tracker.port, device_key, 4))
proc2 = multiprocessing.Process(target=target,
args=(tracker.host, tracker.port, device_key, 200))
proc1.start()
time.sleep(0.5)
proc2.start()
time.sleep(0.5)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
assert summary['queue_info'][device_key]['pending'] == 1
proc1.terminate()
proc1.join()
time.sleep(0.5)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
assert summary['queue_info'][device_key]['pending'] == 0
proc2.terminate()
proc2.join()
server.terminate()
tracker.terminate()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
......@@ -248,3 +324,5 @@ if __name__ == "__main__":
test_rpc_array()
test_rpc_simple()
test_local_func()
test_rpc_tracker_register()
test_rpc_tracker_request()
......@@ -151,13 +151,13 @@ def get_network(name, batch_size):
# .. code-block:: bash
#
# Queue Status
# ----------------------------
# key free pending
# ----------------------------
# mate10pro 2 0
# rk3399 2 0
# rpi3b 11 0
# ----------------------------
# ----------------------------------
# key total free pending
# ----------------------------------
# mate10pro 2 2 0
# rk3399 2 2 0
# rpi3b 11 11 0
# ----------------------------------
###########################################
# Set Tuning Options
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment