Commit 52fde8f7 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] [Training] Fix ad for concatenate (#3729)

* reproduce error

* fix

* lint

* lint
parent 45827220
......@@ -17,7 +17,7 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..expr import const
from ..expr import const, Tuple, TupleGetItem
from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less, cos, sin
......@@ -176,3 +176,14 @@ def avg_pool2d_grad(orig, grad):
layout=attrs.layout, ceil_mode=attrs.ceil_mode,
count_include_pad=attrs.count_include_pad)
return [pool_grad]
# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
assert len(orig.args) == 1
t = orig.args[0]
x = TupleGetItem(t, 0)
y = TupleGetItem(t, 1)
# Assume only two element in tuple rn.
# In the real implementation, concatenate_grad probably need to be implemented by an operator.
return [Tuple([zeros_like(x), zeros_like(y)])]
......@@ -117,9 +117,12 @@ class AlphaEqualHandler:
* \return the comparison result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
auto compute = [&](){
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
......
......@@ -29,6 +29,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "let_list.h"
#include "../ir/type_functor.h"
......@@ -257,11 +258,79 @@ struct ReverseADType : TypeMutator {
}
};
Type ReverseType(const Type& t) {
return ReverseADType()(t);
}
/*! \brief Lift a function that transform Tensor to a function that also transform more type
* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
const Type& t,
const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (t.as<TensorTypeNode>()) {
return f(e);
} else if (auto* tt = t.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(LiftTensor(f,
tt->fields[i],
ll->Push(GetField(e, i)),
ll));
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
}
}
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) {
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
};
return LiftTensor(rev, t, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) {
return ll->Push(RefReadNode::make(GetField(e, 1)));
};
return LiftTensor(grad, t, e, ll);
}
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWriteNode::make(GetField(arg, 1),
Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
UpdateGrad(tt->fields[i],
ll->Push(GetField(arg, i)),
ll->Push(GetField(grad, i)),
ll);
}
} else {
LOG(FATAL) << "unsupported arg type of operator: " << t;
throw;
}
}
struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
explicit ReverseAD(const Var& bp) : bp(bp) { }
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
......@@ -279,29 +348,26 @@ struct ReverseAD : ExprMutator {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (const auto& arg : args) {
orig_args.push_back(GetField(arg, 0));
for (size_t i = 0; i < args.size(); ++i) {
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
Var orig_var = ll->Push(orig);
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
ll->Push(RefWriteNode::make(GetField(args[i], 1),
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
rev[i])));
UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return Pair(orig_var, ref);
return ret;
});
}
return ExprMutator::VisitExpr_(op);
......@@ -319,7 +385,7 @@ struct ReverseAD : ExprMutator {
}
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
return t.defined() ? ReverseType(t) : t;
}
};
......
......@@ -18,11 +18,12 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
import tvm.relay.op as op
def test_id():
......@@ -280,6 +281,20 @@ def test_grad_tuple():
tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
def test_concat():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
rt = relay.TensorType((10, 20), dtype)
x = relay.var("x", t)
y = op.concatenate([x, x], axis=1)
func = relay.Function([x], y)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
# no value validation as concatenate has dummy gradient right now.
if __name__ == "__main__":
test_id()
test_add()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment