Commit 0edb332f by Siju Committed by Tianqi Chen

REGION op removed from topi and added in darkent frontend (#2275)

parent 2a898181
......@@ -302,18 +302,29 @@ def _darknet_reorg(inputs, attrs):
def _darknet_region(inputs, attrs):
"""Process the region operation."""
op_name, new_attrs = 'yolo_region', {}
if 'n' in attrs:
new_attrs['n'] = attrs.get('n', 1)
if 'classes' in attrs:
new_attrs['classes'] = attrs.get('classes', 1)
if 'coords' in attrs:
new_attrs['coords'] = attrs.get('coords', 0)
if 'background' in attrs:
new_attrs['background'] = attrs.get('background', 0)
if 'softmax' in attrs:
new_attrs['softmax'] = attrs.get('softmax', 0)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
num = attrs.get('n', 1)
classes = attrs.get('classes', 1)
coords = attrs.get('coords', 0)
background = attrs.get('background', 0)
softmax = attrs.get('softmax', True)
input_shape = attrs.get('shape')
split_size = classes + coords + 1
intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3])
data_block = _sym.reshape(inputs[0], shape=intermediate_shape)
split_indices = (2, 4, 5)
split_res = _sym.split(data_block, indices_or_sections=split_indices, axis=2)
split_res0 = _sym.sigmoid(split_res[0])
if not background:
split_res2 = _sym.sigmoid(split_res[2])
else:
split_res2 = split_res[2]
if softmax:
split_res3 = _sym.softmax(split_res[3], axis=2)
concat_list = [split_res0, split_res[1], split_res2, split_res3]
out = _sym.concatenate(*concat_list, axis=2)
return _sym.reshape(out, shape=input_shape), None
def _darknet_yolo(inputs, attrs):
"""Process the yolo operation."""
......@@ -638,6 +649,7 @@ class GraphProto(object):
attr.update({'coords' : layer.coords})
attr.update({'background' : layer.background})
attr.update({'softmax' : layer.softmax})
attr.update({'shape' : (1, layer.c, layer.h, layer.w)})
elif LAYERTYPE.YOLO == layer.type:
attr.update({'n' : layer.n})
......
......@@ -20,24 +20,6 @@ def schedule_reorg(attrs, outs, target):
reg.register_pattern("yolo_reorg", OpPattern.INJECTIVE)
@reg.register_compute("yolo_region")
def compute_region(attrs, inputs, _):
"""Compute definition of region"""
n = attrs.get_int("n")
classes = attrs.get_int("classes")
coords = attrs.get_int("coords")
background = attrs.get_int("background")
softmax = attrs.get_int("softmax")
return topi.vision.yolo.region(inputs[0], n, classes, coords, background, softmax)
@reg.register_schedule("yolo_region")
def schedule_region(attrs, outs, target):
"""Schedule definition of region"""
with tvm.target.create(target):
return topi.generic.vision.schedule_region(outs)
reg.register_pattern("yolo_region", OpPattern.OPAQUE)
# multibox_prior
@reg.register_schedule("multibox_prior")
def schedule_multibox_prior(_, outs, target):
......
/*!
* Copyright (c) 2018 by Contributors
* \file region.cc
* \brief Property def of pooling operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../op_common.h"
#include "region.h"
namespace nnvm {
namespace top {
NNVM_REGISTER_OP(yolo_region)
.describe(R"code(Region layer
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferType>("FInferType", RegionType<1, 1>)
.set_attr<FInferShape>("FInferShape", RegionShape<1, 1>)
.set_attr<FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
})
.set_attr<FGradient>("FGradient", [](const NodePtr &n,
const std::vector<NodeEntry> &ograds) {
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file region.h
*/
#ifndef NNVM_TOP_VISION_YOLO_REGION_H_
#define NNVM_TOP_VISION_YOLO_REGION_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
template <typename AttrType,
bool (*is_none)(const AttrType &),
bool (*assign)(AttrType *,
const AttrType &),
bool reverse_infer,
std::string (*attr_string)(const AttrType &),
int n_in = -1,
int n_out = -1>
inline bool RegionAttr(const nnvm::NodeAttrs &attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType &none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1) {
in_size = static_cast<size_t>(n_in);
}
if (n_out != -1) {
out_size = static_cast<size_t>(n_out);
}
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (i == 0)
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th " << name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
deduce(in_attrs, in_size, "input");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
write(out_attrs, out_size, "output");
if (is_none(dattr)) {
return false;
}
return true;
}
template <int n_in, int n_out>
inline bool RegionShape(const NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
template <int n_in, int n_out>
inline bool RegionType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO_REGION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file cuda/vision.h
* \brief CUDA schedule for vision operations
*/
#ifndef TOPI_CUDA_VISION_H_
#define TOPI_CUDA_VISION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/contrib/cublas.h"
#include "topi/generic/extern.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for region
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_region(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto output = outs[0]->op.output(0);
auto num_thread = 64;
auto _schedule_softmax = [&](const Operation& softmax_op) {
auto softmax_inputs = softmax_op->InputTensors();
auto softmax = softmax_inputs[0];
auto max_elem = softmax_inputs[1];
auto expsum = softmax_inputs[2];
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar ko, ki;
s[expsum].split(k, num_thread, &ko, &ki);
auto ef = s.rfactor(expsum, ki)[0];
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->axis[0], block_x);
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
s[ef].compute_at(s[expsum], s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0]);
s[expsum].set_store_predicate(static_cast<Expr>(thread_x) == 0);
IterVar tx, xi;
s[softmax_op].split_by_nparts(softmax_op.as<ComputeOpNode>()->axis[1], num_thread, &tx, &xi);
s[softmax_op].bind(tx, thread_x);
return max_elem->op.as<ComputeOpNode>()->InputTensors()[0];
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "softmax_output") {
auto tensor = _schedule_softmax(op);
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
auto k = output->op.as<ComputeOpNode>()->axis[0];
IterVar bx, tx;
s[output].split(k, num_thread, &bx, &tx);
s[output].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
s[output].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_VISION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file rocm/vision.h
* \brief rocm schedule for region operation
*/
#ifndef TOPI_ROCM_VISION_H_
#define TOPI_ROCM_VISION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/contrib/rocblas.h"
#include "topi/generic/extern.h"
#include "topi/cuda/vision.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for region
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_region(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_region(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_VISION_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief Region op constructions
* \file vision/yolo/region.h
*/
#ifndef TOPI_VISION_YOLO_REGION_H_
#define TOPI_VISION_YOLO_REGION_H_
#include <algorithm>
#include <string>
#include "topi/detail/constant_utils.h"
#include "topi/reduction.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "topi/nn/softmax.h"
#include "tvm/tvm.h"
namespace topi {
namespace vision {
namespace yolo {
using namespace tvm;
using namespace nn;
/*!
* \brief region operation
*
* \param data The input tensor. Can be any dimension
* \param num Darknet layer parameter n
* \param classes number of classes in the yolo model
* \param coords Darknet layer parameter coords
* \param background Darknet layer parameter background
* \param l_softmax if true apply softmax
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the region operation
*/
inline Tensor region(const Tensor &data,
int num,
int classes,
int coords,
int background,
int l_softmax,
std::string name = "tensor",
std::string tag = "region_output") {
auto input_shape = data->shape;
int split_size = classes + coords + 1;
Array <Expr> intermediate_shape = {input_shape[0],
num,
split_size,
input_shape[2],
input_shape[3]};
auto data_block = reshape(data, intermediate_shape);
Array <Integer> split_indices;
for (int i = 1; i < split_size; ++i) {
split_indices.push_back(i);
}
Array <Tensor> split_res = split(data_block, split_indices, 2);
split_res.Set(0, sigmoid(split_res[0]));
split_res.Set(1, sigmoid(split_res[1]));
if (!background) {
split_res.Set(coords, sigmoid(split_res[coords]));
}
if (l_softmax) {
int offset = coords + static_cast<int>(!background);
Array <Tensor> softmax_input(split_res.begin() + offset, split_res.end());
auto softmax_output = softmax(concatenate(softmax_input, 2), 2);
Array <Tensor> data_block_1(split_res.begin(), split_res.begin() + offset);
data_block_1.push_back(softmax_output);
split_res = data_block_1;
}
Tensor out = concatenate(split_res, 2);
return reshape(out, input_shape);
}
} // namespace yolo
} // namespace vision
} // namespace topi
#endif // TOPI_VISION_YOLO_REGION_H_
......@@ -61,24 +61,6 @@ def schedule_reorg(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_injective(cpp_target, outs)
@generic.schedule_region.register(["cuda", "gpu"])
def schedule_region(outs):
"""Schedule for region operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for region.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_region(cpp_target, outs)
@generic.schedule_nms.register(["cuda", "gpu"])
def schedule_nms(outs):
"""Schedule for non-maximum suppression
......
......@@ -18,23 +18,6 @@ def _default_schedule(outs, auto_inline):
return s
@tvm.target.generic_func
def schedule_shortcut(outs):
"""Schedule for shortcut
Parameters
----------
outs: Array of Tensor
The computation graph description of shortcut
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reorg(outs):
"""Schedule for reorg
......@@ -54,25 +37,6 @@ def schedule_reorg(outs):
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_region(outs):
"""Schedule for region
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_nms(outs):
"""Schedule for non-maximum suppression
......
......@@ -4,5 +4,4 @@ from __future__ import absolute_import as _abs
from .conv2d import *
from .dense import *
from .vision import *
from .nn import *
# pylint: disable=invalid-name, unused-variable
"""Schedule for vision operator"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import cpp
@generic.schedule_region.register(["rocm"])
def schedule_region(outs):
"""Schedule for region operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for region.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_region(cpp_target, outs)
......@@ -14,9 +14,7 @@ from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python
from .region_python import region_python
from .roi_align_python import roi_align_nchw_python
from .shortcut_python import shortcut_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
......
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Region in python"""
import numpy as np
def entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
def region_python(a_np, N, classes, coords, background, softmax):
"""Region operator
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
N : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
coords : int
Darknet layer parameter coords
background : int
Darknet layer parameter background
softmax : int
Darknet layer parameter softmax
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width)
outputs = batch*in_channel*in_height*in_width
b_np = np.zeros(batch*in_channel*in_height*in_width)
for i in range(batch*in_channel*in_height*in_width):
b_np[i] = a_np_temp[i]
for b in range(batch):
for n in range(N):
index = entry_index(b, in_width, in_height, outputs, classes, coords, n*in_width*in_height, 0)
b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height]))
index = entry_index(b, in_width, in_height, outputs, classes, coords, n*in_width*in_height, coords)
if not background:
b_np[index: index+in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+in_width*in_height]))
b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width))
def local_softmax(data_in):
data_c, data_h, data_w = data_in.shape
largest = np.max(data_in, axis=1)
data_out = np.zeros((data_c, data_h, data_w))
for i in range(data_h):
for j in range(data_w):
data_out[:, i, j] = np.exp(data_in[:, i, j] - largest[i, j])
return data_out/data_out.sum(axis=0)
if softmax:
index = coords + int(not background)
for b in range(batch):
for i in range(N):
b_np_index = int(i*(in_channel/N) + index)
b_np[b, b_np_index: b_np_index + classes+background, :, :] = local_softmax(b_np[b, b_np_index:b_np_index + classes+background, :, :])
return b_np
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Shortcut in python"""
import numpy as np
def shortcut_python(a_np1, a_np2):
"""Reorg operator
Parameters
----------
a_np1 : numpy.ndarray
4-D with shape [batch1, in_channel1, in_height1, in_width1]
a_np2 : numpy.ndarray
4-D with shape [batch2, in_channel2, in_height2, in_width2]
Returns
-------
b_np : np.ndarray
4-D with shape [batch1, out_channel1, out_height1, out_width1]
"""
batch1, in_channel1, in_height1, in_width1 = a_np1.shape
batch2, in_channel2, in_height2, in_width2 = a_np2.shape
a_np1_temp = np.reshape(a_np1, batch1*in_channel1*in_height1*in_width1)
a_np2_temp = np.reshape(a_np2, batch2*in_channel2*in_height2*in_width2)
b_np = np.zeros(batch1*in_channel1*in_height1*in_width1)
stride = int(in_width1/in_width2)
sample = int(in_width2/in_width1)
if stride < 1:
stride = 1
if sample < 1:
sample = 1
minw = min(in_width1, in_width2)
minh = min(in_height1, in_height2)
minc = min(in_channel1, in_channel2)
for i in range((batch1*in_channel1*in_height1*in_width1)):
b_np[i] = a_np1_temp[i]
for b in range(batch1):
for k in range(minc):
for j in range(minh):
for i in range(minw):
out_index = i*sample + in_width2*(j*sample + in_height2*(k + in_channel2*b))
add_index = i*stride + in_width1*(j*stride + in_height1*(k + in_channel1*b))
b_np[out_index] = a_np1_temp[out_index] + a_np2_temp[add_index]
b_np = np.reshape(b_np, (batch1, in_channel1, in_height1, in_width1))
return b_np
......@@ -2,8 +2,7 @@
"""VISION network operators"""
from __future__ import absolute_import as _abs
from . import yolo, ssd
from .shortcut import *
from . import ssd
from .reorg import *
from .nms import *
from .rcnn import *
"""Shortcut operators (short-cut connections)."""
from __future__ import absolute_import as _abs
import tvm
from .. import util
from .. import transform
@tvm.target.generic_func
def shortcut(inp1, inp2):
"""Shortcut forward operators.
Parameters
----------
First Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Second Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
_, inp1_c, inp1_h, inp1_w = util.get_const_tuple(inp1.shape)
batch, inp2_c, inp2_h, inp2_w = util.get_const_tuple(inp2.shape)
stride = int(max(inp2_w / inp1_w, 1))
sample = int(max(inp1_w / inp2_w, 1))
minc = min(inp2_c, inp1_c)
minh = min(inp2_h, inp1_h)
minw = min(inp2_w, inp1_w)
out = tvm.compute((batch, minc, minh, minw), lambda b, c, h, w:
inp1[b, c, h * sample, w * sample] +
inp2[b, c, h * stride, w * stride],
tag="shortcut")
split_indices = int(inp1_c / minc)
if split_indices > 1:
split_res = transform.split(inp1, split_indices, 1)
split_res[0] = out
out = transform.concatenate(split_res, 1)
return out
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from .region import *
# pylint: disable=invalid-name, unused-variable
"""
REGION Operator
====================
Region operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from ... import cpp
@tvm.target.generic_func
def region(data, num, classes, coords, background, softmax=True):
"""Region forward operators.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
num : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
coords : int
Darknet layer parameter coords
background : int
Darknet layer parameter background
softmax : boolean
Darknet layer parameter softmax
Returns
-------
out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
"""
return cpp.yolo.region(data, num, classes, coords, background, softmax)
......@@ -30,7 +30,6 @@
#include <topi/vision/reorg.h>
#include <topi/image/resize.h>
#include <topi/vision/yolo/region.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
#include <topi/generic/injective.h>
......@@ -41,7 +40,6 @@
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h>
#include <topi/cuda/vision.h>
#include <topi/cuda/normalization.h>
#include <topi/x86/bnn.h>
......@@ -49,7 +47,6 @@
#include <topi/x86/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/vision.h>
#include <topi/rocm/normalization.h>
namespace topi {
......@@ -416,11 +413,6 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg")
*rv = vision::reorg(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.vision.yolo.region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo::region(args[0], args[1], args[2], args[3], args[4], args[5]);
});
/* Ops from image/resize.h */
TVM_REGISTER_GLOBAL("topi.image.bilinear_sample_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -488,11 +480,6 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_region(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
......@@ -544,11 +531,6 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_region(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_lrn(args[0], args[1]);
......
"""Example code to do region."""
import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
import topi.testing
def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_softmax):
'''Verify region operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.vision.yolo.region(A, n, classes, coords, background, l_softmax)
a_shape = get_const_tuple(A.shape)
dtype = A.dtype
def get_ref_data_region():
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.region_python(a_np, n, classes, coords, background, l_softmax)
return a_np, b_np
a_np, b_np = get_ref_data_region()
def check_device(device):
'''Cheching devices is enabled or not'''
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
s = topi.generic.vision.schedule_region([B])
else:
s = topi.cuda.vision.schedule_region([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_region():
verify_region(1, 19, 425, 5, 80, 4, 0, 1)
if __name__ == "__main__":
test_region()
"""Example code to do shortcut."""
import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
def verify_shortcut(batch, in_size, in_channel):
'''Verify shortcut operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A1 = tvm.placeholder((batch, in_channel, in_height, in_width), name='A1')
A2 = tvm.placeholder((batch, in_channel, in_height, in_width), name='A2')
B = topi.vision.shortcut(A1, A2)
a_shape = get_const_tuple(A1.shape)
dtype = A1.dtype
def get_ref_data_shortcut():
a_np1 = np.random.uniform(size=a_shape).astype(dtype)
a_np2 = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.shortcut_python(a_np1, a_np2)
return a_np1, a_np2, b_np
a_np1, a_np2, b_np = get_ref_data_shortcut()
def check_device(device):
'''Cheching devices is enabled or not'''
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective([B])
a1 = tvm.nd.array(a_np1, ctx)
a2 = tvm.nd.array(a_np2, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A1, A2, B], device)
func(a1, a2, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_shortcut():
verify_shortcut(1, 144, 32)
if __name__ == "__main__":
test_shortcut()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment