Commit 887255a8 by Zhi Committed by Jared Roesch

[relay][heterogeneous] annotate using visitor (#3261)

* annotate using visitor

* retrigger CI
parent f6acf2e5
......@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
}
Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) {
if (IsOnDeviceNode(call_node)) {
return this->VisitExpr(call_node->args[0]);
}
if (IsDeviceCopyNode(call_node)) {
return ExprMutator::VisitExpr_(call_node);
}
......@@ -358,6 +362,9 @@ class DeviceInfo {
public:
void Visit(const Expr& expr) {
if (const auto* fn = expr.as<FunctionNode>()) {
for (const auto& param : fn->params) {
this->VisitExpr(param);
}
this->VisitExpr(fn->body);
} else {
this->VisitExpr(expr);
......@@ -402,7 +409,7 @@ class DeviceInfo {
}
void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}
void VisitExpr_(const LetNode* ln) final {
......
......@@ -21,6 +21,7 @@ import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.expr_functor import ExprMutator
def test_redundant_annotation():
......@@ -34,11 +35,10 @@ def test_redundant_annotation():
add = relay.add(x, y)
_add1 = relay.annotation.on_device(add, ctx2)
_add2 = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
sub1 = relay.subtract(_add1, z)
sub2 = relay.subtract(_add2, z)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add1, _add2,
sub])))
func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
......@@ -46,9 +46,11 @@ def test_redundant_annotation():
def expected():
add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx2, ctx1)
sub = relay.subtract(copy_add_sub, z)
func = relay.Function([x, y, z], sub)
copy_add_sub1 = relay.device_copy(add, ctx2, ctx1)
sub1 = relay.subtract(copy_add_sub1, z)
copy_add_sub2 = relay.device_copy(add, ctx2, ctx1)
sub2 = relay.subtract(copy_add_sub2, z)
func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
return func
annotated_func = relay.ir_pass.infer_type(annotated())
......@@ -66,10 +68,9 @@ def test_annotate_expr():
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx1)
sub = relay.subtract(add, z)
sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2)
expr = relay.Tuple([sub, _add, _sub])
expr = relay.ir_pass.infer_type(expr)
expr = relay.ir_pass.infer_type(_sub)
expr = relay.ir_pass.rewrite_annotated_ops(expr,
ctx1.device_type)
return expr
......@@ -95,12 +96,10 @@ def test_annotate_all():
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z)
sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2)
func = relay.Function([x, y, z],
relay.Tuple(tvm.convert([_add, _sub,
sub])))
func = relay.Function([x, y, z], _sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
......@@ -168,6 +167,34 @@ def test_conv_network():
dev1 = tvm.context(1)
dev2 = tvm.context(2)
def original():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
add = relay.add(conv2d_1, conv2d_2)
conv2d_3 = relay.nn.conv2d(
add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
func = relay.Function([data1, data2, weight], conv2d_3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
return func
def annotated():
conv2d_1 = relay.nn.conv2d(
data1,
......@@ -183,25 +210,40 @@ def test_conv_network():
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
add = relay.add(conv2d_1, conv2d_2)
add = relay.add(_conv2d_1, _conv2d_2)
_add = relay.annotation.on_device(add, dev1)
conv2d_3 = relay.nn.conv2d(
add,
_add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
_conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
func = relay.Function([data1, data2, weight],
relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2,
_conv2d_3, _add,
conv2d_3])))
func = relay.Function([data1, data2, weight], _conv2d_3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
return func
class ScheduleConv2d(ExprMutator):
def __init__(self, device):
self.device = device
super().__init__()
def visit_call(self, expr):
visit = super().visit_call(expr)
if expr.op == tvm.relay.op.get("nn.conv2d"):
return relay.annotation.on_device(visit, self.device)
else:
return visit
def annotate_with_visitor(func):
sched = ScheduleConv2d(dev2)
func = sched.visit(func)
func = relay.ir_pass.rewrite_annotated_ops(func, dev1.device_type)
return func
def expected():
conv2d_1 = relay.nn.conv2d(
data1,
......@@ -249,10 +291,19 @@ def test_conv_network():
assert len(set(device_types)) == 2
assert set(device_types) == {1, 2}
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_manual_annotation():
annotated_func = annotated()
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_visitor_annotation():
annotated_func = annotate_with_visitor(original())
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
test_manual_annotation()
test_visitor_annotation()
def run_fusible_network(dev, tgt):
......@@ -321,12 +372,11 @@ def run_fusible_network(dev, tgt):
sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
subtract = relay.subtract(_sqrt, log)
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_sqrt, _exp, exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
......@@ -364,19 +414,16 @@ def run_fusible_network(dev, tgt):
def annotated():
add = relay.add(x, y)
_add = relay.annotation.on_device(add, dev_ctx)
sqrt = relay.sqrt(add)
sqrt = relay.sqrt(_add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add)
log = relay.log(_add)
_log = relay.annotation.on_device(log, dev_ctx)
subtract = relay.subtract(sqrt, log)
subtract = relay.subtract(_sqrt, _log)
_subtract = relay.annotation.on_device(subtract, dev_ctx)
exp = relay.exp(subtract)
exp = relay.exp(_subtract)
_exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_add, _sqrt, _log,
_subtract, _exp,
exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
......@@ -401,8 +448,7 @@ def run_fusible_network(dev, tgt):
exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, cpu_ctx)
func = relay.Function([x, y],
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.Function([x, y], _exp)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
......@@ -472,11 +518,9 @@ def run_unpropagatable_graph(dev, tgt):
_add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul)
sub = relay.subtract(_add, _mul)
_sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d],
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.Function([a, b, c, d], _sub)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment