Commit a1688998 by Lianmin Zheng Committed by Tianqi Chen

improve text summary (#1655)

parent 531efd6f
...@@ -104,11 +104,11 @@ You are supposed to find a free "android" in the queue status. ...@@ -104,11 +104,11 @@ You are supposed to find a free "android" in the queue status.
... ...
Queue Status Queue Status
---------------------------- -------------------------------
key free pending key total free pending
---------------------------- -------------------------------
android 1 0 android 1 1 0
---------------------------- -------------------------------
``` ```
......
...@@ -40,14 +40,14 @@ python3 -m tvm.exec.rpc_tracker ...@@ -40,14 +40,14 @@ python3 -m tvm.exec.rpc_tracker
For our test environment, one sample output can be For our test environment, one sample output can be
```bash ```bash
Queue Status Queue Status
------------------------------ ----------------------------------
key free pending key total free pending
------------------------------ ----------------------------------
mate10pro 1 0 mate10pro 1 1 0
p20pro 2 0 p20pro 2 2 0
pixel2 2 0 pixel2 2 2 0
rk3399 2 0 rk3399 2 2 0
rasp3b 8 0 rasp3b 8 8 0
``` ```
4. Run benchmark 4. Run benchmark
......
...@@ -218,6 +218,9 @@ class TrackerSession(object): ...@@ -218,6 +218,9 @@ class TrackerSession(object):
def text_summary(self): def text_summary(self):
"""Get a text summary of the tracker.""" """Get a text summary of the tracker."""
data = self.summary() data = self.summary()
total_ct = {}
res = "" res = ""
res += "Server List\n" res += "Server List\n"
res += "----------------------------\n" res += "----------------------------\n"
...@@ -225,8 +228,12 @@ class TrackerSession(object): ...@@ -225,8 +228,12 @@ class TrackerSession(object):
res += "----------------------------\n" res += "----------------------------\n"
for item in data["server_info"]: for item in data["server_info"]:
addr = item["addr"] addr = item["addr"]
res += addr[0] + ":" + str(addr[1])+ "\t" res += addr[0] + ":" + str(addr[1]) + "\t"
res += item["key"] + "\n" res += item["key"] + "\n"
key = item['key'].split(':')[1] # 'server:rasp3b` -> 'rasp3b'
if key not in total_ct:
total_ct[key] = 0
total_ct[key] += 1
res += "----------------------------\n" res += "----------------------------\n"
res += "\n" res += "\n"
...@@ -240,14 +247,16 @@ class TrackerSession(object): ...@@ -240,14 +247,16 @@ class TrackerSession(object):
max_key_len = 0 max_key_len = 0
res += "Queue Status\n" res += "Queue Status\n"
res += "----------------------------\n" title = ("%%-%ds" % max_key_len + " total free pending\n") % 'key'
res += ("%%-%ds" % max_key_len + "\tfree\tpending\n") % 'key' separate_line = '-' * len(title) + '\n'
res += "----------------------------\n" res += separate_line + title + separate_line
for k in keys: for k in keys:
res += ("%%-%ds" % max_key_len + "\t%d\t%g\n") % \ total = total_ct.get(k, 0)
(k, queue_info[k]["free"], queue_info[k]["pending"]) free, pending = queue_info[k]["free"], queue_info[k]["pending"]
if total or pending:
res += "----------------------------\n" res += ("%%-%ds" % max_key_len + " %-5d %-4d %-7d\n") % \
(k, total, free, pending)
res += separate_line
return res return res
def request(self, key, priority=1, session_timeout=0, max_retry=5): def request(self, key, priority=1, session_timeout=0, max_retry=5):
......
...@@ -66,6 +66,8 @@ class TCPHandler(object): ...@@ -66,6 +66,8 @@ class TCPHandler(object):
while self._pending_write: while self._pending_write:
try: try:
msg = self._pending_write[0] msg = self._pending_write[0]
if self._sock is None:
return
nsend = self._sock.send(msg) nsend = self._sock.send(msg)
if nsend != len(msg): if nsend != len(msg):
self._pending_write[0] = msg[nsend:] self._pending_write[0] = msg[nsend:]
......
...@@ -78,6 +78,16 @@ class Scheduler(object): ...@@ -78,6 +78,16 @@ class Scheduler(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def remove(self, value):
"""Remove a resource in the scheduler
Parameters
----------
value: object
The resource to remove
"""
pass
def summary(self): def summary(self):
"""Get summary information of the scheduler.""" """Get summary information of the scheduler."""
raise NotImplementedError() raise NotImplementedError()
...@@ -108,6 +118,11 @@ class PriorityScheduler(Scheduler): ...@@ -108,6 +118,11 @@ class PriorityScheduler(Scheduler):
heapq.heappush(self._requests, (-priority, time.time(), callback)) heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule() self._schedule()
def remove(self, value):
if value in self._values:
self._values.remove(value)
self._schedule()
def summary(self): def summary(self):
"""Get summary information of the scheduler.""" """Get summary information of the scheduler."""
return {"free": len(self._values), return {"free": len(self._values),
...@@ -132,6 +147,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -132,6 +147,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
# list of pending match keys that has not been used. # list of pending match keys that has not been used.
self.pending_matchkeys = set() self.pending_matchkeys = set()
self._tracker._connections.add(self) self._tracker._connections.add(self)
self.put_values = []
def name(self): def name(self):
"""name of connection""" """name of connection"""
...@@ -199,9 +215,11 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -199,9 +215,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
self.pending_matchkeys.add(matchkey) self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server) # got custom address (from rpc server)
if args[3] is not None: if args[3] is not None:
self._tracker.put(key, (self, args[3], port, matchkey)) value = (self, args[3], port, matchkey)
else: else:
self._tracker.put(key, (self, self._addr[0], port, matchkey)) value = (self, self._addr[0], port, matchkey)
self._tracker.put(key, value)
self.put_values.append(value)
self.ret_value(TrackerCode.SUCCESS) self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST: elif code == TrackerCode.REQUEST:
key = args[1] key = args[1]
...@@ -239,7 +257,7 @@ class TCPEventHandler(tornado_util.TCPHandler): ...@@ -239,7 +257,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
self.close() self.close()
def on_close(self): def on_close(self):
self._tracker._connections.remove(self) self._tracker.close(self)
def on_error(self, err): def on_error(self, err):
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err) logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
...@@ -285,6 +303,13 @@ class TrackerServerHandler(object): ...@@ -285,6 +303,13 @@ class TrackerServerHandler(object):
self._scheduler_map[key] = self.create_scheduler(key) self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].request(user, priority, callback) self._scheduler_map[key].request(user, priority, callback)
def close(self, conn):
self._connections.remove(conn)
if 'key' in conn._info:
key = conn._info['key'].split(':')[1] # 'server:rasp3b' -> 'rasp3b'
for value in conn.put_values:
self._scheduler_map[key].remove(value)
def stop(self): def stop(self):
"""Safely stop tracker.""" """Safely stop tracker."""
for conn in list(self._connections): for conn in list(self._connections):
......
import tvm import tvm
import os import os
import logging import logging
import numpy as np
import time import time
import multiprocessing
import numpy as np
from tvm import rpc from tvm import rpc
from tvm.contrib import util from tvm.contrib import util
from tvm.rpc.tracker import Tracker
def test_bigendian_rpc(): def test_bigendian_rpc():
...@@ -237,6 +240,79 @@ def test_local_func(): ...@@ -237,6 +240,79 @@ def test_local_func():
rev = client.download("dat.bin") rev = client.download("dat.bin")
assert rev == blob assert rev == blob
def test_rpc_tracker_register():
# test registration
tracker = Tracker('localhost', port=9000, port_end=10000)
device_key = 'test_device'
server = rpc.Server('localhost', port=9000, port_end=10000,
key=device_key,
tracker_addr=(tracker.host, tracker.port))
time.sleep(1)
client = rpc.connect_tracker(tracker.host, tracker.port)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 1
remote = client.request(device_key)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
del remote
time.sleep(1)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 1
server.terminate()
time.sleep(1)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
tracker.terminate()
def test_rpc_tracker_request():
# test concurrent request
tracker = Tracker('localhost', port=9000, port_end=10000)
device_key = 'test_device'
server = rpc.Server('localhost', port=9000, port_end=10000,
key=device_key,
tracker_addr=(tracker.host, tracker.port))
client = rpc.connect_tracker(tracker.host, tracker.port)
def target(host, port, device_key, timeout):
client = rpc.connect_tracker(host, port)
remote = client.request(device_key, session_timeout=timeout)
while True:
pass
remote.cpu()
proc1 = multiprocessing.Process(target=target,
args=(tracker.host, tracker.port, device_key, 4))
proc2 = multiprocessing.Process(target=target,
args=(tracker.host, tracker.port, device_key, 200))
proc1.start()
time.sleep(0.5)
proc2.start()
time.sleep(0.5)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
assert summary['queue_info'][device_key]['pending'] == 1
proc1.terminate()
proc1.join()
time.sleep(0.5)
summary = client.summary()
assert summary['queue_info'][device_key]['free'] == 0
assert summary['queue_info'][device_key]['pending'] == 0
proc2.terminate()
proc2.join()
server.terminate()
tracker.terminate()
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -248,3 +324,5 @@ if __name__ == "__main__": ...@@ -248,3 +324,5 @@ if __name__ == "__main__":
test_rpc_array() test_rpc_array()
test_rpc_simple() test_rpc_simple()
test_local_func() test_local_func()
test_rpc_tracker_register()
test_rpc_tracker_request()
...@@ -151,13 +151,13 @@ def get_network(name, batch_size): ...@@ -151,13 +151,13 @@ def get_network(name, batch_size):
# .. code-block:: bash # .. code-block:: bash
# #
# Queue Status # Queue Status
# ---------------------------- # ----------------------------------
# key free pending # key total free pending
# ---------------------------- # ----------------------------------
# mate10pro 2 0 # mate10pro 2 2 0
# rk3399 2 0 # rk3399 2 2 0
# rpi3b 11 0 # rpi3b 11 11 0
# ---------------------------- # ----------------------------------
########################################### ###########################################
# Set Tuning Options # Set Tuning Options
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment