Commit b30ae8ac by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI][DARKNET]Yolo op added (#1372)

parent 4e7b548e
...@@ -276,14 +276,14 @@ def _darknet_route(inputs, attrs): ...@@ -276,14 +276,14 @@ def _darknet_route(inputs, attrs):
def _darknet_reorg(inputs, attrs): def _darknet_reorg(inputs, attrs):
"""Process the reorg operation.""" """Process the reorg operation."""
op_name, new_attrs = 'yolo2_reorg', {} op_name, new_attrs = 'yolo_reorg', {}
if 'stride' in attrs: if 'stride' in attrs:
new_attrs = {'stride': attrs.get('stride', 1)} new_attrs = {'stride': attrs.get('stride', 1)}
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_region(inputs, attrs): def _darknet_region(inputs, attrs):
"""Process the region operation.""" """Process the region operation."""
op_name, new_attrs = 'yolo2_region', {} op_name, new_attrs = 'yolo_region', {}
if 'n' in attrs: if 'n' in attrs:
new_attrs['n'] = attrs.get('n', 1) new_attrs['n'] = attrs.get('n', 1)
if 'classes' in attrs: if 'classes' in attrs:
......
...@@ -7,20 +7,20 @@ import topi ...@@ -7,20 +7,20 @@ import topi
from . import registry as reg from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
@reg.register_compute("yolo2_reorg") @reg.register_compute("yolo_reorg")
def compute_reorg(attrs, inputs, _): def compute_reorg(attrs, inputs, _):
"""Compute definition of reorg""" """Compute definition of reorg"""
return topi.vision.reorg(inputs[0], attrs.get_int("stride")) return topi.vision.reorg(inputs[0], attrs.get_int("stride"))
@reg.register_schedule("yolo2_reorg") @reg.register_schedule("yolo_reorg")
def schedule_reorg(attrs, outs, target): def schedule_reorg(attrs, outs, target):
"""Schedule definition of reorg""" """Schedule definition of reorg"""
with tvm.target.create(target): with tvm.target.create(target):
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE) reg.register_pattern("yolo_reorg", OpPattern.INJECTIVE)
@reg.register_compute("yolo2_region") @reg.register_compute("yolo_region")
def compute_region(attrs, inputs, _): def compute_region(attrs, inputs, _):
"""Compute definition of region""" """Compute definition of region"""
n = attrs.get_int("n") n = attrs.get_int("n")
...@@ -28,15 +28,15 @@ def compute_region(attrs, inputs, _): ...@@ -28,15 +28,15 @@ def compute_region(attrs, inputs, _):
coords = attrs.get_int("coords") coords = attrs.get_int("coords")
background = attrs.get_int("background") background = attrs.get_int("background")
softmax = attrs.get_int("softmax") softmax = attrs.get_int("softmax")
return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax) return topi.vision.yolo.region(inputs[0], n, classes, coords, background, softmax)
@reg.register_schedule("yolo2_region") @reg.register_schedule("yolo_region")
def schedule_region(attrs, outs, target): def schedule_region(attrs, outs, target):
"""Schedule definition of region""" """Schedule definition of region"""
with tvm.target.create(target): with tvm.target.create(target):
return topi.generic.vision.schedule_region(outs) return topi.generic.vision.schedule_region(outs)
reg.register_pattern("yolo2_region", OpPattern.OPAQUE) reg.register_pattern("yolo_region", OpPattern.OPAQUE)
# multibox_prior # multibox_prior
@reg.register_schedule("multibox_prior") @reg.register_schedule("multibox_prior")
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace nnvm { namespace nnvm {
namespace top { namespace top {
NNVM_REGISTER_OP(yolo2_region) NNVM_REGISTER_OP(yolo_region)
.describe(R"code(Region layer .describe(R"code(Region layer
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file region.h * \file region.h
*/ */
#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_ #ifndef NNVM_TOP_VISION_YOLO_REGION_H_
#define NNVM_TOP_VISION_YOLO2_REGION_H_ #define NNVM_TOP_VISION_YOLO_REGION_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -98,4 +98,4 @@ inline bool RegionType(const NodeAttrs &attrs, ...@@ -98,4 +98,4 @@ inline bool RegionType(const NodeAttrs &attrs,
} }
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REGION_H_ #endif // NNVM_TOP_VISION_YOLO_REGION_H_
...@@ -34,7 +34,7 @@ inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs, ...@@ -34,7 +34,7 @@ inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs,
return true; return true;
} }
NNVM_REGISTER_OP(yolo2_reorg) NNVM_REGISTER_OP(yolo_reorg)
.describe(R"(Perform reorg operation on input array based on the stride value. .describe(R"(Perform reorg operation on input array based on the stride value.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width). - **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride). - **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride).
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file reorg.h * \file reorg.h
*/ */
#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_ #ifndef NNVM_TOP_VISION_YOLO_REORG_H_
#define NNVM_TOP_VISION_YOLO2_REORG_H_ #define NNVM_TOP_VISION_YOLO_REORG_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -107,4 +107,4 @@ struct ReorgParam : public dmlc::Parameter<ReorgParam> { ...@@ -107,4 +107,4 @@ struct ReorgParam : public dmlc::Parameter<ReorgParam> {
}; };
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REORG_H_ #endif // NNVM_TOP_VISION_YOLO_REORG_H_
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \brief Region op constructions * \brief Region op constructions
* \file vision/yolo2/region.h * \file vision/yolo/region.h
*/ */
#ifndef TOPI_VISION_YOLO2_REGION_H_ #ifndef TOPI_VISION_YOLO_REGION_H_
#define TOPI_VISION_YOLO2_REGION_H_ #define TOPI_VISION_YOLO_REGION_H_
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace topi { namespace topi {
namespace vision { namespace vision {
namespace yolo2 { namespace yolo {
using namespace tvm; using namespace tvm;
using namespace nn; using namespace nn;
...@@ -75,7 +75,7 @@ inline Tensor region(const Tensor &data, ...@@ -75,7 +75,7 @@ inline Tensor region(const Tensor &data,
Tensor out = concatenate(split_res, 2); Tensor out = concatenate(split_res, 2);
return reshape(out, input_shape); return reshape(out, input_shape);
} }
} // namespace yolo2 } // namespace yolo
} // namespace vision } // namespace vision
} // namespace topi } // namespace topi
#endif // TOPI_VISION_YOLO2_REGION_H_ #endif // TOPI_VISION_YOLO_REGION_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief YOLO op constructions
* \file vision/yolo/yolo.h
*/
#ifndef TOPI_VISION_YOLO_YOLO_H_
#define TOPI_VISION_YOLO_YOLO_H_
#include <algorithm>
#include <string>
#include "topi/detail/constant_utils.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "tvm/tvm.h"
namespace topi {
namespace vision {
namespace yolo {
using namespace tvm;
using namespace nn;
/*!
* \brief yolo operation
*
* \param data The input tensor.
* \param num Darknet layer parameter n
* \param classes number of classes in the yolo model
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the yolo operation
*/
inline Tensor yolo(const Tensor &data,
int num,
int classes,
std::string name = "tensor",
std::string tag = "yolo_output") {
auto input_shape = data->shape;
int split_size = classes + 5;
Array <Expr> intermediate_shape = {input_shape[0],
num,
split_size,
input_shape[2],
input_shape[3]};
auto data_block = reshape(data, intermediate_shape);
Array <Expr> split_indices = {2, 4};
Array <Tensor> split_res = split(data_block, split_indices, 2);
split_res.Set(0, sigmoid(split_res[0]));
split_res.Set(2, sigmoid(split_res[2]));
Tensor out = concatenate(split_res, 2);
return reshape(out, input_shape);
}
} // namespace yolo
} // namespace vision
} // namespace topi
#endif // TOPI_VISION_YOLO_YOLO_H_
...@@ -46,7 +46,7 @@ x86 = _create_module("x86") ...@@ -46,7 +46,7 @@ x86 = _create_module("x86")
_init_api_prefix("topi.cpp.x86", "topi.x86") _init_api_prefix("topi.cpp.x86", "topi.x86")
vision = _create_module("vision") vision = _create_module("vision")
_init_api_prefix("topi.cpp.vision", "topi.vision") _init_api_prefix("topi.cpp.vision", "topi.vision")
yolo2 = _create_module("vision.yolo2") yolo = _create_module("vision.yolo")
_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2") _init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
image = _create_module("image") image = _create_module("image")
_init_api_prefix("topi.cpp.image", "topi.image") _init_api_prefix("topi.cpp.image", "topi.image")
...@@ -15,6 +15,7 @@ from .upsampling_python import upsampling_python ...@@ -15,6 +15,7 @@ from .upsampling_python import upsampling_python
from .bilinear_resize_python import bilinear_resize_python from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python from .reorg_python import reorg_python
from .region_python import region_python from .region_python import region_python
from .yolo_python import yolo_python
from .shortcut_python import shortcut_python from .shortcut_python import shortcut_python
from .lrn_python import lrn_python from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python from .l2_normalize_python import l2_normalize_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Yolo operator in python"""
import numpy as np
def entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
def yolo_python(a_np, N, classes):
"""Yolo operator
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
N : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width)
outputs = batch*in_channel*in_height*in_width
b_np = np.zeros(batch*in_channel*in_height*in_width)
for i in range(batch*in_channel*in_height*in_width):
b_np[i] = a_np_temp[i]
for b in range(batch):
for n in range(N):
index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 0)
b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height]))
index = entry_index(b, in_width, in_height, outputs, classes, 4, n*in_width*in_height, 4)
b_np[index: index+(1+classes)*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+(1+classes)*in_width*in_height]))
b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width))
return b_np
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""VISION network operators""" """VISION network operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import yolo2, ssd from . import yolo, ssd
from .shortcut import * from .shortcut import *
from .reorg import * from .reorg import *
from .nms import * from .nms import *
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .region import * from .region import *
from .yolo import *
...@@ -36,4 +36,4 @@ def region(data, num, classes, coords, background, softmax=True): ...@@ -36,4 +36,4 @@ def region(data, num, classes, coords, background, softmax=True):
out : tvm.Tensor out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in] 4-D with shape [batch, c_in, h_in, w_in]
""" """
return cpp.yolo2.region(data, num, classes, coords, background, softmax) return cpp.yolo.region(data, num, classes, coords, background, softmax)
# pylint: disable=invalid-name, unused-variable
"""
YOLO Operator
=============
YOLO operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from ... import cpp
@tvm.target.generic_func
def yolo(data, num, classes):
"""YOLO forward operators.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
num : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
Returns
-------
out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
"""
return cpp.yolo.yolo(data, num, classes)
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <topi/vision/reorg.h> #include <topi/vision/reorg.h>
#include <topi/image/resize.h> #include <topi/image/resize.h>
#include <topi/vision/yolo2/region.h> #include <topi/vision/yolo/region.h>
#include <topi/vision/yolo/yolo.h>
#include <topi/generic/default.h> #include <topi/generic/default.h>
#include <topi/generic/extern.h> #include <topi/generic/extern.h>
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
...@@ -386,9 +387,14 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg") ...@@ -386,9 +387,14 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg")
*rv = vision::reorg(args[0], args[1]); *rv = vision::reorg(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.vision.yolo2.region") TVM_REGISTER_GLOBAL("topi.vision.yolo.region")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo2::region(args[0], args[1], args[2], args[3], args[4], args[5]); *rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]);
});
TVM_REGISTER_GLOBAL("topi.vision.yolo.yolo")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo::yolo(args[0], args[1], args[2]);
}); });
/* Ops from image/resize.h */ /* Ops from image/resize.h */
......
...@@ -10,7 +10,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_ ...@@ -10,7 +10,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.vision.yolo2.region(A, n, classes, coords, background, l_softmax) B = topi.vision.yolo.region(A, n, classes, coords, background, l_softmax)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
dtype = A.dtype dtype = A.dtype
......
...@@ -11,7 +11,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_ ...@@ -11,7 +11,7 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_
in_height = in_width = in_size in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.cpp.yolo2.region(A, n, classes, coords, background, l_softmax) B = topi.cpp.yolo.region(A, n, classes, coords, background, l_softmax)
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
dtype = A.dtype dtype = A.dtype
......
"""Test code for yolo op"""
import logging
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def verify_yolo(ishape, n, classes):
'''Verify yolo operator by comparing outputs from tvm and numpy implementation'''
A = tvm.placeholder(ishape, name='A')
B = topi.cpp.yolo.yolo(A, n, classes)
dtype = A.dtype
def get_ref_data_yolo():
'''Randomly initialize the data variables and get refernce output for the yolo operation'''
a_np = np.random.uniform(size=ishape).astype(dtype)
b_np = topi.testing.yolo_python(a_np, n, classes)
return a_np, b_np
a_np, b_np = get_ref_data_yolo()
def check_device(device):
'''Check the device is available and if so, build and run the program'''
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="yolo")
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
check_device(device)
def test_yolo():
verify_yolo((1, 425, 19, 19), 5, 80)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_yolo()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment