Commit eae76b3c by 雾雨魔理沙 Committed by ziheng

[Relay] Higher order reverse mode automatic differentiation that work with control flow (#2496)

add test

remove dead code

stash

do it

add more test
parent 25c50fc9
...@@ -530,9 +530,11 @@ def to_graph_normal_form(expr): ...@@ -530,9 +530,11 @@ def to_graph_normal_form(expr):
return _ir_pass.to_graph_normal_form(expr) return _ir_pass.to_graph_normal_form(expr)
def gradient(expr, mod=None): def gradient(expr, mod=None, mode='higher_order'):
""" """
Transform a function to return original result paired with gradient of input. Transform the input function,
returning a function that calculate the original result,
paired with gradient of the input.
Parameters Parameters
---------- ----------
...@@ -541,12 +543,23 @@ def gradient(expr, mod=None): ...@@ -541,12 +543,23 @@ def gradient(expr, mod=None):
mod : Optional[tvm.relay.Module] mod : Optional[tvm.relay.Module]
mode : Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' only work on first order code, but will not produce reference nor closure.
'higher_order' work on all code using reference and closure.
Returns Returns
------- -------
expr : tvm.relay.Expr expr : tvm.relay.Expr
The output expression. The transformed expression.
""" """
if mode == 'first_order':
return _ir_pass.first_order_gradient(expr, mod) return _ir_pass.first_order_gradient(expr, mod)
elif mode == 'higher_order':
return _ir_pass.gradient(expr, mod)
else:
raise Exception('unknown mode')
def get_total_mac_number(expr): def get_total_mac_number(expr):
......
...@@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
} }
node->pattern = op_pattern; node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>(); const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references. // pass the message back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
......
...@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>; ...@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;
/*! \brief AD over a program which generates a tensor output. */ /*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode { struct ADTensor : ADValueNode {
Expr foward; Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& foward) : ADTensor(LetList* ll, const Expr& forward) :
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { } forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
}; };
/*! \brief A staged representation of the program, we reflect /*! \brief A staged representation of the program, we reflect
...@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode { ...@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
func(func) { } func(func) { }
}; };
struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions; std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping // we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env; std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
LetList* ll; LetList* ll;
ReverseAD(LetList* ll) : ll(ll) { } FirstOrderReverseAD(LetList* ll) : ll(ll) { }
ADValue VisitExpr_(const OpNode* op) final { ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op); Op op_ref = GetRef<Op>(op);
...@@ -123,12 +123,13 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { ...@@ -123,12 +123,13 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
const tvm::Array<Type>& type_args) { const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args; std::vector<Expr> call_args;
for (const ADValue& adval : args) { for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().foward); call_args.push_back(adval->get<ADTensor>().forward);
} }
auto orig = CallNode::make(op_ref, call_args, attrs, type_args); auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig); auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse = args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i])); ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
...@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { ...@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
} }
}; };
Type GradRetType(const Function& f) {
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
if (!f->ret_type.defined()) {
return Type();
}
std::vector<Type> vt;
for (const auto& p : f->params) {
if (!p->type_annotation.defined()) {
return Type();
}
vt.push_back(p->type_annotation);
}
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}
Expr FirstOrderGradient(const Expr& re, const Module& mod) { Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first // Currently we first remove any global functions for the first
// order case. // order case.
...@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// We will then build a sequence of lets which implement reverse mode. // We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) { Expr body = LetList::With([&](LetList* ll) {
ReverseAD reverse_ad(ll); FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e); ADValue rev = reverse_ad(e);
std::vector<ADValue> args; std::vector<ADValue> args;
for (const auto& p : f->params) { for (const auto& p : f->params) {
...@@ -191,7 +209,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -191,7 +209,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
auto c = rev->get<ADFunction>().func(args, Attrs(), {}); auto c = rev->get<ADFunction>().func(args, Attrs(), {});
const auto& res = c->get<ADTensor>(); const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) { Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OneLike(res.foward); res.reverse = OnesLike(res.forward);
for (auto it = reverse_ad.backprop_actions.rbegin(); for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend(); it != reverse_ad.backprop_actions.rend();
++it) { ++it) {
...@@ -203,34 +221,119 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -203,34 +221,119 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
} }
return TupleNode::make(grad_res); return TupleNode::make(grad_res);
}); });
return Pair(res.foward, grad); return Pair(res.forward, grad);
}); });
// if type annotations are provided, we will construct a ret type; return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
// otherwise, leave it to be inferred }
Type ret_type = Type();
std::vector<Type> vt; TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
bool missing = !f->ret_type.defined(); .set_body([](TVMArgs args, TVMRetValue* ret) {
for (const auto& p : f->params) { CHECK_EQ(args.size(), 2);
if (missing || !p->type_annotation.defined()) { *ret = FirstOrderGradient(args[0], args[1]);
missing = true; });
break;
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)});
} }
vt.push_back(p->type_annotation); };
struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
ReverseAD(const Var& bp) : bp(bp) { }
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
} }
if (!missing) { Expr VisitExpr_(const CallNode* op) final {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (const auto& arg : args) {
orig_args.push_back(GetField(VisitExpr(arg), 0));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
Var orig_var = ll->Push(orig);
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
ll->Push(RefWriteNode::make(GetField(args[i], 1),
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
rev[i])));
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return Pair(orig_var, ref);
});
}
return ExprMutator::VisitExpr_(op);
} }
return FunctionNode::make(f->params, body, ret_type, {}); Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
};
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
} }
TVM_REGISTER_API("relay._ir_pass.first_order_gradient") Expr Gradient(const Expr& re, const Module& mod) {
.set_body([](TVMArgs args, TVMRetValue* ret) { auto e = DeGlobal(mod, re);
CHECK_EQ(args.size(), 2); auto f = e.as<FunctionNode>();
*ret = FirstOrderGradient(args[0], args[1]); CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
}
return Pair(GetField(c, 0), TupleNode::make(ret));
}); });
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) { ...@@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
} }
inline Expr ZeroLike(Expr e) { inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like"); static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
inline Expr OneLike(Expr e) { inline Expr OnesLike(Expr e) {
static const Op& op = Op::Get("ones_like"); static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
......
...@@ -2,6 +2,7 @@ import tvm ...@@ -2,6 +2,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
from tvm.relay import create_executor from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
import numpy as np import numpy as np
...@@ -123,6 +124,72 @@ def test_broadcast_subtract(): ...@@ -123,6 +124,72 @@ def test_broadcast_subtract():
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0)) -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
def test_tuple():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
z = relay.var("z", t)
tup = relay.Var("tup")
func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]),
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape)
y_nd = rand(dtype, *shape)
z_nd = rand(dtype, *shape)
x_np = x_nd.asnumpy()
y_np = y_nd.asnumpy()
z_np = z_nd.asnumpy()
expected_forward = x_np + y_np - z_np
ex = create_executor()
forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
np.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
np.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
def test_pow():
mod = relay.Module()
p = Prelude(mod)
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
i_nd = rand(dtype, *shape)
ex = create_executor(mod=mod)
forward, (grad_i,) = ex.evaluate(back_func)(i_nd)
np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
def test_ref():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
r = relay.Var("r")
u = relay.Var("u")
body = relay.RefRead(r)
body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(x), body)
func = relay.Function([x], body)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
ex = create_executor()
forward, (grad_x,) = ex.evaluate(back_func)(x_nd)
np.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
np.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_add() test_add()
...@@ -130,3 +197,6 @@ if __name__ == "__main__": ...@@ -130,3 +197,6 @@ if __name__ == "__main__":
test_sub() test_sub()
test_broadcast_add() test_broadcast_add()
test_broadcast_subtract() test_broadcast_subtract()
test_tuple()
test_pow()
test_ref()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment