Commit 12257ddd by Yinghai Lu Committed by Leyuan Wang

Implement relay nn.bias_add compute in C++ (#3027)

* Implement nn.bias_add compute in C++

* Address comments

* Remove unnecessary check
parent c91f7141
...@@ -182,20 +182,6 @@ def schedule_conv2d_transpose(attrs, outs, target): ...@@ -182,20 +182,6 @@ def schedule_conv2d_transpose(attrs, outs, target):
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# bias_add # bias_add
@reg.register_compute("nn.bias_add")
def compute_bias_add(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d_transpose"""
axis = attrs.axis
bias = inputs[1]
data_ndim = len(inputs[0].shape)
if axis < 0:
axis = axis + data_ndim
num_newaxis = data_ndim - axis - 1
if num_newaxis:
bias = topi.expand_dims(bias, axis=1, num_newaxis=num_newaxis)
return [topi.add(inputs[0], bias)]
reg.register_schedule("nn.bias_add", schedule_injective) reg.register_schedule("nn.bias_add", schedule_injective)
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST) reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h> #include <tvm/relay/attrs/image.h>
#include <topi/nn.h> #include <topi/nn.h>
#include <topi/nn/bias_add.h>
#include <topi/nn/softmax.h> #include <topi/nn/softmax.h>
#include <topi/nn/flatten.h> #include <topi/nn/flatten.h>
#include <vector> #include <vector>
...@@ -90,7 +91,12 @@ RELAY_REGISTER_OP("nn.bias_add") ...@@ -90,7 +91,12 @@ RELAY_REGISTER_OP("nn.bias_add")
.add_argument("data", "nD Tensor", "Input data.") .add_argument("data", "nD Tensor", "Input data.")
.add_argument("bias", "1D Tensor", "Bias.") .add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel); .add_type_rel("BiasAdd", BiasAddRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
const auto* param = attrs.as<BiasAddAttrs>();
return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
});
// relay.nn.dense // relay.nn.dense
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief bias_add op constructions
* \file nn/bias_add.h
*/
#ifndef TOPI_NN_BIAS_ADD_H_
#define TOPI_NN_BIAS_ADD_H_
#include <string>
#include "topi/tags.h"
#include "topi/broadcast.h"
#include "topi/transform.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
/*!
* \brief Creates an operation that calculates data + bias
*
* \param data Tensor with shape [batch, in_dim]
* \param bias Tensor with shape [batch].
*
* \return Tensor with shape [batch, in_dim]
*/
inline tvm::Tensor bias_add(const tvm::Tensor& data, const tvm::Tensor& bias, int axis) {
int data_ndim = data->shape.size();
if (axis < 0) {
axis += data_ndim;
}
int num_newaxis = data_ndim - axis - 1;
return add(data, (num_newaxis ? expand_dims(bias, 1, num_newaxis) : bias));
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_BIAS_ADD_H_
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include <topi/reduction.h> #include <topi/reduction.h>
#include <topi/transform.h> #include <topi/transform.h>
#include <topi/nn/bias_add.h>
#include <topi/nn/bnn.h> #include <topi/nn/bnn.h>
#include <topi/nn/dense.h> #include <topi/nn/dense.h>
#include <topi/nn/dilate.h> #include <topi/nn/dilate.h>
...@@ -400,6 +401,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense") ...@@ -400,6 +401,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense")
*rv = nn::dense(args[0], args[1], args[2]); *rv = nn::dense(args[0], args[1], args[2]);
}); });
/* Ops from nn/bias_add.h */
TVM_REGISTER_GLOBAL("topi.nn.bias_add")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::bias_add(args[0], args[1], args[2]);
});
/* Ops from nn/batch_matmul.h */ /* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment