Commit b840e960 by Siju Committed by Tianqi Chen

[YOLO]yolo op added in frontend and removed from topi (#1974)

parent dd1558af
......@@ -317,12 +317,19 @@ def _darknet_region(inputs, attrs):
def _darknet_yolo(inputs, attrs):
"""Process the yolo operation."""
op_name, new_attrs = 'yolov3_yolo', {}
if 'n' in attrs:
new_attrs['n'] = attrs.get('n', 1)
if 'classes' in attrs:
new_attrs['classes'] = attrs.get('classes', 1)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
num = attrs.get('n', 1)
classes = attrs.get('classes', 1)
input_shape = attrs.get('shape')
split_size = classes + 5
intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3])
data_block = _sym.reshape(inputs[0], shape=intermediate_shape)
split_indices = (2, 4)
split_res = _sym.split(data_block, indices_or_sections=split_indices, axis=2)
split_res0 = _sym.sigmoid(split_res[0])
split_res2 = _sym.sigmoid(split_res[2])
concat_list = [split_res0, split_res[1], split_res2]
out = _sym.concatenate(*concat_list, axis=2)
return _sym.reshape(out, shape=input_shape), None
def _darknet_activations(inputs, attrs):
"""Process the activation function."""
......@@ -635,6 +642,7 @@ class GraphProto(object):
elif LAYERTYPE.YOLO == layer.type:
attr.update({'n' : layer.n})
attr.update({'classes' : layer.classes})
attr.update({'shape' : (1, layer.c, layer.h, layer.w)})
elif LAYERTYPE.UPSAMPLE == layer.type:
attr.update({'scale' : layer.stride})
......
......@@ -38,21 +38,6 @@ def schedule_region(attrs, outs, target):
reg.register_pattern("yolo_region", OpPattern.OPAQUE)
@reg.register_compute("yolov3_yolo")
def compute_yolo(attrs, inputs, _):
"""Compute definition of yolo"""
n = attrs.get_int("n")
classes = attrs.get_int("classes")
return topi.vision.yolo.yolo(inputs[0], n, classes)
@reg.register_schedule("yolov3_yolo")
def schedule_yolo(attrs, outs, target):
"""Schedule definition of yolo"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("yolov3_yolo", OpPattern.OPAQUE)
# multibox_prior
@reg.register_schedule("multibox_prior")
def schedule_multibox_prior(_, outs, target):
......
/*!
* Copyright (c) 2018 by Contributors
* \file yolo.cc
* \brief Property def of yolo operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../elemwise_op_common.h"
namespace nnvm {
namespace top {
NNVM_REGISTER_OP(yolov3_yolo)
.describe(R"code(Yolo layer
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
})
.set_attr<FGradient>("FGradient", [](const NodePtr &n,
const std::vector<NodeEntry> &ograds) {
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \brief YOLO op constructions
* \file vision/yolo/yolo.h
*/
#ifndef TOPI_VISION_YOLO_YOLO_H_
#define TOPI_VISION_YOLO_YOLO_H_
#include <algorithm>
#include <string>
#include "topi/detail/constant_utils.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "tvm/tvm.h"
namespace topi {
namespace vision {
namespace yolo {
using namespace tvm;
using namespace nn;
/*!
* \brief yolo operation
*
* \param data The input tensor.
* \param num Darknet layer parameter n
* \param classes number of classes in the yolo model
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the yolo operation
*/
inline Tensor yolo(const Tensor &data,
int num,
int classes,
std::string name = "tensor",
std::string tag = "yolo_output") {
auto input_shape = data->shape;
int split_size = classes + 5;
Array <Expr> intermediate_shape = {input_shape[0],
num,
split_size,
input_shape[2],
input_shape[3]};
auto data_block = reshape(data, intermediate_shape);
Array <Expr> split_indices = {2, 4};
Array <Tensor> split_res = split(data_block, split_indices, 2);
split_res.Set(0, sigmoid(split_res[0]));
split_res.Set(2, sigmoid(split_res[2]));
Tensor out = concatenate(split_res, 2);
return reshape(out, input_shape);
}
} // namespace yolo
} // namespace vision
} // namespace topi
#endif // TOPI_VISION_YOLO_YOLO_H_
......@@ -15,7 +15,6 @@ from .upsampling_python import upsampling_python
from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python
from .region_python import region_python
from .yolo_python import yolo_python
from .shortcut_python import shortcut_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Yolo operator in python"""
import numpy as np
def entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
def yolo_python(a_np, N, classes):
"""Yolo operator
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
N : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width)
outputs = batch*in_channel*in_height*in_width
b_np = np.zeros(batch*in_channel*in_height*in_width)
for i in range(batch*in_channel*in_height*in_width):
b_np[i] = a_np_temp[i]
for b in range(batch):
for n in range(N):
index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 0)
b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height]))
index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 4)
b_np[index: index+(1+classes)*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+(1+classes)*in_width*in_height]))
b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width))
return b_np
......@@ -3,4 +3,3 @@
from __future__ import absolute_import as _abs
from .region import *
from .yolo import *
# pylint: disable=invalid-name, unused-variable
"""
YOLO Operator
=============
YOLO operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from ... import cpp
@tvm.target.generic_func
def yolo(data, num, classes):
"""YOLO forward operators.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
num : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
Returns
-------
out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
"""
return cpp.yolo.yolo(data, num, classes)
......@@ -32,7 +32,6 @@
#include <topi/vision/reorg.h>
#include <topi/image/resize.h>
#include <topi/vision/yolo/region.h>
#include <topi/vision/yolo/yolo.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
#include <topi/generic/injective.h>
......@@ -413,11 +412,6 @@ TVM_REGISTER_GLOBAL("topi.vision.yolo.region")
*rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]);
});
TVM_REGISTER_GLOBAL("topi.vision.yolo.yolo")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo::yolo(args[0], args[1], args[2]);
});
/* Ops from image/resize.h */
TVM_REGISTER_GLOBAL("topi.image.resize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
"""Test code for yolo op"""
import logging
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def verify_yolo(ishape, n, classes):
'''Verify yolo operator by comparing outputs from tvm and numpy implementation'''
A = tvm.placeholder(ishape, name='A')
B = topi.cpp.yolo.yolo(A, n, classes)
dtype = A.dtype
def get_ref_data_yolo():
'''Randomly initialize the data variables and get refernce output for the yolo operation'''
a_np = np.random.uniform(size=ishape).astype(dtype)
b_np = topi.testing.yolo_python(a_np, n, classes)
return a_np, b_np
a_np, b_np = get_ref_data_yolo()
def check_device(device):
'''Check the device is available and if so, build and run the program'''
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="yolo")
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
check_device(device)
def test_yolo():
verify_yolo((1, 425, 19, 19), 5, 80)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_yolo()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment