Commit dee11b41 by 雾雨魔理沙 Committed by Jared Roesch

[Relay][Training] Small refactoring (#3893)

* init

* fix
parent a6bb84a8
......@@ -44,6 +44,7 @@ def log_grad(orig, grad):
x = orig.args[0]
return [grad * ones_like(x) / x]
@register_gradient("cos")
def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]"""
......@@ -51,12 +52,14 @@ def cos_grad(orig, grad):
ones = ones_like(x)
return [grad * (-ones * sin(x))]
@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]
@register_gradient("exp")
def exp_grad(orig, grad):
"""Returns [grad * exp(x)]"""
......@@ -173,6 +176,7 @@ def clip_grad(orig, grad):
ones = ones_like(x)
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
attrs = orig.attrs
......@@ -181,6 +185,7 @@ def max_pool2d_grad(orig, grad):
layout=attrs.layout, ceil_mode=attrs.ceil_mode)
return [pool_grad]
@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
attrs = orig.attrs
......@@ -190,6 +195,7 @@ def avg_pool2d_grad(orig, grad):
count_include_pad=attrs.count_include_pad)
return [pool_grad]
# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
......@@ -201,6 +207,7 @@ def concatenate_grad(orig, grad):
# In the real implementation, concatenate_grad probably need to be implemented by an operator.
return [Tuple([zeros_like(x), zeros_like(y)])]
@register_gradient("nn.conv2d")
def conv2d_grad(orig, grad):
"""Gradient of conv2d"""
......@@ -268,8 +275,8 @@ def softmax_grad(orig, grad):
@register_gradient("nn.bias_add")
def bias_grad(orig, grad):
"""Returns grad"""
def bias_add_grad(orig, grad):
"""Returns gradient of bias_add"""
data, bias = orig.args
return [collapse_sum_like(grad, data),
collapse_sum_like(grad, bias)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment