Commit dee11b41 by 雾雨魔理沙 Committed by Jared Roesch

[Relay][Training] Small refactoring (#3893)

* init

* fix
parent a6bb84a8
...@@ -44,6 +44,7 @@ def log_grad(orig, grad): ...@@ -44,6 +44,7 @@ def log_grad(orig, grad):
x = orig.args[0] x = orig.args[0]
return [grad * ones_like(x) / x] return [grad * ones_like(x) / x]
@register_gradient("cos") @register_gradient("cos")
def cos_grad(orig, grad): def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]""" """Returns [grad * (-sin(x))]"""
...@@ -51,12 +52,14 @@ def cos_grad(orig, grad): ...@@ -51,12 +52,14 @@ def cos_grad(orig, grad):
ones = ones_like(x) ones = ones_like(x)
return [grad * (-ones * sin(x))] return [grad * (-ones * sin(x))]
@register_gradient("sin") @register_gradient("sin")
def sin_grad(orig, grad): def sin_grad(orig, grad):
"""Returns [grad * cos(x)]""" """Returns [grad * cos(x)]"""
x = orig.args[0] x = orig.args[0]
return [grad * cos(x)] return [grad * cos(x)]
@register_gradient("exp") @register_gradient("exp")
def exp_grad(orig, grad): def exp_grad(orig, grad):
"""Returns [grad * exp(x)]""" """Returns [grad * exp(x)]"""
...@@ -173,6 +176,7 @@ def clip_grad(orig, grad): ...@@ -173,6 +176,7 @@ def clip_grad(orig, grad):
ones = ones_like(x) ones = ones_like(x)
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))] return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
@register_gradient("nn.max_pool2d") @register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad): def max_pool2d_grad(orig, grad):
attrs = orig.attrs attrs = orig.attrs
...@@ -181,6 +185,7 @@ def max_pool2d_grad(orig, grad): ...@@ -181,6 +185,7 @@ def max_pool2d_grad(orig, grad):
layout=attrs.layout, ceil_mode=attrs.ceil_mode) layout=attrs.layout, ceil_mode=attrs.ceil_mode)
return [pool_grad] return [pool_grad]
@register_gradient("nn.avg_pool2d") @register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad): def avg_pool2d_grad(orig, grad):
attrs = orig.attrs attrs = orig.attrs
...@@ -190,6 +195,7 @@ def avg_pool2d_grad(orig, grad): ...@@ -190,6 +195,7 @@ def avg_pool2d_grad(orig, grad):
count_include_pad=attrs.count_include_pad) count_include_pad=attrs.count_include_pad)
return [pool_grad] return [pool_grad]
# not implemented, this is only for testing. # not implemented, this is only for testing.
@register_gradient("concatenate") @register_gradient("concatenate")
def concatenate_grad(orig, grad): def concatenate_grad(orig, grad):
...@@ -201,6 +207,7 @@ def concatenate_grad(orig, grad): ...@@ -201,6 +207,7 @@ def concatenate_grad(orig, grad):
# In the real implementation, concatenate_grad probably need to be implemented by an operator. # In the real implementation, concatenate_grad probably need to be implemented by an operator.
return [Tuple([zeros_like(x), zeros_like(y)])] return [Tuple([zeros_like(x), zeros_like(y)])]
@register_gradient("nn.conv2d") @register_gradient("nn.conv2d")
def conv2d_grad(orig, grad): def conv2d_grad(orig, grad):
"""Gradient of conv2d""" """Gradient of conv2d"""
...@@ -268,8 +275,8 @@ def softmax_grad(orig, grad): ...@@ -268,8 +275,8 @@ def softmax_grad(orig, grad):
@register_gradient("nn.bias_add") @register_gradient("nn.bias_add")
def bias_grad(orig, grad): def bias_add_grad(orig, grad):
"""Returns grad""" """Returns gradient of bias_add"""
data, bias = orig.args data, bias = orig.args
return [collapse_sum_like(grad, data), return [collapse_sum_like(grad, data),
collapse_sum_like(grad, bias)] collapse_sum_like(grad, bias)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment