Unverified Commit 90b2a1eb by Animesh Jain Committed by GitHub

[Relay][Topi] Use SimplifyInference for L2 Normazlization. (#4795)

parent 10f85d03
......@@ -611,22 +611,6 @@ def schedule_lrn(attrs, outs, target):
reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
# l2_normalize
@reg.register_compute("nn.l2_normalize")
def compute_l2_normalize(attrs, inputs, out_dtype, target):
"""Compute definition of l2 normalize"""
return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]
@reg.register_schedule("nn.l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with target:
return topi.generic.schedule_l2_normalize(outs)
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
# upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective)
......
......@@ -397,6 +397,11 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Maximum(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("maximum");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
......
......@@ -124,13 +124,26 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return out;
}
Expr L2NormToInferUnpack(const Attrs attrs, Expr data) {
const auto param = attrs.as<L2NormalizeAttrs>();
CHECK(param);
Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->eps));
Expr sqr = Multiply(data, data);
Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon);
Expr sqrt = Sqrt(sum);
return Divide(data, sqrt);
}
class InferenceSimplifier : public ExprMutator {
public:
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")) {}
layer_norm_op_(Op::Get("nn.layer_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
Expr VisitExpr_(const TupleGetItemNode* n) final {
Expr new_e = ExprMutator::VisitExpr_(n);
......@@ -155,12 +168,15 @@ class InferenceSimplifier : public ExprMutator {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op == layer_norm_op_) {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == l2_norm_op_) {
const auto* call = new_n.as<CallNode>();
return L2NormToInferUnpack(call->attrs, call->args[0]);
}
return new_n;
}
......@@ -173,6 +189,7 @@ class InferenceSimplifier : public ExprMutator {
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
const Op& l2_norm_op_;
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
};
......
......@@ -71,55 +71,6 @@ inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
return s;
}
/*!
* \brief Create a CUDA schedule for L2 normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag) || op->tag == "l2_normalize") {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "comm_reduce") {
ScheduleReduce(target, op, s, false);
for (auto tensor : op->InputTensors()) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
int num_thread = 64;
Tensor l2_normalize = outs[0];
IterVar block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
IterVar xto, xti;
s[l2_normalize].split_by_nparts(l2_normalize->op.as<ComputeOpNode>()->axis[1],
num_thread, &xto, &xti);
s[l2_normalize].bind(l2_normalize->op.as<ComputeOpNode>()->axis[0], block_x);
s[l2_normalize].bind(xto, thread_x);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_NORMALIZATION_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief l2 normalization op constructions
* \file nn/l2_normalize.h
*/
#ifndef TOPI_NN_L2_NORMALIZE_H_
#define TOPI_NN_L2_NORMALIZE_H_
#include <tvm/te/operation.h>
#include <topi/tags.h>
#include <string>
#include <algorithm>
namespace topi {
namespace nn {
using namespace tvm;
using namespace tvm::te;
/*!
* \brief L2 normalization inference operator
*
* \param data The input tensor. 4-D with shape [batch, channel, height, width]
* \param eps Epsilon to prevent div by 0
* \param axis Axes over the normalization applied
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the l2 normalization operation
*/
inline Tensor l2_normalize(const Tensor& data,
float eps,
const Array<Integer>& axis,
std::string name = "tensor",
std::string tag = "l2_normalize") {
for (size_t i = 0; i < axis.size(); ++i) {
int ax = topi::detail::GetConstInt(axis[i]);
CHECK_LT(ax, data->shape.size()) <<
"Axis " << ax << " exceeds input data dim " <<
data->shape.size();
}
auto input_shape = data->shape;
Tensor dot_value = topi::power(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true);
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
return topi::divide(data,
topi::sqrt(tvm::te::compute(expand_sum->shape,
[&](const Array<Var>& i){
return (max(expand_sum(i), eps));
}, name, tag)));
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_L2_NORMALIZE_H_
......@@ -44,17 +44,6 @@ inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_lrn(target, outs);
}
/*!
* \brief Create a rocm schedule for L2 Normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_l2_normalize(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_NORMALIZATION_H_
......@@ -31,7 +31,7 @@ from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
from .nn import schedule_lrn, schedule_l2_normalize
from .nn import schedule_lrn
from .batch_matmul import schedule_batch_matmul
from .vision import *
from . import ssd
......
......@@ -40,22 +40,3 @@ def schedule_lrn(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_lrn(cpp_target, outs)
@generic.schedule_l2_normalize.register(["cuda"])
def schedule_l2_normalize(outs):
"""Schedule for L2 normalize
Parameters
----------
outs: Array of Tensor
The computation graph description of L2 normalize
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_l2_normalize(cpp_target, outs)
......@@ -649,24 +649,6 @@ def schedule_lrn(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_l2_normalize(outs):
"""Schedule for l2 normalize
Parameters
----------
outs: Array of Tensor
The computation graph description of l2 normalize
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_sparse_dense(outs):
......
......@@ -38,7 +38,6 @@ from .upsampling import *
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
from .pad import *
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""TVM operator for l2 normalize"""
from __future__ import absolute_import
import tvm
from .. import cpp
@tvm.target.generic_func
def l2_normalize(data, eps, axis=None):
"""Perform L2 normalization on the input data
For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
Parameters
----------
data : tvm.Tensor
4-D with NCHW or NHWC layout
eps : float
epsilon value
axis : list of int
axis over the normalization applied
Returns
-------
output : tvm.Tensor
4-D output with same shape
"""
return cpp.nn.l2_normalize(data, eps, axis)
......@@ -26,9 +26,3 @@ def schedule_lrn(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_lrn(cpp_target, outs)
@generic.schedule_l2_normalize.register(["rocm", "gpu"])
def schedule_l2_normalize(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_l2_normalize(cpp_target, outs)
......@@ -43,7 +43,6 @@
#include <topi/nn/mapping.h>
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
#include <topi/nn/l2_normalize.h>
#include <topi/nn/local_response_norm.h>
#include <topi/nn/batch_matmul.h>
......@@ -554,12 +553,6 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
*rv = nn::log_softmax(args[0]);
});
/* Ops from nn/l2_normalize.h */
TVM_REGISTER_GLOBAL("topi.nn.l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::l2_normalize(args[0], static_cast<double>(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::lrn(args[0], args[1], args[2],
......@@ -674,11 +667,6 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_l2_normalize(args[0], args[1]);
});
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -725,11 +713,6 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
*rv = topi::cuda::schedule_lrn(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
});
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for L2 normalization"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
import topi.testing
def verify_l2_normalize(ishape, eps, axis=None):
A = tvm.placeholder(ishape, name='A')
B = topi.nn.l2_normalize(A, eps, axis)
dtype = A.dtype
a_np = np.random.uniform(size=ishape).astype(dtype)
b_np = topi.testing.l2_normalize_python(a_np, eps, axis)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
s = topi.generic.schedule_l2_normalize([B])
else:
s = topi.cuda.schedule_l2_normalize([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
def test_l2_normalize():
verify_l2_normalize((1, 3, 20, 20), 0.001)
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3))
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3))
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3))
if __name__ == "__main__":
test_l2_normalize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment