Commit 21935dcb by Zhi Committed by Jared Roesch

[Relay][heterogeneous pass] remove on_device op after annotation (#3204)

* remove on_device op after annotation

* Update src/relay/pass/device_annotation.cc

Co-Authored-By: MORINAGA <34588258+imorinaga@users.noreply.github.com>
parent 91b181d4
......@@ -485,7 +485,52 @@ class DeviceInfo {
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
RewriteAnnotation rewrote = RewriteAnnotation();
return rewrote.Rewrite(expr, fallback_device);
Expr new_expr = rewrote.Rewrite(expr, fallback_device);
// Remove OnDevice operators. Note that these operators are only present at the
// leaves after annotation. Therefore, we can simply reconstruct the
// Function/Expr by removing them directly.
if (const FunctionNode* fn = new_expr.as<FunctionNode>()) {
auto params = fn->params;
auto body = fn->body;
std::vector<Expr> new_body;
if (const TupleNode* tuple = body.as<TupleNode>()) {
for (const auto& field : tuple->fields) {
if (!IsOnDeviceNode(field.operator->())) {
new_body.push_back(field);
}
}
CHECK_GT(new_body.size(), 0U);
if (new_body.size() == 1) {
return FunctionNode::make(params, new_body[0], Type(nullptr),
fn->type_params, fn->attrs);
} else if (tuple->fields.size() == new_body.size()) {
return new_expr;
} else {
Tuple tuple_body = TupleNode::make(new_body);
return FunctionNode::make(params, tuple_body, Type(nullptr),
fn->type_params, fn->attrs);
}
} else {
return new_expr;
}
} else if (const TupleNode* tuple = new_expr.as<TupleNode>()) {
std::vector<Expr> new_fields;
for (const auto& field : tuple->fields) {
if (!IsOnDeviceNode(field.operator->())) {
new_fields.push_back(field);
}
}
CHECK_GT(new_fields.size(), 0U);
if (tuple->fields.size() == new_fields.size()) {
return new_fields.size() == 1 ? new_fields[0] : new_expr;
} else {
return new_fields.size() == 1 ? new_fields[0]
: TupleNode::make(new_fields);
}
} else {
return new_expr;
}
}
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
......
......@@ -42,9 +42,7 @@ def test_redundant_annotation():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func
def expected():
add = relay.add(x, y)
......@@ -58,6 +56,35 @@ def test_redundant_annotation():
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_expr():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
x = relay.var("x", shape=(3,))
y = relay.var("y", shape=(3,))
z = relay.var("z", shape=(3,))
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx1)
sub = relay.subtract(add, z)
_sub = relay.annotation.on_device(sub, ctx2)
expr = relay.Tuple([sub, _add, _sub])
expr = relay.ir_pass.infer_type(expr)
expr = relay.ir_pass.rewrite_annotated_ops(expr,
ctx1.device_type)
return expr
def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx1, ctx2)
sub = relay.subtract(copy_add_sub, z)
return sub
annotated_expr = relay.ir_pass.infer_type(annotated())
expected_expr = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.graph_equal(annotated_expr, expected_expr)
def test_annotate_all():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
......@@ -77,9 +104,7 @@ def test_annotate_all():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func
def expected():
add = relay.add(x, y)
......@@ -91,6 +116,7 @@ def test_annotate_all():
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_none():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
......@@ -174,9 +200,7 @@ def test_conv_network():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[4]),
func.body[4])
return func
def expected():
conv2d_1 = relay.nn.conv2d(
......@@ -202,7 +226,7 @@ def test_conv_network():
kernel_size=(3, 3),
padding=(1, 1))
func = relay.Function([data1, weight, data2], conv2d_3)
func = relay.Function([data1, data2, weight], conv2d_3)
return func
def check_storage_and_device_types():
......@@ -306,9 +330,7 @@ def run_fusible_network(dev, tgt):
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[2]),
func.body[2])
return func
def expected():
add = relay.add(x, y)
......@@ -358,9 +380,7 @@ def run_fusible_network(dev, tgt):
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[5]),
func.body[5])
return func
annotated_func = annotated()
expected_func = get_func()
......@@ -386,9 +406,7 @@ def run_fusible_network(dev, tgt):
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1])
return func
def expected():
add = relay.add(x, y)
......@@ -462,9 +480,7 @@ def run_unpropagatable_graph(dev, tgt):
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[3]),
func.body[3])
return func
def expected():
add = relay.add(a, b)
......@@ -506,6 +522,7 @@ def test_check_run():
if __name__ == "__main__":
test_redundant_annotation()
test_annotate_expr()
test_annotate_all()
test_annotate_none()
test_conv_network()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment