/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file nn.cc * \brief Property def of nn operators. */ #include <tvm/tir/data_layout.h> #include <tvm/relay/op.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/image.h> #include <topi/nn.h> #include <topi/nn/bias_add.h> #include <topi/nn/softmax.h> #include <topi/nn/flatten.h> #include <vector> #include <string> #include "../type_relations.h" #include "../../transforms/infer_layout_util.h" #include "../op_common.h" #include "nn.h" namespace tvm { namespace relay { // relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); bool BiasAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; const BiasAddAttrs* param = attrs.as<BiasAddAttrs>(); CHECK(param != nullptr); int axis = param->axis; if (axis < 0) { axis = data->shape.size() + axis; } CHECK_LE(axis, static_cast<int>(data->shape.size())) << "axis " << param->axis << " is out of range"; // assign output type reporter->Assign(types[1], TensorType( {data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; } // Positional relay function to create dense operator used by frontend FFI. Expr MakeBiasAdd(Expr data, Expr bias, int axis) { auto attrs = make_object<BiasAddAttrs>(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return Call(op, {data, bias}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add") .set_body_typed(MakeBiasAdd); RELAY_REGISTER_OP("nn.bias_add") .describe(R"code(Add bias to an axis of the input. )code" TVM_ADD_FILELINE) .set_attrs_type<BiasAddAttrs>() .set_num_inputs(2) .add_argument("data", "nD Tensor", "Input data.") .add_argument("bias", "1D Tensor", "Bias.") .set_support_level(1) .add_type_rel("BiasAdd", BiasAddRel) .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { const auto* param = attrs.as<BiasAddAttrs>(); return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; }); // relay.nn.fifo_buffer TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { auto attrs = make_object<FIFOBufferAttrs>(); attrs->axis = axis; static const Op& op = Op::Get("nn.fifo_buffer"); return Call(op, {input, buffer}, Attrs(attrs), {}); } bool FIFOBufferRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* input = types[0].as<TensorTypeNode>(); const auto* buffer = types[1].as<TensorTypeNode>(); const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>(); if (input == nullptr || buffer == nullptr) { return false; } CHECK(param != nullptr); CHECK_EQ(input->shape.size(), buffer->shape.size()); const size_t buffer_axis = static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis : param->axis); reporter->Assert(buffer_axis < buffer->shape.size()); for (size_t i = 0; i < buffer->shape.size(); ++i) { if (i != buffer_axis) { reporter->AssertEQ(input->shape[i], buffer->shape[i]); } } reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]); Array<tvm::PrimExpr> oshape = buffer->shape; reporter->Assign(types[2], TensorType(oshape, buffer->dtype)); return true; } TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer") .set_body_typed(MakeFIFOBuffer); RELAY_REGISTER_OP("nn.fifo_buffer") .describe(R"code(FIFO buffer Compute equivalent of ``` concat(buffer, data, axis=axis) \ .slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis]) ``` Useful for * Encoding explicit re-use of computation in convolution ops operated on a sliding window input * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. )code" TVM_ADD_FILELINE) .set_attrs_type<FIFOBufferAttrs>() .set_num_inputs(2) .add_argument("data", "Tensor", "Latest input") .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs") .set_support_level(3) .add_type_rel("FIFOBuffer", FIFOBufferRel); // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); // Positional relay function to create dense operator used by frontend FFI. Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object<DenseAttrs>(); attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); return Call(op, {data, weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.dense") .set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) .set_attrs_type<DenseAttrs>() .set_num_inputs(2) .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) .add_type_rel("Dense", DenseRel<DenseAttrs>); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. Expr MakeLeakyRelu(Expr data, double alpha) { auto attrs = make_object<LeakyReluAttrs>(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu") .set_body_typed(MakeLeakyRelu); RELAY_REGISTER_OP("nn.leaky_relu") .describe(R"code(Leaky version of a Rectified Linear Unit. `y = x > 0 ? x : alpha * x` )code" TVM_ADD_FILELINE) .set_attrs_type<LeakyReluAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "Input data.") .set_support_level(3) .add_type_rel("Identity", IdentityRel) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FTVMCompute>( "FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { const auto* param = attrs.as<LeakyReluAttrs>(); return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) }; }); // relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; const PReluAttrs* param = attrs.as<PReluAttrs>(); CHECK(param != nullptr); CHECK(param->axis < static_cast<int>(data->shape.size())) << "Wrong axis (" << param->axis << ")value."; // assign alpha type Array<IndexExpr> alpha_shape({data->shape[param->axis]}); reporter->Assign(types[1], TensorType(alpha_shape, data->dtype)); // assign output type reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } template<typename T> Array<Array<Layout> > PReluInferCorrectLayout( const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, const Array<tvm::relay::Type> &old_in_types) { CHECK_EQ(old_in_layouts.size(), 2U); CHECK_EQ(old_in_types.size(), 2U); Layout data_layout = old_in_layouts[0]; if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 2U); } return Array<Array<Layout> >{{data_layout, Layout("C")}, {data_layout}}; } // Positional relay function to create prelu operator used by frontend FFI. Expr MakePRelu(Expr data, Expr alpha, int axis) { auto attrs = make_object<PReluAttrs>(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu") .set_body_typed(MakePRelu); RELAY_REGISTER_OP("nn.prelu") .describe(R"code(Parametric version of a Rectified Linear Unit. It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`, where :math:`*` is an channelwise multiplication for each sample in the batch. )code" TVM_ADD_FILELINE) .set_attrs_type<PReluAttrs>() .set_num_inputs(2) .add_argument("data", "Tensor", "Input data.") .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>) .set_attr<FTVMCompute>( "FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { const auto* param = attrs.as<PReluAttrs>(); return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)}; }); // relay.softmax TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") .set_body_typed([](Expr data, int axis) { auto attrs = make_object<SoftmaxAttrs>(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("nn.softmax") .describe(R"code(Softmax layer. .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} .. note:: This operator can be optimized away for inference. - **data**: The input data )code" TVM_ADD_FILELINE) .set_attrs_type<SoftmaxAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel); // relay.nn.log_softmax TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") .set_body_typed([](Expr data, int axis) { auto attrs = make_object<SoftmaxAttrs>(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("nn.log_softmax") .describe(R"code(Computes log softmax. .. math:: \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)} .. note:: This operator can be optimized away for inference. - **data**: The input data )code" TVM_ADD_FILELINE) .set_attrs_type<SoftmaxAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel) .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { const auto* param = attrs.as<SoftmaxAttrs>(); CHECK(param != nullptr); CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1) << "log_softmax currently only works on last dimension"; return Array<te::Tensor>{ topi::nn::log_softmax(inputs[0]) }; }); // relay.nn.batch_flatten bool BatchFlattenRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; if (data->shape.size() == 0) return false; auto target_dim = tir::make_const(DataType::Int(32), 1); for (uint32_t i = 1; i < data->shape.size(); ++i) { if (!data->shape[i].as<tir::AnyNode>()) { target_dim = target_dim * data->shape[i]; } else { target_dim = data->shape[i]; break; } } std::vector<IndexExpr> oshape({data->shape[0], target_dim}); // assign output type reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } Expr MakeBatchFlatten(Expr data) { static const Op& op = Op::Get("nn.batch_flatten"); return Call(op, {data}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten") .set_body_typed(MakeBatchFlatten); RELAY_REGISTER_OP("nn.batch_flatten") .describe(R"code(Flattens the input into a 2-D array. For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. Example:: x = [[ [1,2,3], [4,5,6], [7,8,9] ], [ [1,2,3], [4,5,6], [7,8,9] ]], batch_flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] )code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("BatchFlatten", BatchFlattenRel) .set_attr<FTVMCompute>( "FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) }; }); // relu TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") .set_body_typed([](Expr data) { static const Op& op = Op::Get("nn.relu"); return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("nn.relu") .describe(R"code(Returns the relu input array, computed element-wise. .. math:: max(x, 0) )code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type) { return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) }; }); // Positional relay function to create LRN operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(LRNAttrs); Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) { auto attrs = make_object<LRNAttrs>(); attrs->size = size; attrs->axis = axis; attrs->alpha = alpha; attrs->beta = beta; attrs->bias = bias; static const Op& op = Op::Get("nn.lrn"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") .set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") .describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, where n is the size of each local region, and the sum is taken over the region centered at that value (zero padding is added where necessary). .. math:: data / (bias + (alpha * sum_data ^2 /size))^beta - **data**: The input tensor. )code" TVM_ADD_FILELINE) .set_attrs_type<LRNAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); // Positional relay function to create L2Normalize operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); Expr MakeL2Normalize(Expr data, double eps, Array<Integer> axis) { auto attrs = make_object<L2NormalizeAttrs>(); attrs->eps = eps; attrs->axis = std::move(axis); static const Op& op = Op::Get("nn.l2_normalize"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") .set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") .describe(R"code(L2 Normalization layer. Normalizes along dimension axis using an L2 norm .. math:: output = x / sqrt(max(sum(x^2), epsilon)) - **data**: The input tensor. )code" TVM_ADD_FILELINE) .set_attrs_type<L2NormalizeAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); // Dropout TVM_REGISTER_NODE_TYPE(DropoutAttrs); bool DropoutRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; // dropout returns the original tensor with dropout applied // and a mask tensor (1.0 where element not dropped, 0.0 where dropped) auto ret_type = TensorType(data->shape, data->dtype); reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type}))); return true; } Expr MakeDropout(Expr data, double rate) { auto attrs = make_object<DropoutAttrs>(); attrs->rate = rate; static const Op& op = Op::Get("nn.dropout"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") .set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") .describe(R"code(Applies the dropout operation to the input array. During training, each element of the input is set to zero with probability ``p``. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. )code" TVM_ADD_FILELINE) .set_attrs_type<DropoutAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_support_level(1) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Dropout", DropoutRel); // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, const Array<tvm::relay::Type>& old_in_types) { BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>()); Array<Array<IndexExpr>> old_in_shapes; for (auto old_in_t : old_in_types) { CHECK(old_in_t.as<TensorTypeNode>()); old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); } size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis); Layout ret = Layout::Undef(); // If new_in_layouts are defined, this code tries to modify the layout. if (new_in_layouts.defined() && old_in_layouts.defined()) { // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout. const auto& bn_dim = old_in_layouts[0][axis]; auto new_index = new_in_layouts[0].IndexOf(bn_dim); param->axis = new_index; ret = new_in_layouts[0]; } else if (old_in_layouts.defined()) { ret = old_in_layouts[0]; } // BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout. Layout c_layout = Layout("C"); return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout}, {ret, c_layout, c_layout}}; } bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; const BatchNormAttrs* param = attrs.as<BatchNormAttrs>(); // axis of -1 means use the last dimension CHECK(param->axis >= -1 && param->axis < (int)data->shape.size()); int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1; auto axis_size = data->shape[axis]; // if we are using beta and gamma, they need to be of shape (dim,) reporter->Assign(types[1], TensorType({axis_size}, data->dtype)); reporter->Assign(types[2], TensorType({axis_size}, data->dtype)); reporter->Assign(types[3], TensorType({axis_size}, data->dtype)); reporter->Assign(types[4], TensorType({axis_size}, data->dtype)); // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) std::vector<Type> fields; auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype); fields.push_back(TensorType(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); reporter->Assign(types[5], TupleType(Array<Type>(fields))); return true; } Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, double epsilon, bool center, bool scale) { auto attrs = make_object<BatchNormAttrs>(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.batch_norm"); return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") .set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. .. math:: data\_mean[i] = mean(data[:,i,:,...]) \\ data\_var[i] = var(data[:,i,:,...]) Then compute the normalized output, which has the same shape as input, as following: .. math:: out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} \ * gamma[i] + beta[i] Both *mean* and *var* returns a scalar by treating the input as a vector. Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*. Besides the inputs and the outputs, this operator accepts two auxiliary states, ``moving_mean`` and ``moving_var``, which are *k*-length vectors. They are global statistics for the whole dataset, which are updated by:: moving_mean = moving_mean * momentum + data_mean * (1 - momentum) moving_var = moving_var * momentum + data_var * (1 - momentum) The parameter ``axis`` specifies which axis of the input shape denotes the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel axis to be the last item in the input shape. .. note:: This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) .set_attrs_type<BatchNormAttrs>() .set_num_inputs(5) .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .add_argument("moving_mean", "Tensor", "Running mean of input.") .add_argument("moving_var", "Tensor", "Running variance of input.") .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout) .set_support_level(1) .add_type_rel("BatchNorm", BatchNormRel); // instance_norm TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); bool InstanceNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>(); int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); CHECK(axis >= 0 && axis < (int)data->shape.size()); reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { auto attrs = make_object<InstanceNormAttrs>(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.instance_norm"); return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv); }); RELAY_REGISTER_OP("nn.instance_norm") .describe(R"code(Instance Normalization (Ulyanov and et al., 2016) Applies instance normalization to the n-dimensional input array. .. math:: out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}} * gamma + beta The instance normalization is similar to batch normalization, but unlike batch normalization, the mean and var are calculated per-dimension separately for each object(instance) in a mini-batch, not over a batch. And the same normalization is applied both at test and train time. Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*. The parameter ``axis`` specifies which axis of the input shape denotes the 'channel'. The default is 1. Specifying -1 sets the channel axis to be the last item in the input shape. .. note:: This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) .set_attrs_type<InstanceNormAttrs>() .set_num_inputs(3) .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_support_level(1) .add_type_rel("InstanceNorm", InstanceNormRel); // layer_norm TVM_REGISTER_NODE_TYPE(LayerNormAttrs); bool LayerNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; const LayerNormAttrs* param = attrs.as<LayerNormAttrs>(); int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); CHECK(axis >= 0 && axis < (int)data->shape.size()); reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { auto attrs = make_object<LayerNormAttrs>(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.layer_norm"); return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv); }); RELAY_REGISTER_OP("nn.layer_norm") .describe(R"code( )code" TVM_ADD_FILELINE) .set_attrs_type<LayerNormAttrs>() .set_num_inputs(3) .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_support_level(1) .add_type_rel("LayerNorm", LayerNormRel); // relay.nn.batch_matmul bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as<TensorTypeNode>(); const auto* y = types[1].as<TensorTypeNode>(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " << " x shape=" << x->shape << ", y shape=" << y->shape; Array<tvm::PrimExpr> oshape = x->shape; oshape.Set(2, y->shape[1]); // assign output type reporter->Assign(types[2], TensorType(oshape, x->dtype)); return true; } // Positional relay function to create batch_matmul operator used by frontend FFI. Expr MakeBatchMatmul(Expr x, Expr y) { static const Op& op = Op::Get("nn.batch_matmul"); return Call(op, {x, y}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul") .set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` are data in batch. .. math:: batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T) - **x**: `(b, m, k)` - **y**: `(b, n, k)` - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("x", "3D Tensor", "First input.") .add_argument("y", "3D Tensor", "Second input.") .set_support_level(10) .add_type_rel("BatchMatmul", BatchMatmulRel); // relay.nn.cross_entropy bool CrossEntropyRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as<TensorTypeNode>(); const auto* y = types[1].as<TensorTypeNode>(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 2 && y->shape.size() == 2) << "CrossEntropy: shapes of x and y is inconsistent, " << "x shape = " << x->shape << ", " << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "CrossEntropy: shapes of x and y is inconsistent, " << "x shape = " << x->shape << ", " << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) << "CrossEntropy: shapes of x and y is inconsistent, " << "x shape = " << x->shape << ", " << "y shape = " << y->shape; // assign output type reporter->Assign(types[2], TensorType({}, x->dtype)); return true; } // Positional relay function to create cross_entropy operator used by frontend FFI. Expr MakeCrossEntropy(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy"); return Call(op, {predictions, targets}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy") .set_body_typed(MakeCrossEntropy); RELAY_REGISTER_OP("nn.cross_entropy") .describe(R"code( Computes cross entropy given predictions and targets. Do log on the data - do not accept logits. )code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("x", "1D Tensor", "Predictions.") .add_argument("y", "1D Tensor", "Targets.") .set_support_level(10) .add_type_rel("CrossEntropy", CrossEntropyRel); // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy_with_logits"); return Call(op, {predictions, targets}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits") .set_body_typed(MakeCrossEntropyWithLogits); RELAY_REGISTER_OP("nn.cross_entropy_with_logits") .describe(R"code( Computes cross entropy given predictions and targets. Accept logits. )code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("x", "1D Tensor", "Predictions.") .add_argument("y", "1D Tensor", "Targets.") .set_support_level(10) .add_type_rel("CrossEntropy", CrossEntropyRel); // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); const SubPixelAttrs* param = attrs.as<SubPixelAttrs>(); CHECK(param != nullptr); const int block_size = param->block_size; const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "DepthToSpace only support input layouts that are convertible from NCHW." << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(1, indexdiv(oshape[1], (block_size * block_size))); oshape.Set(2, oshape[2] * block_size); oshape.Set(3, oshape[3] * block_size); // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create DepthToSpace operator // used by frontend FFI Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string mode) { auto attrs = make_object<SubPixelAttrs>(); attrs->block_size = block_size; attrs->layout = std::move(layout); attrs->mode = std::move(mode); static const Op& op = Op::Get("nn.depth_to_space"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace); RELAY_REGISTER_OP("nn.depth_to_space") .describe(R"code(Rearrange input channels into spatial pixels. - **data**: data is a 4D array of shape (batch, in_channels, in_height, in_width) for NCHW - **out**: Output is a 4D array of shape (batch, in_channels / block_size * block_size, in_height * block_size, in_width * block_size) for NCHW. )code" TVM_ADD_FILELINE) .set_attrs_type<SubPixelAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_support_level(5) .add_type_rel("DepthToSpace", DepthToSpaceRel); bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); const SubPixelAttrs* param = attrs.as<SubPixelAttrs>(); CHECK(param != nullptr); const int block_size = param->block_size; const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "SpaceToDepth only support input layouts that are convertible from NCHW." << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(1, oshape[1] * (block_size * block_size)); oshape.Set(2, indexdiv(oshape[2], block_size)); oshape.Set(3, indexdiv(oshape[3], block_size)); // Assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create SpaceToDepth operator // used by frontend FFI Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { auto attrs = make_object<SubPixelAttrs>(); attrs->block_size = block_size; attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.space_to_depth"); return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth); RELAY_REGISTER_OP("nn.space_to_depth") .describe(R"code(Rearrange spatial pixels into new output channels. - **data**: data is a 4D array of shape (batch, in_channels, in_height, in_width) for NCHW - **out**: Output is a 4D array of shape (batch, in_channels * block_size * block_size, in_height / block_size, in_width / block_size) for NCHW. )code" TVM_ADD_FILELINE) .set_attrs_type<SubPixelAttrs>() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor") .set_support_level(5) .add_type_rel("SpaceToDepth", SpaceToDepthRel); } // namespace relay } // namespace tvm