Commit 50484b38 by Siju Committed by Tianqi Chen

Bugfix for path issues (#3038)

parent 2a7f7548
...@@ -165,7 +165,7 @@ def do_nms_sort(dets, classes, thresh): ...@@ -165,7 +165,7 @@ def do_nms_sort(dets, classes, thresh):
if _box_iou(a, b) > thresh: if _box_iou(a, b) > thresh:
dets[j]['prob'][k] = 0 dets[j]['prob'][k] = 0
def draw_detections(im, dets, thresh, names, classes): def draw_detections(font_path, im, dets, thresh, names, classes):
"Draw the markings around the detected region" "Draw the markings around the detected region"
for det in dets: for det in dets:
labelstr = [] labelstr = []
...@@ -198,7 +198,7 @@ def draw_detections(im, dets, thresh, names, classes): ...@@ -198,7 +198,7 @@ def draw_detections(im, dets, thresh, names, classes):
if bot > imh-1: if bot > imh-1:
bot = imh-1 bot = imh-1
_draw_box_width(im, left, top, right, bot, width, red, green, blue) _draw_box_width(im, left, top, right, bot, width, red, green, blue)
label = _get_label(''.join(labelstr), rgb) label = _get_label(font_path, ''.join(labelstr), rgb)
_draw_label(im, top + width, left, label, rgb) _draw_label(im, top + width, left, label, rgb)
def _get_pixel(im, x, y, c): def _get_pixel(im, x, y, c):
...@@ -223,7 +223,7 @@ def _draw_label(im, r, c, label, rgb): ...@@ -223,7 +223,7 @@ def _draw_label(im, r, c, label, rgb):
val = _get_pixel(label, i, j, k) val = _get_pixel(label, i, j, k)
_set_pixel(im, i+c, j+r, k, val)#rgb[k] * val) _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
def _get_label(labelstr, rgb): def _get_label(font_path, labelstr, rgb):
from PIL import Image from PIL import Image
from PIL import ImageDraw from PIL import ImageDraw
from PIL import ImageFont from PIL import ImageFont
...@@ -231,7 +231,7 @@ def _get_label(labelstr, rgb): ...@@ -231,7 +231,7 @@ def _get_label(labelstr, rgb):
text = labelstr text = labelstr
colorText = "black" colorText = "black"
testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1))) testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
font = ImageFont.truetype("arial.ttf", 25) font = ImageFont.truetype(font_path, 25)
width, height = testDraw.textsize(labelstr, font=font) width, height = testDraw.textsize(labelstr, font=font)
img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255), img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
int(rgb[2]*255))) int(rgb[2]*255)))
......
...@@ -153,7 +153,7 @@ elif MODEL_NAME == 'yolov3': ...@@ -153,7 +153,7 @@ elif MODEL_NAME == 'yolov3':
# do the detection and bring up the bounding boxes # do the detection and bring up the bounding boxes
thresh = 0.5 thresh = 0.5
nms_thresh = 0.45 nms_thresh = 0.45
img = nnvm.testing.darknet.load_image_color(test_image) img = nnvm.testing.darknet.load_image_color(img_path)
_, im_h, im_w = img.shape _, im_h, im_w = img.shape
dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh,
1, tvm_out) 1, tvm_out)
...@@ -172,6 +172,6 @@ with open(coco_path) as f: ...@@ -172,6 +172,6 @@ with open(coco_path) as f:
names = [x.strip() for x in content] names = [x.strip() for x in content]
nnvm.testing.yolo_detection.draw_detections(img, dets, thresh, names, last_layer.classes) nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
plt.imshow(img.transpose(1, 2, 0)) plt.imshow(img.transpose(1, 2, 0))
plt.show() plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment