Unverified Commit d2019784 by Wuwei Lin Committed by GitHub

[Relay] Conv2d grad (#3636)

* [Relay] Conv2d grad

* Fix test

* Fix first order gradient
parent 7391fc00
...@@ -17,9 +17,13 @@ ...@@ -17,9 +17,13 @@
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from ..expr import const, Tuple, TupleGetItem from ..expr import const, Tuple, TupleGetItem
from .op import register_gradient from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where from .reduce import sum as _sum
from .transform import collapse_sum_like, broadcast_to_like, where, transpose, reshape, tile, \
strided_slice
from .tensor import exp, negative, power, less, cos, sin from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like from .tensor import zeros_like, ones_like
from . import nn as _nn from . import nn as _nn
...@@ -187,3 +191,62 @@ def concatenate_grad(orig, grad): ...@@ -187,3 +191,62 @@ def concatenate_grad(orig, grad):
# Assume only two element in tuple rn. # Assume only two element in tuple rn.
# In the real implementation, concatenate_grad probably need to be implemented by an operator. # In the real implementation, concatenate_grad probably need to be implemented by an operator.
return [Tuple([zeros_like(x), zeros_like(y)])] return [Tuple([zeros_like(x), zeros_like(y)])]
@register_gradient("nn.conv2d")
def conv2d_grad(orig, grad):
"""Gradient of conv2d"""
attrs = orig.attrs
data, weight = orig.args
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(weight.checked_type.shape)
_, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
out_channel, _, filter_h, filter_w = weight_shape
# infer output_padding
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(get_const_tuple(attrs.padding),
(filter_h, filter_w))
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
output_padding = (in_h - out_h, in_w - out_w)
assert attrs.data_layout == 'NCHW', 'only support NCHW data layout'
assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout'
assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout'
backward_data = _nn.conv2d_transpose(grad, weight,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.dilation,
groups=attrs.groups,
output_padding=output_padding)
grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw
backward_weight = _nn.conv2d(data, grad,
strides=attrs.dilation,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch)
# infer shape of backward_weight
padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \
// dilation_h + 1
padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \
// dilation_w + 1
backward_weight = reshape(backward_weight,
[batch, in_channel // attrs.groups, out_channel,
padded_weight_grad_h, padded_weight_grad_w])
backward_weight = _sum(backward_weight, axis=0)
backward_weight = transpose(backward_weight, [1, 0, 2, 3])
assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0],
end=[None, None, filter_h, filter_w])
return [backward_data, backward_weight]
...@@ -116,6 +116,7 @@ def conv2d_transpose(data, ...@@ -116,6 +116,7 @@ def conv2d_transpose(data,
kernel_size=None, kernel_size=None,
data_layout="NCHW", data_layout="NCHW",
kernel_layout="OIHW", kernel_layout="OIHW",
out_layout="",
output_padding=(0, 0), output_padding=(0, 0),
out_dtype=""): out_dtype=""):
"""Two dimensional transposed convolution operator. """Two dimensional transposed convolution operator.
...@@ -152,6 +153,9 @@ def conv2d_transpose(data, ...@@ -152,6 +153,9 @@ def conv2d_transpose(data,
kernel_layout : str, optional kernel_layout : str, optional
Layout of the weight. Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output. Additional zero-padding to be added to one side of the output.
...@@ -165,7 +169,7 @@ def conv2d_transpose(data, ...@@ -165,7 +169,7 @@ def conv2d_transpose(data,
""" """
return _make.conv2d_transpose(data, weight, strides, padding, dilation, return _make.conv2d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, output_padding, out_dtype) kernel_layout, out_layout, output_padding, out_dtype)
def softmax(data, axis=-1): def softmax(data, axis=-1):
......
...@@ -320,6 +320,7 @@ Expr MakeConv2DTranspose(Expr data, ...@@ -320,6 +320,7 @@ Expr MakeConv2DTranspose(Expr data,
Array<IndexExpr> kernel_size, Array<IndexExpr> kernel_size,
std::string data_layout, std::string data_layout,
std::string kernel_layout, std::string kernel_layout,
std::string out_layout,
Array<IndexExpr> output_padding, Array<IndexExpr> output_padding,
DataType out_dtype) { DataType out_dtype) {
auto attrs = make_node<Conv2DTransposeAttrs>(); auto attrs = make_node<Conv2DTransposeAttrs>();
...@@ -332,6 +333,7 @@ Expr MakeConv2DTranspose(Expr data, ...@@ -332,6 +333,7 @@ Expr MakeConv2DTranspose(Expr data,
attrs->groups = groups; attrs->groups = groups;
attrs->data_layout = std::move(data_layout); attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout); attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype); attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose"); static const Op& op = Op::Get("nn.conv2d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {}); return CallNode::make(op, {data, weight}, Attrs(attrs), {});
......
...@@ -109,7 +109,9 @@ struct ADTensor : ADValueNode { ...@@ -109,7 +109,9 @@ struct ADTensor : ADValueNode {
Expr forward; Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& forward) : ADTensor(LetList* ll, const Expr& forward) :
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { } forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
this->forward->checked_type_ = forward->checked_type();
}
}; };
/*! \brief A staged representation of the program, we reflect /*! \brief A staged representation of the program, we reflect
...@@ -117,10 +119,12 @@ struct ADTensor : ADValueNode { ...@@ -117,10 +119,12 @@ struct ADTensor : ADValueNode {
* can compute away this function to obtain a reverse mode program. * can compute away this function to obtain a reverse mode program.
*/ */
struct ADFunction : ADValueNode { struct ADFunction : ADValueNode {
std::function<ADValue(const std::vector<ADValue>&, std::function<ADValue(const Type&,
const std::vector<ADValue>&,
const Attrs&, const Attrs&,
const tvm::Array<Type>&)> func; const tvm::Array<Type>&)> func;
explicit ADFunction(const std::function<ADValue(const std::vector<ADValue>&, explicit ADFunction(const std::function<ADValue(const Type&,
const std::vector<ADValue>&,
const Attrs&, const Attrs&,
const tvm::Array<Type>&)>& func) : const tvm::Array<Type>&)>& func) :
func(func) { } func(func) { }
...@@ -139,7 +143,8 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> { ...@@ -139,7 +143,8 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
Op op_ref = GetRef<Op>(op); Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref)) CHECK(rev_map.count(op_ref))
<< op->name << " does not have reverse mode defined"; << op->name << " does not have reverse mode defined";
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args, return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type,
const std::vector<ADValue>& args,
const Attrs& attrs, const Attrs& attrs,
const tvm::Array<Type>& type_args) { const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args; std::vector<Expr> call_args;
...@@ -147,6 +152,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> { ...@@ -147,6 +152,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
call_args.push_back(adval->get<ADTensor>().forward); call_args.push_back(adval->get<ADTensor>().forward);
} }
auto orig = CallNode::make(op_ref, call_args, attrs, type_args); auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
orig->checked_type_ = orig_type;
auto ret = std::make_shared<ADTensor>(ll, orig); auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
...@@ -171,13 +177,14 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> { ...@@ -171,13 +177,14 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
for (const auto& arg : op->args) { for (const auto& arg : op->args) {
args.push_back(VisitExpr(arg)); args.push_back(VisitExpr(arg));
} }
return f->get<ADFunction>().func(args, op->attrs, op->type_args); return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args);
} }
ADValue VisitExpr_(const FunctionNode* op) final { ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op); Function f = GetRef<Function>(op);
// todo: assert no closure // todo: assert no closure
return std::make_shared<ADFunction>([this, f](const std::vector<ADValue>& args, return std::make_shared<ADFunction>([this, f](const Type& orig_type,
const std::vector<ADValue>& args,
const Attrs& attrs, const Attrs& attrs,
const tvm::Array<Type>& type_args) { const tvm::Array<Type>& type_args) {
CHECK_EQ(f->params.size(), args.size()); CHECK_EQ(f->params.size(), args.size());
...@@ -227,7 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -227,7 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
for (const auto& p : f->params) { for (const auto& p : f->params) {
args.push_back(std::make_shared<ADTensor>(ll, p)); args.push_back(std::make_shared<ADTensor>(ll, p));
} }
auto c = rev->get<ADFunction>().func(args, Attrs(), {}); auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
const auto& res = c->get<ADTensor>(); const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) { Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OnesLike(res.forward); res.reverse = OnesLike(res.forward);
...@@ -271,7 +278,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f, ...@@ -271,7 +278,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
LetList* ll) { LetList* ll) {
CHECK(IsAtomic(e)) << e; CHECK(IsAtomic(e)) << e;
if (t.as<TensorTypeNode>()) { if (t.as<TensorTypeNode>()) {
return f(e); auto ret = f(e);
ret->checked_type_ = t;
return ret;
} else if (auto* tt = t.as<TupleTypeNode>()) { } else if (auto* tt = t.as<TupleTypeNode>()) {
tvm::Array<Expr> fields; tvm::Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) { for (size_t i = 0; i < tt->fields.size(); ++i) {
...@@ -280,7 +289,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f, ...@@ -280,7 +289,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
ll->Push(GetField(e, i)), ll->Push(GetField(e, i)),
ll)); ll));
} }
return TupleNode::make(fields); auto ret = TupleNode::make(fields);
ret->checked_type_ = t;
return std::move(ret);
} else { } else {
LOG(FATAL) << "unsupported input/output type: " << tt; LOG(FATAL) << "unsupported input/output type: " << tt;
throw; throw;
...@@ -348,11 +359,14 @@ struct ReverseAD : ExprMutator { ...@@ -348,11 +359,14 @@ struct ReverseAD : ExprMutator {
args.push_back(ll->Push(VisitExpr(arg))); args.push_back(ll->Push(VisitExpr(arg)));
} }
std::vector<Expr> orig_args; std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll)); orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
} }
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll)); orig->checked_type_ = op->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = op->checked_type();
auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp)); auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make( Expr nbp = FunctionNode::make(
{}, {},
......
...@@ -20,7 +20,7 @@ import topi ...@@ -20,7 +20,7 @@ import topi
import topi.testing import topi.testing
from tvm import relay from tvm import relay
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import run_infer_type from tvm.relay.testing import run_infer_type
...@@ -83,6 +83,53 @@ def test_avg_pool2d_grad(): ...@@ -83,6 +83,53 @@ def test_avg_pool2d_grad():
ceil_mode=False, count_include_pad=False) ceil_mode=False, count_include_pad=False)
def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try:
import torch
import torch.nn.functional as F
except ImportError:
print('Skip because pytorch is not installed')
return
dtype = 'float32'
data = relay.var('data', shape=dshape, dtype=dtype)
weight = relay.var('weight', shape=wshape, dtype=dtype)
conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation,
groups=groups)
fwd_func = relay.Function([data, weight], conv)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True)
weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True)
out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation,
groups=groups)
grad_output_pt = torch.ones(out_pt.shape)
grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides,
padding=padding, dilation=dilation, groups=groups) \
.detach().numpy()
grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides,
padding=padding, dilation=dilation, groups=groups) \
.detach().numpy()
for target, ctx in ctx_list():
data = tvm.nd.array(data_pt.detach().numpy(), ctx)
weight = tvm.nd.array(weight_pt.detach().numpy(), ctx)
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight)
np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4)
def test_conv2d_grad():
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')
if __name__ == "__main__": if __name__ == "__main__":
test_max_pool2d_grad() test_max_pool2d_grad()
test_avg_pool2d_grad() test_avg_pool2d_grad()
test_conv2d_grad()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment