Commit 80f8e982 by David Hirvonen Committed by Tianqi Chen

remove batch_norm_inference (#2626)

parent c59a78e5
/*!
* Copyright (c) 2017 by Contributors
* \brief Batch normalization op constructions
* \file nn/batch_norm.h
*/
#ifndef TOPI_NN_BATCH_NORM_H_
#define TOPI_NN_BATCH_NORM_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Batch normalization inference operator with NCHW layout
*
* \param x The input tensor. 4-D with shape [batch, channel, height, width]
* \param gamma 1-D with shape [channel]
* \param beta 1-D with shape [channel]
* \param moving_mean 1-D with shape [channel]
* \param moving_var 1-D with shape [channel]
* \param eps Epsilon to prevent div by 0
* \param fix_gamma Fix gamma while training
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the batch normalization operation
*/
inline Tensor batch_norm_inference(const Tensor& x,
const Tensor& gamma,
const Tensor& beta,
const Tensor& moving_mean,
const Tensor& moving_var,
float eps,
bool fix_gamma,
std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), 4) << "Batch norm requires 4-D input";
Tensor out;
if (fix_gamma) {
out = tvm::compute(
x->shape,
[&](const Array<Var>& indices) {
auto c = Array<Var>({ indices[1] });
return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) + beta(c);
}, name, tag);
} else {
out = tvm::compute(
x->shape,
[&](const Array<Var>& indices) {
auto c = Array<Var>({ indices[1] });
return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) * gamma(c) + beta(c);
}, name, tag);
}
return out;
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_BATCH_NORM_H_
......@@ -2,7 +2,6 @@
"""Neural network operators"""
from __future__ import absolute_import as _abs
from .batch_norm import *
from .conv2d import *
from .depthwise_conv2d import *
from .elemwise import *
......
"""TVM operator batch normalization compute."""
from __future__ import absolute_import
import tvm
from .. import tag
@tvm.tag_scope(tag=tag.BROADCAST)
def batch_norm_inference(data, gamma, beta, moving_mean, moving_var, eps, fix_gamma):
"""Batch normalization inference operator in NCHW layout.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]
gamma : tvm.Tensor
1-D with shape [channel]
beta : tvm.Tensor
1-D with shape [channel]
moving_mean : tvm.Tensor
1-D with shape [channel]
moving_var : tvm.Tensor
1-D with shape [channel]
eps : float
Epsilon to prevent div 0.
fix_gamma : boolean
Fix gamma while training
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, height, width]
mean : tvm.Tensor
1-D with shape [channel]
var : tvm.Tensor
1-D with shape [channel]
"""
assert len(data.shape) == 4, "only support 4-dim batch norm"
batch, channel, height, width = data.shape
if fix_gamma:
out = tvm.compute((batch, channel, height, width), \
lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \
tvm.intrin.sqrt(moving_var[c] + eps) + beta[c])
else:
out = tvm.compute((batch, channel, height, width), \
lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \
tvm.intrin.sqrt(moving_var[c] + eps) * gamma[c] + beta[c])
mean = tvm.compute((C, ), lambda c: moving_mean[c])
var = tvm.compute((C, ), lambda c: moving_var[c])
return [out, mean, var]
......@@ -17,7 +17,6 @@
#include <topi/reduction.h>
#include <topi/transform.h>
#include <topi/nn/batch_norm.h>
#include <topi/nn/bnn.h>
#include <topi/nn/dense.h>
#include <topi/nn/dilate.h>
......@@ -328,18 +327,6 @@ TVM_REGISTER_GLOBAL("topi.nn.upsampling")
*rv = nn::upsampling(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::batch_norm_inference(args[0],
args[1],
args[2],
args[3],
args[4],
static_cast<double>(args[5]),
args[6]);
});
/* Ops from nn/bnn.h */
TVM_REGISTER_GLOBAL("topi.nn.binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment