Commit 6fdb662e by Pariksheet Pinjari Committed by Tianqi Chen

CPP support for region and reorg operators (#1115)

parent 431a42bf
/*!
* Copyright (c) 2018 by Contributors
* \file cuda/vision.h
* \brief CUDA schedule for vision operations
*/
#ifndef TOPI_CUDA_VISION_H_
#define TOPI_CUDA_VISION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/contrib/cublas.h"
#include "topi/generic/extern.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for region
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_region(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto output = outs[0]->op.output(0);
auto num_thread = 64;
auto _schedule_softmax = [&](const Operation& softmax_op) {
auto softmax_inputs = softmax_op->InputTensors();
auto softmax = softmax_inputs[0];
auto max_elem = softmax_inputs[1];
auto expsum = softmax_inputs[2];
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar ko, ki;
s[expsum].split(k, num_thread, &ko, &ki);
auto ef = s.rfactor(expsum, ki)[0];
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->axis[0], block_x);
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
s[ef].compute_at(s[expsum], s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0]);
s[expsum].set_store_predicate(static_cast<Expr>(thread_x) == 0);
IterVar tx, xi;
s[softmax_op].split_by_nparts(softmax_op.as<ComputeOpNode>()->axis[1], num_thread, &tx, &xi);
s[softmax_op].bind(tx, thread_x);
return max_elem->op.as<ComputeOpNode>()->InputTensors()[0];
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "softmax_output") {
auto tensor = _schedule_softmax(op);
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
auto k = output->op.as<ComputeOpNode>()->axis[0];
IterVar bx, tx;
s[output].split(k, num_thread, &bx, &tx);
s[output].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
s[output].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_VISION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file rocm/vision.h
* \brief rocm schedule for region operation
*/
#ifndef TOPI_ROCM_VISION_H_
#define TOPI_ROCM_VISION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/contrib/rocblas.h"
#include "topi/generic/extern.h"
#include "topi/cuda/vision.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for region
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_region(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_region(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_VISION_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief Reorg op constructions
* \file vision/reorg.h
*/
#ifndef TOPI_VISION_REORG_H_
#define TOPI_VISION_REORG_H_
#include <algorithm>
#include <string>
#include "topi/detail/constant_utils.h"
#include "topi/reduction.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "tvm/tvm.h"
namespace topi {
namespace vision {
using namespace tvm;
/*!
* \brief Reorg operation
*
* \param data The input tensor. Can be any dimension
* \param stride The input integer used as stride in reorg operation
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the reorg operation
*/
inline Tensor reorg(const Tensor &data,
int stride = 1,
std::string name = "tensor",
std::string tag = "reorg_output") {
auto input_shape = data->shape;
int batch = GetConstInt(input_shape[0]);
int c_in = GetConstInt(input_shape[1]);
int h_in = GetConstInt(input_shape[2]);
int w_in = GetConstInt(input_shape[3]);
int out_c = c_in / (stride * stride);
auto out = tvm::compute(input_shape,
[&](Var b, Var k, Var j, Var i) {
return data(b * stride * stride,
(k % out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride,
(i*stride + (k / out_c) % stride));
},
name,
tag);
out_c = c_in * stride * stride;
int out_h = h_in / stride;
int out_w = w_in / stride;
Array<Expr> out_shape = {batch, out_c, out_h, out_w};
return reshape(out, out_shape);
}
} // namespace vision
} // namespace topi
#endif // TOPI_VISION_REORG_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief Region op constructions
* \file vision/yolo2/region.h
*/
#ifndef TOPI_VISION_YOLO2_REGION_H_
#define TOPI_VISION_YOLO2_REGION_H_
#include <algorithm>
#include <string>
#include "topi/detail/constant_utils.h"
#include "topi/reduction.h"
#include "topi/tags.h"
#include "topi/transform.h"
#include "topi/nn/softmax.h"
#include "tvm/tvm.h"
namespace topi {
namespace vision {
namespace yolo2 {
using namespace tvm;
using namespace nn;
/*!
* \brief region operation
*
* \param data The input tensor. Can be any dimension
* \param num Darknet layer parameter n
* \param classes number of classes in the yolo model
* \param coords Darknet layer parameter coords
* \param background Darknet layer parameter background
* \param l_softmax if true apply softmax
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the region operation
*/
inline Tensor region(const Tensor &data,
int num,
int classes,
int coords,
int background,
int l_softmax,
std::string name = "tensor",
std::string tag = "region_output") {
auto input_shape = data->shape;
int split_size = classes + coords + 1;
Array <Expr> intermediate_shape = {input_shape[0],
num,
split_size,
input_shape[2],
input_shape[3]};
auto data_block = reshape(data, intermediate_shape);
Array <Expr> split_indices;
for (int i = 1; i < split_size; ++i) {
split_indices.push_back(i);
}
Array <Tensor> split_res = split(data_block, split_indices, 2);
split_res.Set(0, sigmoid(split_res[0]));
split_res.Set(1, sigmoid(split_res[1]));
if (!background) {
split_res.Set(coords, sigmoid(split_res[coords]));
}
if (l_softmax) {
int offset = coords + static_cast<int>(!background);
Array <Tensor> softmax_input(split_res.begin() + offset, split_res.end());
auto softmax_output = softmax(concatenate(softmax_input, 2), 2);
Array <Tensor> data_block_1(split_res.begin(), split_res.begin() + offset);
data_block_1.push_back(softmax_output);
split_res = data_block_1;
}
Tensor out = concatenate(split_res, 2);
return reshape(out, input_shape);
}
} // namespace yolo2
} // namespace vision
} // namespace topi
#endif // TOPI_VISION_YOLO2_REGION_H_
......@@ -46,6 +46,10 @@ rocm = _create_module("rocm")
_init_api_prefix("topi.cpp.rocm", "topi.rocm")
x86 = _create_module("x86")
_init_api_prefix("topi.cpp.x86", "topi.x86")
vision = _create_module("vision")
_init_api_prefix("topi.cpp.vision", "topi.vision")
yolo2 = _create_module("vision.yolo2")
_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2")
class IntVector(object):
"""Handle to std::vector<int> instance """
......
......@@ -16,4 +16,5 @@ from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
from .vision import schedule_region
from .vision import schedule_reorg
from .nn import schedule_lrn, schedule_l2norm
......@@ -2,8 +2,26 @@
"""Schedule for vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
from .. import cpp
@generic.schedule_reorg.register(["cuda", "gpu"])
def schedule_reorg(outs):
"""Schedule for reorg operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of reorg
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for reorg.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_injective(cpp_target, outs)
@generic.schedule_region.register(["cuda", "gpu"])
def schedule_region(outs):
......@@ -19,47 +37,6 @@ def schedule_region(outs):
s: Schedule
The computation schedule for region.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
output = outs[0].op.output(0)
#thread = 64 for higher size tensors, give resource_unavailable error for higher values
num_thread = 64
def _schedule_softmax(softmax_op):
softmax = softmax_op.input_tensors[0]
max_elem = softmax_op.input_tensors[1]
expsum = softmax_op.input_tensors[2]
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
ef = s.rfactor(expsum, ki)
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[ef].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax_op].split(softmax_op.axis[1], nparts=num_thread)
s[softmax_op].bind(softmax_op.axis[0], block_x)
s[softmax_op].bind(tx, thread_x)
return max_elem.op.input_tensors[0]
def _traverse(op):
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
_traverse(tensor.op)
elif op.tag == 'softmax_output':
tensor = _schedule_softmax(op)
if tensor.op.input_tensors:
_traverse(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
_traverse(outs[0].op)
k = output.op.axis[0]
bx, tx = s[output].split(k, factor=num_thread)
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
return s
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_region(cpp_target, outs)
"""Generic vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import cpp
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
......@@ -47,7 +48,9 @@ def schedule_reorg(outs):
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_region(outs):
......@@ -64,4 +67,6 @@ def schedule_region(outs):
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
# pylint: disable=invalid-name, unused-variable
"""Schedule for vision operator"""
from __future__ import absolute_import as _abs
import topi
import tvm
from .. import generic
from .. import cpp
@generic.schedule_region.register(["rocm"])
def schedule_region(outs):
......@@ -19,4 +20,6 @@ def schedule_region(outs):
s: Schedule
The computation schedule for region.
"""
return topi.cuda.schedule_region(outs)
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_region(cpp_target, outs)
......@@ -5,8 +5,7 @@ Reorg operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from .. import util
from .. import transform
from .. import cpp
@tvm.target.generic_func
def reorg(data, stride):
......@@ -25,15 +24,4 @@ def reorg(data, stride):
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, c_in, h_in, w_in = util.get_const_tuple(data.shape)
out_c = int(c_in / (stride * stride))
out = tvm.compute((batch, c_in, h_in, w_in), lambda b, k, j, i:
data[b * stride * stride,
(k % out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride,
(i*stride + (k / out_c) % stride)],
tag="reorg")
out_c = int(c_in * stride * stride)
out_h = int(h_in / stride)
out_w = int(w_in / stride)
return transform.reshape(out, (batch, out_c, out_h, out_w))
return cpp.vision.reorg(data, stride)
......@@ -6,10 +6,7 @@ Region operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from ... import transform
from ... import util
from ... import math
from ... import nn
from ... import cpp
@tvm.target.generic_func
def region(data, num, classes, coords, background, softmax=True):
......@@ -39,25 +36,4 @@ def region(data, num, classes, coords, background, softmax=True):
out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
"""
batch, c_in, h_in, w_in = util.get_const_tuple(data.shape)
split_indices = classes+coords+1
data_block = transform.reshape(data, (batch, num, split_indices, h_in, w_in))
split_res = transform.split(data_block, split_indices, 2)
split_res[0] = math.sigmoid(split_res[0])
split_res[1] = math.sigmoid(split_res[1])
if not background:
split_res[coords] = math.sigmoid(split_res[coords])
if softmax:
offset = coords + int(not background)
data_block_1 = []
data_block_1.append(transform.concatenate(split_res[0:offset], 2))
temp_out = transform.concatenate(split_res[offset:split_indices], 2)
temp_out = nn.softmax(temp_out, axis=2)
data_block_1.append(temp_out)
split_res = data_block_1
out = transform.concatenate(split_res, 2)
out = transform.reshape(out, data.shape)
return out
return cpp.yolo2.region(data, num, classes, coords, background, softmax)
......@@ -24,6 +24,8 @@
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
#include <topi/vision/reorg.h>
#include <topi/vision/yolo2/region.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
#include <topi/generic/injective.h>
......@@ -34,12 +36,14 @@
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h>
#include <topi/cuda/vision.h>
#include <topi/x86/bnn.h>
#include <topi/x86/default.h>
#include <topi/x86/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/vision.h>
namespace tvm {
namespace runtime {
......@@ -338,6 +342,14 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
*rv = nn::log_softmax(args[0]);
});
TVM_REGISTER_GLOBAL("topi.vision.reorg")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::reorg(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.vision.yolo2.region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo2::region(args[0], args[1], args[2], args[3], args[4], args[5]);
});
/* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -394,6 +406,10 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_region(args[0], args[1]);
});
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -435,6 +451,11 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_region(args[0], args[1]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
......
......@@ -3,6 +3,8 @@ import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
import topi.testing
def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_softmax):
'''Verify region operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
......@@ -27,7 +29,10 @@ def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
s = topi.generic.vision.schedule_region([B])
else:
s = topi.cuda.vision.schedule_region([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device)
......
......@@ -3,6 +3,7 @@ import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
import topi.testing
def verify_reorg(batch, in_size, in_channel, stride):
'''Verify reorg operator by comparing outputs from tvm and numpy implementation'''
......@@ -29,8 +30,10 @@ def verify_reorg(batch, in_size, in_channel, stride):
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective([B])
if device == 'llvm':
s = topi.generic.schedule_reorg([B])
else:
s = topi.cuda.schedule_reorg([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device)
......
"""Test code for region"""
import logging
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_softmax):
'''Verify region operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.cpp.yolo2.region(A, n, classes, coords, background, l_softmax)
a_shape = get_const_tuple(A.shape)
dtype = A.dtype
def get_ref_data_region():
'''Randomly initialize the data variables and get refernce output for the region operation'''
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.region_python(a_np, n, classes, coords, background, l_softmax)
return a_np, b_np
a_np, b_np = get_ref_data_region()
def check_device(device):
'''Check the device is available and if so, build and run the program'''
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.rocm.schedule_region(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="region")
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
check_device(device)
def test_region():
verify_region(1, 19, 425, 5, 80, 4, 0, 1)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_region()
"""Test code for reorg"""
import logging
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def verify_reorg(batch, in_size, in_channel, stride):
'''Verify reorg operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.cpp.vision.reorg(A, stride)
a_shape = get_const_tuple(A.shape)
dtype = A.dtype
def get_ref_data_reorg():
'''Randomly initialize the data variables and get refernce output for the reorg operation'''
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.reorg_python(a_np, stride)
return a_np, b_np
a_np, b_np = get_ref_data_reorg()
def check_device(device):
'''Check the device is available and if so, build and run the program'''
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="reorg")
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm', 'vulkan']:
check_device(device)
def test_reorg():
verify_reorg(1, 38, 64, 2)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_reorg()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment