Commit eae76b3c by 雾雨魔理沙 Committed by ziheng

[Relay] Higher order reverse mode automatic differentiation that work with control flow (#2496)

add test

remove dead code

stash

do it

add more test
parent 25c50fc9
......@@ -530,9 +530,11 @@ def to_graph_normal_form(expr):
return _ir_pass.to_graph_normal_form(expr)
def gradient(expr, mod=None):
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform a function to return original result paired with gradient of input.
Transform the input function,
returning a function that calculate the original result,
paired with gradient of the input.
Parameters
----------
......@@ -541,12 +543,23 @@ def gradient(expr, mod=None):
mod : Optional[tvm.relay.Module]
mode : Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' only work on first order code, but will not produce reference nor closure.
'higher_order' work on all code using reference and closure.
Returns
-------
expr : tvm.relay.Expr
The output expression.
The transformed expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
if mode == 'first_order':
return _ir_pass.first_order_gradient(expr, mod)
elif mode == 'higher_order':
return _ir_pass.gradient(expr, mod)
else:
raise Exception('unknown mode')
def get_total_mac_number(expr):
......
......@@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
......
......@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;
/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr foward;
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& foward) :
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
ADTensor(LetList* ll, const Expr& forward) :
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
};
/*! \brief A staged representation of the program, we reflect
......@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
func(func) { }
};
struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
LetList* ll;
ReverseAD(LetList* ll) : ll(ll) { }
FirstOrderReverseAD(LetList* ll) : ll(ll) { }
ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
......@@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().foward);
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
});
return ret;
});
return ret;
});
}
ADValue VisitExpr_(const ConstantNode* op) final {
......@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
}
};
Type GradRetType(const Function& f) {
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
if (!f->ret_type.defined()) {
return Type();
}
std::vector<Type> vt;
for (const auto& p : f->params) {
if (!p->type_annotation.defined()) {
return Type();
}
vt.push_back(p->type_annotation);
}
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}
Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first
// order case.
......@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
ReverseAD reverse_ad(ll);
FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
......@@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OneLike(res.foward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
res.reverse = OnesLike(res.forward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return TupleNode::make(grad_res);
});
return Pair(res.forward, grad);
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)});
}
};
struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
ReverseAD(const Var& bp) : bp(bp) { }
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}
Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
std::vector<Expr> orig_args;
for (const auto& arg : args) {
orig_args.push_back(GetField(VisitExpr(arg), 0));
}
return TupleNode::make(grad_res);
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
Var orig_var = ll->Push(orig);
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
ll->Push(RefWriteNode::make(GetField(args[i], 1),
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
rev[i])));
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return Pair(orig_var, ref);
});
return Pair(res.foward, grad);
});
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation);
return ExprMutator::VisitExpr_(op);
}
Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}
if (!missing) {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
};
return FunctionNode::make(f->params, body, ret_type, {});
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
}
return Pair(GetField(c, 0), TupleNode::make(ret));
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});
} // namespace relay
} // namespace tvm
......@@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ZeroLike(Expr e) {
inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
}
inline Expr OneLike(Expr e) {
inline Expr OnesLike(Expr e) {
static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e});
}
......
......@@ -53,7 +53,7 @@ bool TupleGetItemRel(const Array<Type>& types,
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
CHECK_LT(param->index, data->fields.size());
CHECK_LT(param->index, data->fields.size());
reporter->Assign(types[1], data->fields[param->index]);
return true;
}
......
......@@ -2,6 +2,7 @@ import tvm
from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
import numpy as np
......@@ -123,6 +124,72 @@ def test_broadcast_subtract():
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
def test_tuple():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
z = relay.var("z", t)
tup = relay.Var("tup")
func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]),
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape)
y_nd = rand(dtype, *shape)
z_nd = rand(dtype, *shape)
x_np = x_nd.asnumpy()
y_np = y_nd.asnumpy()
z_np = z_nd.asnumpy()
expected_forward = x_np + y_np - z_np
ex = create_executor()
forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
np.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
np.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
def test_pow():
mod = relay.Module()
p = Prelude(mod)
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
i_nd = rand(dtype, *shape)
ex = create_executor(mod=mod)
forward, (grad_i,) = ex.evaluate(back_func)(i_nd)
np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
def test_ref():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
r = relay.Var("r")
u = relay.Var("u")
body = relay.RefRead(r)
body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(x), body)
func = relay.Function([x], body)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
ex = create_executor()
forward, (grad_x,) = ex.evaluate(back_func)(x_nd)
np.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
np.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
if __name__ == "__main__":
test_id()
test_add()
......@@ -130,3 +197,6 @@ if __name__ == "__main__":
test_sub()
test_broadcast_add()
test_broadcast_subtract()
test_tuple()
test_pow()
test_ref()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment