Commit 93d610a1 by Altan Haan Committed by Jared Roesch

[Relay][Training] Add checkpoint annotation for checkpointing memory optimization (#4146)

* add checkpoint annotation for checkpointing memory optimization

* add alpha-equivalence checkpoint test and fix gradient type issue

* fix build issues

* ignore checkpoint annotation when checking missing gradients

* refactor, fix checkpoint compute for tuple and add tests
parent 7732873e
......@@ -17,10 +17,10 @@
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext
def on_device(data, device):
"""Annotate an expression with a certain device type.
......@@ -61,3 +61,20 @@ def stop_fusion(data):
The annotated expression.
"""
return _make.stop_fusion(data)
def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective)
......@@ -144,5 +144,32 @@ Mark the end of bitpacking.
return {topi::identity(inputs[0])};
});
TVM_REGISTER_API("relay.op.annotation._make.checkpoint")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {});
});
RELAY_REGISTER_OP("annotation.checkpoint")
.describe(R"code(
Mark a checkpoint for checkpointing memory optimization.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
Array<Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
}
return outputs;
});
} // namespace relay
} // namespace tvm
......@@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) {
}
Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}
Expr VisitExpr_(const VarNode* op) final {
......
......@@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
def test_checkpoint():
inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
output = relay.multiply(relay.add(inputs[0], inputs[1]),
relay.add(inputs[2], inputs[3]))
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
relay.multiply(inputs[2], inputs[3])])
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
relay.TupleGetItem(out_tuple, 1))
check_grad(relay.Function(inputs, out_single))
if __name__ == "__main__":
......
......@@ -31,6 +31,127 @@ def run_infer_type(expr):
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_checkpoint():
dtype = "float32"
xs = [relay.var("x{}".format(i), dtype) for i in range(4)]
f = relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
f_checkpoint = relay.annotation.checkpoint(f)
func, func_checkpoint = relay.Function(xs, f), relay.Function(xs, f_checkpoint)
f, f_checkpoint = run_infer_type(func), run_infer_type(func_checkpoint)
assert f.checked_type == f_checkpoint.checked_type
inputs = [np.random.uniform() for _ in range(len(xs))]
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
f_res = intrp.evaluate(f)(*inputs)
f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs)
tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0)
def test_checkpoint_alpha_equal():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
f = relay.Function(xs, relay.annotation.checkpoint(
relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
))
df = transform.gradient(run_infer_type(f))
# run PE and DCE
with transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(relay.Module.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
"""
v0.0.4
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
-> (Tensor[(1), float32],
(Tensor[(1), float32], Tensor[(1), float32],
Tensor[(1), float32], Tensor[(1), float32])) {
%0 = add(%x, %y);
%1 = add(%z, %w);
let %x1: Tensor[(1), float32] = multiply(%0, %1);
let %x2: Tensor[(1), float32] = ones_like(%x1);
let %x3: Tensor[(1), float32] = add(%x, %y);
let %x4: Tensor[(1), float32] = add(%z, %w);
%2 = zeros_like(%x3);
%3 = multiply(%x2, %x4);
%4 = collapse_sum_like(%3, %x3);
let %x5: Tensor[(1), float32] = add(%2, %4);
%5 = zeros_like(%x4);
%6 = multiply(%x2, %x3);
%7 = collapse_sum_like(%6, %x4);
let %x6: Tensor[(1), float32] = add(%5, %7);
%8 = zeros_like(%x);
%9 = collapse_sum_like(%x5, %x);
%10 = add(%8, %9);
%11 = zeros_like(%y);
%12 = collapse_sum_like(%x5, %y);
%13 = add(%11, %12);
%14 = zeros_like(%z);
%15 = collapse_sum_like(%x6, %z);
%16 = add(%14, %15);
%17 = zeros_like(%w);
%18 = collapse_sum_like(%x6, %w);
%19 = add(%17, %18);
%20 = (%10, %13, %16, %19);
(%x1, %20)
}
"""
)
relay.analysis.assert_alpha_equal(df, df_parsed)
def test_checkpoint_alpha_equal_tuple():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
f = relay.Function(xs, relay.annotation.checkpoint(
relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
))
df = transform.gradient(run_infer_type(f))
# run PE and DCE
with transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(relay.Module.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
"""
v0.0.4
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
-> ((Tensor[(1), float32], Tensor[(1), float32]),
(Tensor[(1), float32], Tensor[(1), float32],
Tensor[(1), float32], Tensor[(1), float32])) {
let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */;
let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */;
let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */;
let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */;
%0 = (%x1, %x2);
%1 = zeros_like(%x) /* ty=Tensor[(1), float32] */;
%2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */;
%3 = add(%1, %2) /* ty=Tensor[(1), float32] */;
%4 = zeros_like(%y) /* ty=Tensor[(1), float32] */;
%5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */;
%6 = add(%4, %5) /* ty=Tensor[(1), float32] */;
%7 = zeros_like(%z) /* ty=Tensor[(1), float32] */;
%8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */;
%9 = add(%7, %8) /* ty=Tensor[(1), float32] */;
%10 = zeros_like(%w) /* ty=Tensor[(1), float32] */;
%11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */;
%12 = add(%10, %11) /* ty=Tensor[(1), float32] */;
%13 = (%3, %6, %9, %12);
(%0, %13)
}
"""
)
relay.analysis.assert_alpha_equal(df, df_parsed)
def test_collapse_sum_like():
shape = (3, 4, 5, 6)
shape_like = (4, 5, 6)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment