Commit fb88b74e by Pariksheet Pinjari Committed by Tianqi Chen

CPP implementation of L2Norm and LRN ops (#1157)

parent 464c8c26
...@@ -368,6 +368,41 @@ struct NMSParam : public dmlc::Parameter<NMSParam> { ...@@ -368,6 +368,41 @@ struct NMSParam : public dmlc::Parameter<NMSParam> {
} }
}; };
struct LRNParam : public dmlc::Parameter<LRNParam> {
int size;
int axis;
float alpha;
float beta;
float bias;
DMLC_DECLARE_PARAMETER(LRNParam) {
DMLC_DECLARE_FIELD(size)
.describe("The size of the local region to be considered for normalization.");
DMLC_DECLARE_FIELD(axis)
.describe("input data layout channel axis");
DMLC_DECLARE_FIELD(alpha)
.describe("The scaling parameter.");
DMLC_DECLARE_FIELD(beta)
.describe("The exponent parameter.");
DMLC_DECLARE_FIELD(bias)
.describe("The offset parameter.");
}
// constants
static const constexpr int kData = 0;
};
struct L2NormalizeParam : public dmlc::Parameter<L2NormalizeParam> {
float eps;
Tuple<int> axis;
DMLC_DECLARE_PARAMETER(L2NormalizeParam) {
DMLC_DECLARE_FIELD(eps)
.describe("float type epsilon value.");
DMLC_DECLARE_FIELD(axis)
.describe("axis over the normalization applied");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -243,3 +243,36 @@ def schedule_upsampling(_, outs, target): ...@@ -243,3 +243,36 @@ def schedule_upsampling(_, outs, target):
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
reg.register_pattern("upsampling", OpPattern.INJECTIVE) reg.register_pattern("upsampling", OpPattern.INJECTIVE)
@reg.register_compute("lrn")
def compute_lrn(attrs, inputs, _):
"""Compute definition of lrn"""
size = attrs.get_int("size")
axis = attrs.get_int("axis")
alpha = attrs.get_float("alpha")
beta = attrs.get_float("beta")
bias = attrs.get_float("bias")
return topi.nn.lrn(inputs[0], size, axis, alpha, beta, bias)
@reg.register_schedule("lrn")
def schedule_lrn(attrs, outs, target):
"""Schedule definition of lrn"""
with tvm.target.create(target):
return topi.generic.schedule_lrn(outs)
reg.register_pattern("lrn", OpPattern.OPAQUE)
@reg.register_compute("l2_normalize")
def compute_l2_normalize(attrs, inputs, _):
"""Compute definition of l2 normalize"""
eps = attrs.get_float("eps")
axis = attrs.get_int_tuple("axis")
return topi.nn.l2_normalize(inputs[0], eps, axis)
@reg.register_schedule("l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with tvm.target.create(target):
return topi.generic.schedule_l2_normalize(outs)
reg.register_pattern("l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -712,5 +712,52 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] ...@@ -712,5 +712,52 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
}) })
.set_support_level(1); .set_support_level(1);
DMLC_REGISTER_PARAMETER(LRNParam);
inline bool LRNInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
TShape dshape = (*in_shape)[0];
TShape oshape = dshape;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(lrn)
.describe(R"code(LRN layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.set_attr_parser(ParamParser<LRNParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<LRNParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", LRNInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);
DMLC_REGISTER_PARAMETER(L2NormalizeParam);
inline bool L2NormalizeInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
TShape dshape = (*in_shape)[0];
TShape oshape = dshape;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(l2_normalize)
.describe(R"code(L2NORMALIZE layer)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.set_attr_parser(ParamParser<L2NormalizeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<L2NormalizeParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", L2NormalizeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_support_level(1);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -6,7 +6,6 @@ import nnvm.symbol as sym ...@@ -6,7 +6,6 @@ import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype, def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True): np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {} ishapes = {}
...@@ -365,6 +364,65 @@ def test_pad(): ...@@ -365,6 +364,65 @@ def test_pad():
inputs = [('x', (1, 3, 28, 28), x)] inputs = [('x', (1, 3, 28, 28), x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward)
def verify_lrn(ishape, size, axis, bias, alpha, beta):
x = sym.Variable("x")
y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
#Checking LRN op followed by elementwise op relu
z = sym.relu(y)
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
out_np = (out_np > 0) * out_np
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def verify_l2_normalize(ishape, eps, axis):
x = sym.Variable("x")
y = sym.l2_normalize(x, eps=eps, axis=axis)
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
#Checking L2 normalization op followed by elementwise op relu
z = sym.relu(y)
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty(ishape))
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
out_np = (out_np > 0) * out_np
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_lrn():
verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
def test_l2_normalize():
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
if __name__ == "__main__": if __name__ == "__main__":
test_split() test_split()
...@@ -384,3 +442,5 @@ if __name__ == "__main__": ...@@ -384,3 +442,5 @@ if __name__ == "__main__":
test_softmax() test_softmax()
test_squeeze() test_squeeze()
test_pad() test_pad()
test_lrn()
test_l2_normalize()
/*!
* Copyright (c) 2018 by Contributors
* \file cuda/normalization.h
* \brief CUDA schedule for LRN and l2 normalization operations
*/
#ifndef TOPI_CUDA_NORMALIZATION_H_
#define TOPI_CUDA_NORMALIZATION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for LRN
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);
int num_thread = 64;
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
Tensor lrn = outs[0];
Tensor sqr_sum_up = lrn->op->InputTensors()[1];
Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0];
Tensor set_pad = sqr_sum->op->InputTensors()[0];
s[set_pad].bind(set_pad->op.as<ComputeOpNode>()->axis[0], block_x);
IterVar rxk = sqr_sum->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar xko, xki;
s[sqr_sum].split(rxk, num_thread, &xko, &xki);
Tensor srf = s.rfactor(sqr_sum, xki)[0];
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->axis[0], block_x);
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
s[srf].compute_at(s[sqr_sum], s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0]);
s[sqr_sum_up].bind(sqr_sum_up->op.as<ComputeOpNode>()->axis[0], block_x);
IterVar xto, xti;
s[lrn].split_by_nparts(lrn->op.as<ComputeOpNode>()->axis[1], num_thread, &xto, &xti);
s[lrn].bind(lrn->op.as<ComputeOpNode>()->axis[0], block_x);
s[lrn].bind(xto, thread_x);
return s;
}
/*!
* \brief Create a CUDA schedule for L2 normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag) || op->tag == "l2_normalize") {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "comm_reduce") {
ScheduleReduce(target, op, s, false);
for (auto tensor : op->InputTensors()) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
int num_thread = 64;
Tensor l2_normalize = outs[0];
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
IterVar xto, xti;
s[l2_normalize].split_by_nparts(l2_normalize->op.as<ComputeOpNode>()->axis[1],
num_thread, &xto, &xti);
s[l2_normalize].bind(l2_normalize->op.as<ComputeOpNode>()->axis[0], block_x);
s[l2_normalize].bind(xto, thread_x);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_NORMALIZATION_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief l2 normalization op constructions
* \file nn/l2_normalize.h
*/
#ifndef TOPI_NN_L2_NORMALIZE_H_
#define TOPI_NN_L2_NORMALIZE_H_
#include <string>
#include <algorithm>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief L2 normalization inference operator
*
* \param data The input tensor. 4-D with shape [batch, channel, height, width]
* \param eps Epsilon to prevent div by 0
* \param axis Axes over the normalization applied
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the l2 normalization operation
*/
inline Tensor l2_normalize(const Tensor& data,
float eps,
const Array<Expr>& axis,
std::string name = "tensor",
std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
auto input_shape = data->shape;
Tensor dot_value = pow(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true);
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
return topi::broadcast_div(data,
topi::sqrt(tvm::compute(expand_sum->shape,
[&](const Array<Var>& i){
return (max(expand_sum(i), eps));
}, name = name, tag = tag)));
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_L2_NORMALIZE_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief local response normalization op constructions
* \file nn/local_response_norm.h
*/
#ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_
#define TOPI_NN_LOCAL_RESPONSE_NORM_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Local response normalization inference operator
*
* \param data The input tensor. 4-D shape NCHW or NHWC
* \param size Integer to define normalisation window size
* \param axis Input data layout channel axis
* \param alpha Float scaling factor
* \param beta Exponent value
* \param bias Offset to avoid dividing by zero
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the Local response normalization operation
*/
inline Tensor lrn(const Tensor& data,
int size,
int axis = 1,
float alpha = 0.0001,
float beta = 0.75,
float bias = 2,
std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
CHECK_EQ(size % 2, 1) << "size should be odd number";
CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
auto input_shape = data->shape;
Array<Expr> pad_before{ 0, 0, 0, 0};
Array<Expr> pad_after{ 0, 0, 0, 0};
pad_before.Set(axis, static_cast<Expr>(size/2));
pad_after.Set(axis, static_cast<Expr>(size/2));
auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
Tensor sqr_sum;
if (axis == 1) {
sqr_sum = tvm::compute(input_shape,
[&](Var i, Var l, Var j, Var k) {
return tvm::sum(pad_data(i, l + rxs, j, k) *
pad_data(i, l + rxs, j, k),
{rxs});
});
} else if (axis == 3) {
sqr_sum = tvm::compute(input_shape,
[&](Var i, Var l, Var j, Var k) {
return tvm::sum(pad_data(i, l, j, k + rxs) *
pad_data(i, l, j, k + rxs),
{rxs});
});
}
auto sqrt_sum_up = tvm::compute(input_shape,
[&](Var i, Var j, Var k, Var l) {
return tvm::pow(bias +
(alpha * sqr_sum(i, j, k, l) / size),
beta);
});
return topi::broadcast_div(data, sqrt_sum_up);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_LOCAL_RESPONSE_NORM_H_
/*!
* Copyright (c) 2018 by Contributors
* \file rocm/normalization.h
* \brief rocm schedule for LRN and l2 normalization operations
*/
#ifndef TOPI_ROCM_NORMALIZATION_H_
#define TOPI_ROCM_NORMALIZATION_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for LRN
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_lrn(target, outs);
}
/*!
* \brief Create a rocm schedule for L2 Normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_l2_normalize(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_NORMALIZATION_H_
...@@ -17,4 +17,4 @@ from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw ...@@ -17,4 +17,4 @@ from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern from .extern import schedule_extern
from .vision import schedule_region from .vision import schedule_region
from .vision import schedule_reorg from .vision import schedule_reorg
from .nn import schedule_lrn, schedule_l2norm from .nn import schedule_lrn, schedule_l2_normalize
...@@ -4,8 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,8 +4,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from .. import generic from .. import generic
from .. import tag from .. import cpp
from .reduction import _schedule_reduce
@generic.schedule_lrn.register(["cuda"]) @generic.schedule_lrn.register(["cuda"])
def schedule_lrn(outs): def schedule_lrn(outs):
...@@ -22,37 +21,18 @@ def schedule_lrn(outs): ...@@ -22,37 +21,18 @@ def schedule_lrn(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs target = tvm.target.current_target(allow_none=False)
s = tvm.create_schedule([x.op for x in outs]) cpp_target = cpp.TEST_create_target(target.target_name)
num_thread = 64 return cpp.cuda.schedule_lrn(cpp_target, outs)
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
lrn = outs[0] @generic.schedule_l2_normalize.register(["cuda"])
sqr_sum_up = lrn.op.input_tensors[1] def schedule_l2_normalize(outs):
sqr_sum = sqr_sum_up.op.input_tensors[0] """Schedule for L2 normalize
set_pad = sqr_sum.op.input_tensors[0]
s[set_pad].bind(set_pad.op.axis[0], block_x)
rxk = sqr_sum.op.reduce_axis[0]
_, xki = s[sqr_sum].split(rxk, factor=num_thread)
srf = s.rfactor(sqr_sum, xki)
s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x)
s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x)
s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0])
s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x)
xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread)
s[lrn].bind(lrn.op.axis[0], block_x)
s[lrn].bind(xto, thread_x)
return s
@generic.schedule_l2norm.register(["cuda"])
def schedule_l2norm(outs):
"""Schedule for L2norm
Parameters Parameters
---------- ----------
outs: Array of Tensor outs: Array of Tensor
The computation graph description of L2norm The computation graph description of L2 normalize
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
...@@ -60,32 +40,6 @@ def schedule_l2norm(outs): ...@@ -60,32 +40,6 @@ def schedule_l2norm(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs target = tvm.target.current_target(allow_none=False)
s = tvm.create_schedule([x.op for x in outs]) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_l2_normalize(cpp_target, outs)
def traverse(OP):
'''inline all one-to-one-mapping operators
except the last stage (output)'''
if tag.is_injective(OP.tag) or OP.tag == 'l2norm':
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif OP.tag == 'comm_reduce':
_schedule_reduce(OP, s, is_idx_reduce=False)
for tensor in OP.input_tensors:
traverse(tensor.op)
else:
raise RuntimeError("Unsupported operator tag: %s" % OP.tag)
traverse(outs[0].op)
num_thread = 64
l2norm = outs[0]
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread)
s[l2norm].bind(l2norm.op.axis[0], block_x)
s[l2norm].bind(xto, thread_x)
return s
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Generic nn operators""" """Generic nn operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import cpp
def _default_schedule(outs, auto_inline): def _default_schedule(outs, auto_inline):
"""Default schedule for llvm.""" """Default schedule for llvm."""
...@@ -273,17 +273,18 @@ def schedule_lrn(outs): ...@@ -273,17 +273,18 @@ def schedule_lrn(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs, False) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func @tvm.target.generic_func
def schedule_l2norm(outs): def schedule_l2_normalize(outs):
"""Schedule for l2norm """Schedule for l2 normalize
Parameters Parameters
---------- ----------
outs: Array of Tensor outs: Array of Tensor
The computation graph description of l2norm The computation graph description of l2 normalize
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
...@@ -291,4 +292,6 @@ def schedule_l2norm(outs): ...@@ -291,4 +292,6 @@ def schedule_l2norm(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs, False) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
...@@ -16,4 +16,4 @@ from .conv2d_transpose import * ...@@ -16,4 +16,4 @@ from .conv2d_transpose import *
from .bnn import * from .bnn import *
from .upsampling import * from .upsampling import *
from .local_response_norm import * from .local_response_norm import *
from .l2_norm import * from .l2_normalize import *
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""TVM operator for l2norm""" """TVM operator for l2 normalize"""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
import topi from .. import cpp
@tvm.target.generic_func @tvm.target.generic_func
def l2norm_instance(data, eps, axis=None): def l2_normalize(data, eps, axis=None):
"""Perform L2norm on the input data """Perform L2 normalization on the input data
For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps)) For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
...@@ -26,10 +26,4 @@ def l2norm_instance(data, eps, axis=None): ...@@ -26,10 +26,4 @@ def l2norm_instance(data, eps, axis=None):
output : tvm.Tensor output : tvm.Tensor
4-D output with same shape 4-D output with same shape
""" """
assert len(data.shape) == 4, "only support 4-dim lrn" return cpp.nn.l2_normalize(data, eps, axis)
dot_value = topi.cpp.pow(data, 2.0)
sum_value = topi.sum(dot_value, axis=axis, keepdims=True)
expand_sum = topi.broadcast_to(sum_value, data.shape)
return topi.broadcast_div(data, topi.sqrt(\
tvm.compute(expand_sum.shape, lambda i, j, k, l:\
tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm')))
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
"""TVM operator for local response norm compute.""" """TVM operator for local response norm compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
import topi from .. import cpp
from .pad import pad
@tvm.target.generic_func @tvm.target.generic_func
def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2): def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
...@@ -42,27 +41,4 @@ def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2): ...@@ -42,27 +41,4 @@ def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
output : tvm.Tensor output : tvm.Tensor
4-D output with same shape 4-D output with same shape
""" """
assert len(data.shape) == 4, "only support 4-dim lrn" return cpp.nn.lrn(data, size, axis, alpha, beta, bias)
assert (size % 2) == 1, "size should be odd number"
assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC"
##Add padding on left & right of size radius first
pad_after = pad_before = [0, 0, 0, 0]
pad_after[axis] = pad_before[axis] = (size//2)
pad_data = pad(data, pad_before, pad_after, name="pad_data")
rxs = tvm.reduce_axis((0, size), name='rxs')
if axis == 1:
#NCHW layout
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l],
axis=rxs))
elif axis == 3:
#NHWC layout
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs],
axis=rxs))
sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power(
(bias + (alpha * sqr_sum[i, j, k, l] / size)), beta))
return topi.broadcast_div(data, sqr_sum_up)
"""scheduler for normalization functions on rocm backend""" """scheduler for normalization functions on rocm backend"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import topi import tvm
from .. import generic from .. import generic
from .. import cpp
@generic.schedule_lrn.register(["rocm", "gpu"]) @generic.schedule_lrn.register(["rocm", "gpu"])
def schedule_lrn(outs): def schedule_lrn(outs):
return topi.cuda.schedule_lrn(outs) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_lrn(cpp_target, outs)
@generic.schedule_l2norm.register(["rocm", "gpu"]) @generic.schedule_l2_normalize.register(["rocm", "gpu"])
def schedule_l2norm(outs): def schedule_l2_normalize(outs):
return topi.cuda.schedule_l2norm(outs) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_l2_normalize(cpp_target, outs)
...@@ -16,3 +16,5 @@ from .bilinear_resize_python import bilinear_resize_python ...@@ -16,3 +16,5 @@ from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python from .reorg_python import reorg_python
from .region_python import region_python from .region_python import region_python
from .shortcut_python import shortcut_python from .shortcut_python import shortcut_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""L2 normalize in python"""
import numpy as np
def l2_normalize_python(a_np, eps, axis=None):
"""L2 normalize operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
eps : float
epsilon constant value
axis : list of int
axis over the normalization applied
Returns
-------
l2_normalize_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
dot_value = np.power(a_np, 2.0)
sqr_sum = np.sum(dot_value, axis, keepdims=True)
sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
l2_normalize_out = np.divide(a_np, sqrt_sum)
return l2_normalize_out
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""LRN in python"""
from itertools import product
import numpy as np
def lrn_python(a_np, size, axis, bias, alpha, beta):
"""Local response normalization operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
size : int
normalization window size
axis : int
input data layout channel axis
bias : float
offset to avoid dividing by 0. constant value
alpha : float
constant value
beta : float
exponent constant value
Returns
-------
lrn_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
for i, j, k, l in product(*[range(_axis) for _axis in a_np.shape]):
axis_size = a_np.shape[axis]
if axis == 1:
#NCHW layout
sum_start = j-radius if j-radius >= 0 else 0
sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
a_np[i, sum_start:sum_end, k, l])
elif axis == 3:
#NHWC layout
sum_start = l-radius if l-radius >= 0 else 0
sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
a_np[i, j, k, sum_start:sum_end])
sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
lrn_out = np.divide(a_np, sqr_sum_up)
return lrn_out
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <topi/nn/softmax.h> #include <topi/nn/softmax.h>
#include <topi/nn/upsampling.h> #include <topi/nn/upsampling.h>
#include <topi/nn/l2_normalize.h>
#include <topi/nn/local_response_norm.h>
#include <topi/vision/reorg.h> #include <topi/vision/reorg.h>
#include <topi/image/resize.h> #include <topi/image/resize.h>
...@@ -39,6 +41,7 @@ ...@@ -39,6 +41,7 @@
#include <topi/cuda/reduction.h> #include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h> #include <topi/cuda/softmax.h>
#include <topi/cuda/vision.h> #include <topi/cuda/vision.h>
#include <topi/cuda/normalization.h>
#include <topi/x86/bnn.h> #include <topi/x86/bnn.h>
#include <topi/x86/default.h> #include <topi/x86/default.h>
...@@ -46,6 +49,7 @@ ...@@ -46,6 +49,7 @@
#include <topi/rocm/dense.h> #include <topi/rocm/dense.h>
#include <topi/rocm/vision.h> #include <topi/rocm/vision.h>
#include <topi/rocm/normalization.h>
namespace topi { namespace topi {
...@@ -359,6 +363,20 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax") ...@@ -359,6 +363,20 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
*rv = nn::log_softmax(args[0]); *rv = nn::log_softmax(args[0]);
}); });
/* Ops from nn/l2_normalize.h */
TVM_REGISTER_GLOBAL("topi.nn.l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::l2_normalize(args[0], static_cast<double>(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::lrn(args[0], args[1], args[2],
static_cast<double>(args[3]),
static_cast<double>(args[4]),
static_cast<double>(args[5]));
});
TVM_REGISTER_GLOBAL("topi.vision.reorg") TVM_REGISTER_GLOBAL("topi.vision.reorg")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::reorg(args[0], args[1]); *rv = vision::reorg(args[0], args[1]);
...@@ -435,6 +453,17 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_region") ...@@ -435,6 +453,17 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_region")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_region(args[0], args[1]); *rv = topi::rocm::schedule_region(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_l2_normalize(args[0], args[1]);
});
/* CUDA schedules */ /* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
...@@ -481,6 +510,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_region") ...@@ -481,6 +510,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_region")
*rv = topi::cuda::schedule_region(args[0], args[1]); *rv = topi::cuda::schedule_region(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_lrn(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
});
/*! \brief Builder function for instantiating schedules. */ /*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function< using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>; tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
......
"""Test code for L2 norm""" """Test code for L2 normalization"""
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
from topi.util import get_const_tuple from topi.util import get_const_tuple
import topi.testing
def l2norm_instance_python(a_np, eps, axis=None): def verify_l2_normalize(ishape, eps, axis=None):
"""L2 norm operator in NCHW layout.
Parameters A = tvm.placeholder(ishape, name='A')
---------- B = topi.nn.l2_normalize(A, eps, axis)
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
eps : float
epsilon constant value
axis : list of int
axis over the normalization applied
Returns
-------
l2norm_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, axis1, axis2, axis3 = a_np.shape
sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
dot_value = np.power(a_np, 2.0)
sqr_sum = np.sum(dot_value, axis, keepdims=True)
sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
return np.divide(a_np, sqrt_sum)
def verify_l2norm(n, c, h, w, eps, axis=None):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.l2norm_instance(A, eps, axis)
dtype = A.dtype dtype = A.dtype
a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype) a_np = np.random.uniform(size=ishape).astype(dtype)
b_np = l2norm_instance_python(a_np, eps, axis) b_np = topi.testing.l2_normalize_python(a_np, eps, axis)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -47,7 +21,10 @@ def verify_l2norm(n, c, h, w, eps, axis=None): ...@@ -47,7 +21,10 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_l2norm(B) if device == 'llvm':
s = topi.generic.schedule_l2_normalize([B])
else:
s = topi.cuda.schedule_l2_normalize([B])
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
...@@ -57,14 +34,14 @@ def verify_l2norm(n, c, h, w, eps, axis=None): ...@@ -57,14 +34,14 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']: for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
def test_l2norm(): def test_l2_normalize():
verify_l2norm(1, 3, 20, 20, 0.001) verify_l2_normalize((1, 3, 20, 20), 0.001)
verify_l2norm(1, 3, 20, 20, 0.001, 1) verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2norm(1, 3, 20, 20, 0.001, (1, 2)) verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
verify_l2norm(1, 3, 20, 20, 0.001, (2, 3)) verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3))
verify_l2norm(1, 3, 20, 20, 0.001, (0, 3)) verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3))
verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3)) verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3))
if __name__ == "__main__": if __name__ == "__main__":
test_l2norm() test_l2_normalize()
...@@ -3,63 +3,7 @@ import numpy as np ...@@ -3,63 +3,7 @@ import numpy as np
import tvm import tvm
import topi import topi
from topi.util import get_const_tuple from topi.util import get_const_tuple
import topi.testing
def lrn_python(a_np, size, axis, bias, alpha, beta):
"""Local response norm operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
size : int
normalisation window size
axis : int
input data layout channel axis
bias : float
offset to avoid dividing by 0. constant value
alpha : float
contant valie
beta : float
exponent constant value
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
axis0, axis1, axis2, axis3 = a_np.shape
radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
def sum_dot_values(i, j, k, l):
axis_size = a_np.shape[axis]
if (axis == 1):
#NCHW layout
sum_start = j-radius if j-radius >= 0 else 0
sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
a_np[i, sum_start:sum_end, k, l])
elif (axis == 3):
#NHWC layout
sum_start = l-radius if l-radius >= 0 else 0
sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
a_np[i, j, k, sum_start:sum_end])
for i in range(axis0):
for j in range(axis1):
for k in range(axis2):
for l in range(axis3):
sum_dot_values(i, j, k, l)
sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
return np.divide(a_np, sqr_sum_up)
def verify_lrn(shape, size, axis, bias, alpha, beta): def verify_lrn(shape, size, axis, bias, alpha, beta):
A = tvm.placeholder(shape, name='A') A = tvm.placeholder(shape, name='A')
...@@ -67,16 +11,19 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): ...@@ -67,16 +11,19 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
dtype = A.dtype dtype = A.dtype
a_np = np.random.uniform(size=shape).astype(dtype) a_np = np.random.uniform(size=shape).astype(dtype)
b_np = lrn_python(a_np, size, axis, bias, alpha, beta) b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) if not tvm.module.enabled(device):
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_lrn(B) if device == 'llvm':
s = topi.generic.schedule_lrn([B])
else:
s = topi.cuda.schedule_lrn([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
...@@ -87,9 +34,9 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): ...@@ -87,9 +34,9 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
check_device(device) check_device(device)
def test_lrn(): def test_lrn():
verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5) verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5) verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75) verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
if __name__ == "__main__": if __name__ == "__main__":
test_lrn() test_lrn()
"""Test code for l2 normalization"""
import numpy as np
import tvm
import topi
import logging
from topi.util import get_const_tuple
import topi.testing
def verify_l2_normalize(shape, eps, axis=None):
'''Verify l2 normalization operator by comparing outputs from tvm and numpy implementation'''
A = tvm.placeholder(shape, name='A')
B = topi.cpp.nn.l2_normalize(A, eps, axis)
dtype = A.dtype
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = topi.testing.l2_normalize_python(a_np, eps, axis)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_l2_normalize(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device, name="l2_normalize")
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']:
check_device(device)
def test_l2_normalize():
verify_l2_normalize((1, 3, 20, 20), 0.001)
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3))
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3))
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3))
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_l2_normalize()
"""Test code for LRN"""
import numpy as np
import tvm
import topi
import logging
from topi.util import get_const_tuple
import topi.testing
def verify_lrn(shape, size, axis, bias, alpha, beta):
'''Verify Local response normalization operator by comparing outputs from tvm and numpy implementation'''
A = tvm.placeholder(shape, name='A')
B = topi.cpp.nn.lrn(A, size, axis, alpha, beta, bias)
dtype = A.dtype
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_lrn(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-1)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']:
check_device(device)
def test_lrn():
verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment