Commit b692289e by 雾雨魔理沙 Committed by ziheng

[Relay] First order reverse mode Automatic Differentiation (#2321)

* init

staging on

final save before staging

save

init ver

init ver

ll stuff

save

add failing test

add file

pass test

fix error

huh?

save

Add test changes

Fix fusion with nested tuples

Fix reverse mode test

More hacking

Clean up AD

Hacking on reverse mode

Fix issue in reverse

save

fix lint

fix lint

fix lint

save

save

fix

support neg

address some comment

add back file

save

save

* save

* save

* save

* lint

* fix lint

* save

* fix
parent 242daeea
......@@ -124,6 +124,19 @@ using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
/*!
* \brief Gradient for a specific op.
*
* \param orig_call the original Expr.
*
* \param output_grad the gradient of the Expr.
*
* \return the gradient for each parameters.
*/
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
......@@ -28,6 +28,7 @@ namespace relay {
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Expr& expr, const Module& mod);
/*!
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
......
......@@ -31,6 +31,19 @@ class TupleValue(Value):
def __getitem__(self, field_no):
return self.fields[field_no]
def __len__(self):
return len(self.fields)
def __str__(self):
body = ','.join(str(f) for f in self.fields)
return '({0})'.format(body)
def __repr__(self):
body = ','.join(repr(f) for f in self.fields)
return '({0})'.format(body)
def __iter__(self):
return iter(self.fields)
@register_relay_node
class Closure(Value):
......@@ -59,6 +72,12 @@ class TensorValue(Value):
def __eq__(self, other):
return self.data == other.data
def __repr__(self):
return repr(self.data)
def __str__(self):
return str(self.data)
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
......
......@@ -414,3 +414,22 @@ def collect_device_annotation_ops(expr):
annotation expressions.
"""
return _ir_pass.CollectDeviceAnnotationOps(expr)
def gradient(expr, mod=None):
""".
Parameters
----------
expr : tvm.relay.Expr
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret : tvm.relay.Expr
A function that calculate the original result paired with gradient.
"""
return _ir_pass.first_order_gradient(expr, mod)
......@@ -3,8 +3,21 @@
from __future__ import absolute_import
import topi
from .op import register_compute, register_schedule, register_pattern
from .op import register_gradient
from .op import schedule_injective, OpPattern
def add_grad(orig, grad):
from tvm.relay import op
return [op.broadcast_to_like(grad, orig.args[0]), op.broadcast_to_like(grad, orig.args[1])]
register_gradient("add", add_grad)
def subtract_grad(orig, grad):
from tvm.relay import op
return [op.broadcast_to_like(grad, orig.args[0]),
op.broadcast_to_like(op.negative(grad), orig.args[1])]
register_gradient("subtract", subtract_grad)
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
......
......@@ -168,6 +168,22 @@ def register_pattern(op_name, pattern, level=10):
"""
return register(op_name, "TOpPattern", pattern, level)
def register_gradient(op_name, fgradient, level=10):
"""Register operator pattern for an op.
Parameters
----------
op_name : str
The name of the op.
fgradient : function (orig_expr : Expr, output_grad : Expr) -> new_expr : Expr
The gradient being used.
level : int
The priority level
"""
return register(op_name, "FPrimalGradient", fgradient, level)
_init_api("relay.op", __name__)
......
......@@ -62,6 +62,7 @@ def sqrt(data):
"""
return _make.sqrt(data)
def sigmoid(data):
"""Compute elementwise sigmoid of data.
......@@ -215,6 +216,7 @@ def add(lhs, rhs):
"""
return _make.add(lhs, rhs)
def subtract(lhs, rhs):
"""Subtraction with numpy-style broadcasting.
......@@ -232,6 +234,7 @@ def subtract(lhs, rhs):
"""
return _make.subtract(lhs, rhs)
def multiply(lhs, rhs):
"""Multiplication with numpy-style broadcasting.
......@@ -553,6 +556,7 @@ def ones_like(data):
"""
return _make.ones_like(data)
def clip(a, a_min, a_max):
"""Clip the elements in `a` between `a_min` and `a_max`.
`a_min` and `a_max` are cast to `a`'s dtype.
......
......@@ -21,7 +21,6 @@ namespace relay {
return {FTOPI(inputs[0], inputs[1])}; \
} \
// Addition
RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting")
......
......@@ -236,7 +236,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
for (const Expr& field : op->fields) {
this->Update(field, tuple_node, kInjective);
if (field->checked_type().as<TensorTypeNode>()) {
this->Update(field, tuple_node, kInjective);
} else {
this->Update(field, nullptr, kOpaque);
}
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
......
/*!
* Copyright (c) 2018 by Contributors
* \file ad.cc
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include "pattern_util.h"
#include "let_list.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
/*! What is automatic differentiation(AD) and why is it important?
* By AD, we roughly mean, given a term which denotes some mathematical function,
* derive a term which denotes the derivative of that mathematical function.
* Such a method can be compile-time, which is a macro on completely known function.
* Formally speaking, such requirement mean that the input function is a closed expression -
* that is, it only refer to local variable that is it's parameter, or defined inside it.
* Every top level definition satisfy this criteria.
* AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]).
* In relay we currently only support compile-time AD, but it should be enough for a lot of use case.
*
* In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant.
* Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD.
* In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD!
*/
/*! In relay, automatic differentiation(AD) is a macro,
* that transform closed expr(expr without free variable/free type variable) of type
* (x0, x1, x2, ...) -> Float[] to
* (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)),
* When x0, x1, x2... are Float of different shape.
* the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input.
* WithGradientType will take the type of input, and produce the type of output.
* There are multiple implementation of AD in relay, with different characteristic.
* However, they all transform the input expr according to WithGradientType.
*/
Type WithGradientType(const Type&);
/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& e, const Module& mod);
Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
return FuncTypeNode::make(ty->arg_types,
TupleTypeNode::make({
ty->ret_type,
TupleTypeNode::make(ty->arg_types)}), {}, {});
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (auto x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else {
return e;
}
}
/*! \brief A fragment of the program being built by the automatic differentation
* pass.
*/
struct ADValueNode {
virtual ~ADValueNode() { }
template <typename T>
T& get() {
auto ret = dynamic_cast<T*>(this);
CHECK(ret) << "cannot downcast";
return *ret;
}
};
using ADValue = std::shared_ptr<ADValueNode>;
/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr foward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& foward) :
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
};
/*! \brief A staged representation of the program, we reflect
* Relay functions into a function over fragments of AD. We
* can compute away this function to obtain a reverse mode program.
*/
struct ADFunction : ADValueNode {
std::function<ADValue(const std::vector<ADValue>&,
const Attrs&,
const tvm::Array<Type>&)> func;
explicit ADFunction(const std::function<ADValue(const std::vector<ADValue>&,
const Attrs&,
const tvm::Array<Type>&)>& func) :
func(func) { }
};
struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
LetList* ll;
ReverseAD(LetList* ll) : ll(ll) { }
ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref))
<< op->name << " does not have reverse mode defined";
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().foward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
});
return ret;
});
}
ADValue VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return std::make_shared<ADTensor>(ll, e);
}
ADValue VisitExpr_(const CallNode* op) final {
ADValue f = VisitExpr(op->op);
std::vector<ADValue> args;
for (const auto& arg : op->args) {
args.push_back(VisitExpr(arg));
}
return f->get<ADFunction>().func(args, op->attrs, op->type_args);
}
ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
// todo: assert no closure
return std::make_shared<ADFunction>([this, f](const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
CHECK_EQ(f->params.size(), args.size());
for (size_t i = 0; i < f->params.size(); ++i) {
env[f->params[i]] = args[i];
}
return VisitExpr(f->body);
});
}
ADValue VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return env.at(v);
}
};
Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first
// order case.
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "FOWithGradient expects its argument to be a function: " << f;
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
ReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
args.push_back(std::make_shared<ADTensor>(ll, p));
}
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OneLike(res.foward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return TupleNode::make(grad_res);
});
return Pair(res.foward, grad);
});
std::vector<Type> vt;
for (const auto& p : f->params) {
vt.push_back(p->type_annotation);
}
return FunctionNode::make(f->params,
body,
TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}),
{});
}
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
} // namespace relay
} // namespace tvm
......@@ -11,6 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h>
#include <string>
#include "../op/layout.h"
......@@ -150,6 +151,23 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
return ConstantNode::make(arr);
}
inline Expr GetField(Expr t, size_t i) {
return TupleGetItemNode::make(t, i);
}
inline Expr Pair(Expr l, Expr r) {
return TupleNode::make({l, r});
}
inline Expr Exp(Expr e) {
static const Op& op = Op::Get("exp");
return CallNode::make(op, {e});
}
inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
}
inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative");
......@@ -180,6 +198,15 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr ZeroLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
}
inline Expr OneLike(Expr e) {
static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
......
import tvm
from tvm import relay
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
from tvm.relay import create_executor
import numpy as np
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
def test_add():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), 2 * x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
def test_temp_add():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = x + x
func = relay.Function([x], y + y)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
def test_sub():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x - x)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), np.zeros_like(x.asnumpy()))
np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
if __name__ == "__main__":
test_id()
test_add()
test_temp_add()
test_sub()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment