# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple from ..expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum from .tensor import ( cos, exp, less, negative, ones_like, power, sin, zeros_like, equal, shape_of, log) from .transform import ( broadcast_to_like, collapse_sum_like, cast_like, reshape, reshape_like, strided_slice, take, tile, transpose, where, repeat, expand_dims, full_like ) @register_gradient("log") def log_grad(orig, grad): """Returns [grad * (1 / x)]""" x = orig.args[0] return [grad * ones_like(x) / x] @register_gradient("cos") def cos_grad(orig, grad): """Returns [grad * (-sin(x))]""" x = orig.args[0] ones = ones_like(x) return [grad * (-ones * sin(x))] @register_gradient("sin") def sin_grad(orig, grad): """Returns [grad * cos(x)]""" x = orig.args[0] return [grad * cos(x)] @register_gradient("atan") def atan_grad(orig, grad): """Returns [grad * 1 / (1 + x ^ 2)]""" x = orig.args[0] a = const(2.0) return [grad * ones_like(x) / (ones_like(x) + power(x, a))] @register_gradient("exp") def exp_grad(orig, grad): """Returns [grad * exp(x)]""" return [grad * exp(orig.args[0])] @register_gradient("sqrt") def sqrt_grad(orig, grad): """Returns [grad * 0.5 * (x ^ -0.5)]""" a = const(0.5) # (TODO) type? return [grad * a * power(orig.args[0], negative(a))] @register_gradient("sigmoid") def sigmoid_grad(orig, grad): """Returns [grad * sigmoid(x) * (1 - sigmoid(x))].""" return [grad * orig * (ones_like(orig) - orig)] @register_gradient("tanh") def tanh_grad(orig, grad): """Returns grad * (1 - tanh(x) * tanh(x)).""" return [grad * ones_like(orig) - orig * orig] @register_gradient("nn.relu") def relu_grad(orig, grad): """Returns grad * (select(x < 0, 0, 1)).""" x = orig.args[0] zeros = zeros_like(x) ones = ones_like(x) return [where(less(x, zeros), zeros, ones * grad)] @register_gradient("add") def add_grad(orig, grad): """Returns [grad, grad]""" return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])] @register_gradient("subtract") def subtract_grad(orig, grad): """Returns [grad, -grad]""" return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(negative(grad), orig.args[1])] @register_gradient("multiply") def multiply_grad(orig, grad): """Returns [grad * y, grad * x]""" x, y = orig.args return [collapse_sum_like(grad * y, x), collapse_sum_like(grad * x, y)] @register_gradient("divide") def divide_grad(orig, grad): """Returns [grad / y, - grad * (x / y) / y]""" x, y = orig.args return [collapse_sum_like(grad / y, x), collapse_sum_like(- (grad * orig / y), y)] @register_gradient("zeros") def zeros_grad(orig, grad): """Returns []""" return [] @register_gradient("ones") def ones_grad(orig, grad): """Returns []""" return [] @register_gradient("zeros_like") def zeros_like_grad(orig, grad): """Returns [0]""" return [orig] @register_gradient("ones_like") def ones_like_grad(orig, grad): """Returns [0]""" return [zeros_like(orig.args[0])] @register_gradient("collapse_sum_like") def collapse_sum_like_grad(orig, grad): """Returns [broadcast_to_like(grad, x), 0]""" x, y = orig.args return [broadcast_to_like(grad, x), zeros_like(y)] @register_gradient("abs") def abs_grad(orig, grad): """Returns grad * (select(x < 0, -1, 1)).""" x = orig.args[0] zeros = zeros_like(x) ones = ones_like(x) return [where(less(x, zeros), -ones * grad, ones * grad)] @register_gradient("clip") def clip_grad(orig, grad): """Returns grad * (select(x < min || max < x , 0, 1)).""" x = orig.args[0] a_min = orig.attrs.get_int("a_min") a_max = orig.attrs.get_int("a_max") a_mins = broadcast_to_like(const(a_min), x) a_maxs = broadcast_to_like(const(a_max), x) zeros = zeros_like(x) ones = ones_like(x) return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))] @register_gradient("nn.max_pool2d") def max_pool2d_grad(orig, grad): """Returns the gradient of max_pool2d.""" attrs = orig.attrs pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, layout=attrs.layout, ceil_mode=attrs.ceil_mode) return [pool_grad] @register_gradient("nn.avg_pool2d") def avg_pool2d_grad(orig, grad): """Returns the gradient of avg_pool2d.""" attrs = orig.attrs pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, layout=attrs.layout, ceil_mode=attrs.ceil_mode, count_include_pad=attrs.count_include_pad) return [pool_grad] @register_gradient("nn.global_avg_pool2d") def global_avg_pool2d_grad(orig, grad): """Returns the gradient of global_avg_pool2d.""" data = orig.args[0] shape = data.checked_type.shape layout = orig.attrs.layout # we assume NCHW or NHWC layout for now, but easy to add more assert layout in ["NCHW", "NHWC"] if layout == "NCHW": pool_size = shape[2], shape[3] elif layout == "NHWC": pool_size = shape[1], shape[2] pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size, strides=(1, 1), padding=(0, 0), layout=layout) return [pool_grad] # not implemented, this is only for testing. @register_gradient("concatenate") def concatenate_grad(orig, grad): assert len(orig.args) == 1 t = orig.args[0] x = TupleGetItem(t, 0) y = TupleGetItem(t, 1) # Assume only two element in tuple rn. # In the real implementation, concatenate_grad probably need to be implemented by an operator. return [Tuple([zeros_like(x), zeros_like(y)])] @register_gradient("nn.conv2d") def conv2d_grad(orig, grad): """Gradient of conv2d""" attrs = orig.attrs data, weight = orig.args data_shape = get_const_tuple(data.checked_type.shape) weight_shape = get_const_tuple(weight.checked_type.shape) _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape) batch, in_channel, in_h, in_w = data_shape out_channel, _, filter_h, filter_w = weight_shape # infer output_padding fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(get_const_tuple(attrs.padding), (filter_h, filter_w)) stride_h, stride_w = get_const_tuple(attrs.strides) dilation_h, dilation_w = get_const_tuple(attrs.dilation) out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w output_padding = (in_h - out_h, in_w - out_w) assert attrs.data_layout == 'NCHW', 'only support NCHW data layout' assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout' assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout' backward_data = _nn.conv2d_transpose(grad, weight, strides=attrs.strides, padding=attrs.padding, dilation=attrs.dilation, groups=attrs.groups, output_padding=output_padding) grad = tile(grad, [1, in_channel // attrs.groups, 1, 1]) grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw backward_weight = _nn.conv2d(data, grad, strides=attrs.dilation, padding=attrs.padding, dilation=attrs.strides, groups=in_channel * batch) # infer shape of backward_weight padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \ // dilation_h + 1 padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \ // dilation_w + 1 backward_weight = reshape(backward_weight, [batch, in_channel // attrs.groups, out_channel, padded_weight_grad_h, padded_weight_grad_w]) backward_weight = _sum(backward_weight, axis=0) backward_weight = transpose(backward_weight, [1, 0, 2, 3]) assert padded_weight_grad_h >= filter_h assert padded_weight_grad_w >= filter_w if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0], end=[None, None, filter_h, filter_w]) return [backward_data, backward_weight] def _get_reduce_axis(call): """Helper function that returns the reduce axis of the call as plain python ints.""" x, axis = call.args[0], call.attrs.axis shape = x.checked_type.concrete_shape # should never exclude when axis is None assert not (axis is None and call.attrs.exclude) if axis is None: return None # convert to nonnegative integers and sort axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)]) if call.attrs.exclude: axis = [ax for ax in range(len(shape)) if ax not in axis] return axis def _unreduce_expand(x, axis): """Helper function that returns x expanded on the reduced dimensions in axis.""" # assume axis is sorted nonnegative ints for ax in axis: x = expand_dims(x, ax) return x @register_gradient("max") def max_grad(orig, grad): """Returns the gradient of max""" x, axis = orig.args[0], _get_reduce_axis(orig) shape = x.checked_type.concrete_shape repeated = orig if axis is None: repeated = full_like(x, repeated) else: # expand dims (if necessary) and repeat along each axis if not orig.attrs.keepdims: repeated = _unreduce_expand(repeated, axis) grad = _unreduce_expand(grad, axis) for ax in axis: repeated = repeat(repeated, shape[ax], ax) indicators = cast_like(equal(repeated, x), grad) num_selected = _sum(indicators, axis, keepdims=True) # spread error across all max weights return [indicators * grad / num_selected] @register_gradient("nn.softmax") def softmax_grad(orig, grad): """Gradient of softmax""" return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig] @register_gradient("nn.log_softmax") def log_softmax_grad(orig, grad): """Gradient of log_softmax""" x = orig.args[0] sm = _nn.softmax(x, axis=orig.attrs.axis) grad = grad / sm return softmax_grad(sm, grad) @register_gradient("nn.bias_add") def bias_add_grad(orig, grad): """Returns gradient of bias_add""" data = orig.args[0] return [collapse_sum_like(grad, data), _sum(grad, orig.attrs.axis, keepdims=False, exclude=True)] @register_gradient("nn.dense") def dense_grad(orig, grad): """Returns [grad' @ weight, data @ grad']""" data, weight = orig.args return [collapse_sum_like(transpose(grad) * weight, data), collapse_sum_like(data * transpose(grad), weight)] @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" return [reshape_like(grad, orig.args[0])] @register_gradient("cast") def cast_grad(orig, grad): x = orig.args[0] return [cast_like(grad, x)] @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" data = orig.args[0] return [reshape_like(grad, data)] @register_gradient("transpose") def transpose_grad(orig, grad): """Returns grad transposed over the complement of original transpose axes""" orig_axes = orig.attrs.axes if orig_axes: dims = len(orig_axes) new_axes = [0] * dims for i in range(dims): new_axes[int(orig_axes[i])] = i else: new_axes = None return [transpose(grad, axes=new_axes)] @register_gradient("negative") def negative_grad(orig, grad): """Returns -grad""" return [-grad] @register_gradient("sum") def sum_grad(orig, grad): """Returns grad broadcasted to data dims""" data, axis = orig.args[0], _get_reduce_axis(orig) if not orig.attrs.keepdims: if axis is None: axis = list(range(len(data.checked_type.concrete_shape))) grad = _unreduce_expand(grad, axis) return [broadcast_to_like(grad, data)] @register_gradient("nn.cross_entropy") def cross_entropy_grad(orig, grad): x, y = orig.args shape = shape_of(x) batch_size = take(shape, const(0, dtype='int32'), axis=0) grad = grad / batch_size.astype('float32') return [-grad * y / x, -grad * log(x)] @register_gradient("nn.cross_entropy_with_logits") def cross_entropy_with_logits_grad(orig, grad): x, y = orig.args shape = shape_of(x) batch_size = take(shape, const(0, dtype='int32'), axis=0) grad = grad / batch_size.astype('float32') return [-grad * y, -grad * x]