/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors. * * \file tvm/relay/pass/pattern_util.h * \brief Header of internal operator functions * These can be used for writing passes. */ #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_PASS_PATTERN_UTIL_H_ #include <builtin_fp16.h> #include <tvm/data_layout.h> #include <tvm/relay/op.h> #include <tvm/relay/expr.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/transform.h> #include <string> namespace tvm { namespace relay { /*! * \brief Dispatch DataType to the C++ data type * during runtime. */ #define TVM_DTYPE_DISPATCH(type, DType, ...) \ if (type == Float(64)) { \ typedef double DType; \ {__VA_ARGS__} \ } else if (type == Float(32)) { \ typedef float DType; \ {__VA_ARGS__} \ } else if (type == Float(16)) { \ typedef uint16_t DType; \ {__VA_ARGS__} \ } else if (type == Int(64)) { \ typedef int64_t DType; \ {__VA_ARGS__} \ } else if (type == Int(32)) { \ typedef int32_t DType; \ {__VA_ARGS__} \ } else if (type == Int(16)) { \ typedef int16_t DType; \ {__VA_ARGS__} \ } else if (type == Int(8)) { \ typedef int8_t DType; \ {__VA_ARGS__} \ } else if (type == UInt(64)) { \ typedef uint64_t DType; \ {__VA_ARGS__} \ } else if (type == UInt(32)) { \ typedef uint32_t DType; \ {__VA_ARGS__} \ } else if (type == UInt(16)) { \ typedef uint16_t DType; \ {__VA_ARGS__} \ } else if (type == UInt(8)) { \ typedef uint8_t DType; \ {__VA_ARGS__} \ } else { \ LOG(FATAL) << "unknown data type " << type; \ } /*! * \brief Try to match lhs and rhs via broadcasting rule, such that: * * rhs matches the dimension of lhs specified by lhs_axes * rhs's value equals 1 on rest of dimensions. * * \param tlhs The type of left operand (data) * \param trhs The type right operand (bias) * \param lhs_axes The axes on lhs to match. * \param rhs_value A squeezed version of rhs which only contains matched dimension. * \return Whether match is successful. */ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs, const Array<Integer>& lhs_axes, Expr* rhs_value = nullptr) { if (tlhs->shape.size() < trhs->shape.size()) return false; AttrsEqual equal; size_t base = tlhs->shape.size() - trhs->shape.size(); size_t j = 0; NodePtr<SqueezeAttrs> squeeze_attrs; if (rhs_value != nullptr) { squeeze_attrs = make_node<SqueezeAttrs>(); } for (size_t i = 0; i < tlhs->shape.size(); ++i) { if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) { if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) { return false; } ++j; } else if (i >= base) { if (!is_const_int(trhs->shape[i - base], 1)) { return false; } if (rhs_value != nullptr) { squeeze_attrs->axis.push_back(static_cast<int>(i - base)); } } } if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) { static const Op& squeeze_op = Op::Get("squeeze"); *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {}); } return true; } /*! * \brief Expand 1D Tensor to match axis. * * The result bias can be used to add or multiply to * the target Tensor on the specified axis via broadcasting rule. * * \param bias The bias. * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Integer>& axes) { static const Op& expand_dims = Op::Get("expand_dims"); for (size_t i = axes.size(); i != 0; --i) { if (i == axes.size()) { int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1; if (num_pad_axis > 0) { auto attrs = make_node<ExpandDimsAttrs>(); attrs->axis = i; attrs->num_newaxis = static_cast<int>(num_pad_axis); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); } } else { int64_t diff = axes[i]->value - axes[i - 1]->value; CHECK_GE(diff, 0L); if (diff > 0) { auto attrs = make_node<ExpandDimsAttrs>(); attrs->axis = i; attrs->num_newaxis = static_cast<int>(diff); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); } } } return bias; } /*! * \brief Check if the call is depthwise conv2d. * * \param call The conv2d call. * \param param The conv2d attributes. * \return Whether it is depthwise_conv2d. */ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape); return is_const_int(wshape[0], param->groups) && is_const_int(wshape[1], 1); } /*! * \brief Get super-dimension of output channels of conv2d * \param call The conv2d call. * \return Super-dimension size of output channels of conv2d. */ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { auto param = call->attrs.as<Conv2DAttrs>(); auto tweight = call->args[1]->type_as<TensorTypeNode>(); auto index = param->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); auto channels = as_const_int(tweight->shape[index]); return *channels; } /*! * \brief Create a Constant with a scalar * * \param dtype The data type. * \param value The value of the scalar. * \return A Constant. */ template<typename T> inline Constant MakeConstantScalar(DataType dtype, T value) { runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); TVM_DTYPE_DISPATCH(dtype, DType, { if (dtype == Float(16)) { // convert to float16 // storage is uint16_t *static_cast<DType*>(arr->data) = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value)); } else { *static_cast<DType*>(arr->data) = value; } }) return ConstantNode::make(arr); } /*! * \brief Check if two expressions are equal scalars. * \param a The expression to be checked. * \param b The expression to be checked * \return Whether two expressions are equal scalars. */ inline bool IsEqualScalar(const Expr& a, const Expr& b) { const auto* constant_a = a.as<ConstantNode>(); const auto* constant_b = b.as<ConstantNode>(); if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { return false; } return AlphaEqual(a, b); } inline Expr GetField(Expr t, size_t i) { return TupleGetItemNode::make(t, i); } inline Expr Pair(Expr l, Expr r) { return TupleNode::make({l, r}); } inline Expr Exp(Expr e) { static const Op& op = Op::Get("exp"); return CallNode::make(op, {e}); } inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return CallNode::make(op, {e}); } /*! * \brief Get an immediate scalar from a Constant expr. * * \param expr The Constant expr. * \return A scalar with type T. */ template <typename T> T GetScalarFromConstant(Expr expr) { const auto* n = expr.as<ConstantNode>(); CHECK(n->is_scalar()); return static_cast<T*>(n->data->data)[0]; } inline Expr Cast(Expr x, DataType dtype) { static const Op& op = Op::Get("cast"); auto attrs = make_node<CastAttrs>(); attrs->dtype = dtype; return CallNode::make(op, {x}, Attrs(attrs), {}); } inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); return CallNode::make(op, {x}, Attrs(), {}); } inline Expr Sqrt(Expr x) { static const Op& op = Op::Get("sqrt"); return CallNode::make(op, {x}, Attrs(), {}); } inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); return CallNode::make(op, {x}, Attrs(), {}); } inline Expr Round(Expr x) { static const Op& op = Op::Get("round"); return CallNode::make(op, {x}, Attrs(), {}); } inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); auto attrs = make_node<ClipAttrs>(); attrs->a_min = a_min; attrs->a_max = a_max; return CallNode::make(op, {x}, Attrs(attrs), {}); } inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr Subtract(Expr lhs, Expr rhs) { static const Op& op = Op::Get("subtract"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr ZerosLike(Expr e) { static const Op& op = Op::Get("zeros_like"); return CallNode::make(op, {e}); } inline Expr OnesLike(Expr e) { static const Op& op = Op::Get("ones_like"); return CallNode::make(op, {e}); } inline Expr CollapseSumLike(Expr e) { static const Op& op = Op::Get("collapse_sum_like"); return CallNode::make(op, {e}); } inline Expr Power(Expr lhs, Expr rhs) { static const Op& op = Op::Get("power"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr RightShift(Expr x, Expr nbit) { static const Op& op = Op::Get("right_shift"); return CallNode::make(op, {x, nbit}, Attrs(), {}); } inline Expr LeftShift(Expr x, Expr nbit) { static const Op& op = Op::Get("left_shift"); return CallNode::make(op, {x, nbit}, Attrs(), {}); } inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } inline Expr Copy(Expr data) { static const Op& op = Op::Get("copy"); return CallNode::make(op, {data}, Attrs(), {}); } Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr StopFusion(Expr data); Expr ForceCast(Expr data); } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_