Commit 887255a8 by Zhi Committed by Jared Roesch

[relay][heterogeneous] annotate using visitor (#3261)

* annotate using visitor

* retrigger CI
parent f6acf2e5
...@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator { ...@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
} }
Expr VisitExpr_(const CallNode* call_node) final { Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) { if (IsOnDeviceNode(call_node)) {
return this->VisitExpr(call_node->args[0]);
}
if (IsDeviceCopyNode(call_node)) {
return ExprMutator::VisitExpr_(call_node); return ExprMutator::VisitExpr_(call_node);
} }
...@@ -358,6 +362,9 @@ class DeviceInfo { ...@@ -358,6 +362,9 @@ class DeviceInfo {
public: public:
void Visit(const Expr& expr) { void Visit(const Expr& expr) {
if (const auto* fn = expr.as<FunctionNode>()) { if (const auto* fn = expr.as<FunctionNode>()) {
for (const auto& param : fn->params) {
this->VisitExpr(param);
}
this->VisitExpr(fn->body); this->VisitExpr(fn->body);
} else { } else {
this->VisitExpr(expr); this->VisitExpr(expr);
...@@ -402,7 +409,7 @@ class DeviceInfo { ...@@ -402,7 +409,7 @@ class DeviceInfo {
} }
void VisitExpr_(const VarNode* vn) final { void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
} }
void VisitExpr_(const LetNode* ln) final { void VisitExpr_(const LetNode* ln) final {
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.expr_functor import ExprMutator
def test_redundant_annotation(): def test_redundant_annotation():
...@@ -34,11 +35,10 @@ def test_redundant_annotation(): ...@@ -34,11 +35,10 @@ def test_redundant_annotation():
add = relay.add(x, y) add = relay.add(x, y)
_add1 = relay.annotation.on_device(add, ctx2) _add1 = relay.annotation.on_device(add, ctx2)
_add2 = relay.annotation.on_device(add, ctx2) _add2 = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z) sub1 = relay.subtract(_add1, z)
sub2 = relay.subtract(_add2, z)
func = relay.Function([x, y, z], func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
relay.Tuple(tvm.convert([_add1, _add2,
sub])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type) ctx1.device_type)
...@@ -46,9 +46,11 @@ def test_redundant_annotation(): ...@@ -46,9 +46,11 @@ def test_redundant_annotation():
def expected(): def expected():
add = relay.add(x, y) add = relay.add(x, y)
copy_add_sub = relay.device_copy(add, ctx2, ctx1) copy_add_sub1 = relay.device_copy(add, ctx2, ctx1)
sub = relay.subtract(copy_add_sub, z) sub1 = relay.subtract(copy_add_sub1, z)
func = relay.Function([x, y, z], sub) copy_add_sub2 = relay.device_copy(add, ctx2, ctx1)
sub2 = relay.subtract(copy_add_sub2, z)
func = relay.Function([x, y, z], relay.Tuple([sub1, sub2]))
return func return func
annotated_func = relay.ir_pass.infer_type(annotated()) annotated_func = relay.ir_pass.infer_type(annotated())
...@@ -66,10 +68,9 @@ def test_annotate_expr(): ...@@ -66,10 +68,9 @@ def test_annotate_expr():
def annotated(): def annotated():
add = relay.add(x, y) add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx1) _add = relay.annotation.on_device(add, ctx1)
sub = relay.subtract(add, z) sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2) _sub = relay.annotation.on_device(sub, ctx2)
expr = relay.Tuple([sub, _add, _sub]) expr = relay.ir_pass.infer_type(_sub)
expr = relay.ir_pass.infer_type(expr)
expr = relay.ir_pass.rewrite_annotated_ops(expr, expr = relay.ir_pass.rewrite_annotated_ops(expr,
ctx1.device_type) ctx1.device_type)
return expr return expr
...@@ -95,12 +96,10 @@ def test_annotate_all(): ...@@ -95,12 +96,10 @@ def test_annotate_all():
def annotated(): def annotated():
add = relay.add(x, y) add = relay.add(x, y)
_add = relay.annotation.on_device(add, ctx2) _add = relay.annotation.on_device(add, ctx2)
sub = relay.subtract(add, z) sub = relay.subtract(_add, z)
_sub = relay.annotation.on_device(sub, ctx2) _sub = relay.annotation.on_device(sub, ctx2)
func = relay.Function([x, y, z], func = relay.Function([x, y, z], _sub)
relay.Tuple(tvm.convert([_add, _sub,
sub])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type) ctx1.device_type)
...@@ -168,6 +167,34 @@ def test_conv_network(): ...@@ -168,6 +167,34 @@ def test_conv_network():
dev1 = tvm.context(1) dev1 = tvm.context(1)
dev2 = tvm.context(2) dev2 = tvm.context(2)
def original():
conv2d_1 = relay.nn.conv2d(
data1,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
conv2d_2 = relay.nn.conv2d(
data2,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
add = relay.add(conv2d_1, conv2d_2)
conv2d_3 = relay.nn.conv2d(
add,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
func = relay.Function([data1, data2, weight], conv2d_3)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
return func
def annotated(): def annotated():
conv2d_1 = relay.nn.conv2d( conv2d_1 = relay.nn.conv2d(
data1, data1,
...@@ -183,25 +210,40 @@ def test_conv_network(): ...@@ -183,25 +210,40 @@ def test_conv_network():
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1)) padding=(1, 1))
_conv2d_2 = relay.annotation.on_device(conv2d_2, dev2) _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2)
add = relay.add(conv2d_1, conv2d_2) add = relay.add(_conv2d_1, _conv2d_2)
_add = relay.annotation.on_device(add, dev1) _add = relay.annotation.on_device(add, dev1)
conv2d_3 = relay.nn.conv2d( conv2d_3 = relay.nn.conv2d(
add, _add,
weight, weight,
channels=64, channels=64,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1)) padding=(1, 1))
_conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2)
func = relay.Function([data1, data2, weight], func = relay.Function([data1, data2, weight], _conv2d_3)
relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2,
_conv2d_3, _add,
conv2d_3])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type) tvm.context(3).device_type)
return func return func
class ScheduleConv2d(ExprMutator):
def __init__(self, device):
self.device = device
super().__init__()
def visit_call(self, expr):
visit = super().visit_call(expr)
if expr.op == tvm.relay.op.get("nn.conv2d"):
return relay.annotation.on_device(visit, self.device)
else:
return visit
def annotate_with_visitor(func):
sched = ScheduleConv2d(dev2)
func = sched.visit(func)
func = relay.ir_pass.rewrite_annotated_ops(func, dev1.device_type)
return func
def expected(): def expected():
conv2d_1 = relay.nn.conv2d( conv2d_1 = relay.nn.conv2d(
data1, data1,
...@@ -249,10 +291,19 @@ def test_conv_network(): ...@@ -249,10 +291,19 @@ def test_conv_network():
assert len(set(device_types)) == 2 assert len(set(device_types)) == 2
assert set(device_types) == {1, 2} assert set(device_types) == {1, 2}
annotated_func = annotated() def test_manual_annotation():
expected_func = expected() annotated_func = annotated()
check_annotated_graph(annotated_func, expected_func) expected_func = expected()
check_storage_and_device_types() check_annotated_graph(annotated_func, expected_func)
check_storage_and_device_types()
def test_visitor_annotation():
annotated_func = annotate_with_visitor(original())
expected_func = expected()
check_annotated_graph(annotated_func, expected_func)
test_manual_annotation()
test_visitor_annotation()
def run_fusible_network(dev, tgt): def run_fusible_network(dev, tgt):
...@@ -321,12 +372,11 @@ def run_fusible_network(dev, tgt): ...@@ -321,12 +372,11 @@ def run_fusible_network(dev, tgt):
sqrt = relay.sqrt(add) sqrt = relay.sqrt(add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx) _sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add) log = relay.log(add)
subtract = relay.subtract(sqrt, log) subtract = relay.subtract(_sqrt, log)
exp = relay.exp(subtract) exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, dev_ctx) _exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y], func = relay.Function([x, y], _exp)
relay.Tuple(tvm.convert([_sqrt, _exp, exp])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type) cpu_ctx.device_type)
...@@ -364,19 +414,16 @@ def run_fusible_network(dev, tgt): ...@@ -364,19 +414,16 @@ def run_fusible_network(dev, tgt):
def annotated(): def annotated():
add = relay.add(x, y) add = relay.add(x, y)
_add = relay.annotation.on_device(add, dev_ctx) _add = relay.annotation.on_device(add, dev_ctx)
sqrt = relay.sqrt(add) sqrt = relay.sqrt(_add)
_sqrt = relay.annotation.on_device(sqrt, dev_ctx) _sqrt = relay.annotation.on_device(sqrt, dev_ctx)
log = relay.log(add) log = relay.log(_add)
_log = relay.annotation.on_device(log, dev_ctx) _log = relay.annotation.on_device(log, dev_ctx)
subtract = relay.subtract(sqrt, log) subtract = relay.subtract(_sqrt, _log)
_subtract = relay.annotation.on_device(subtract, dev_ctx) _subtract = relay.annotation.on_device(subtract, dev_ctx)
exp = relay.exp(subtract) exp = relay.exp(_subtract)
_exp = relay.annotation.on_device(exp, dev_ctx) _exp = relay.annotation.on_device(exp, dev_ctx)
func = relay.Function([x, y], func = relay.Function([x, y], _exp)
relay.Tuple(tvm.convert([_add, _sqrt, _log,
_subtract, _exp,
exp])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type) cpu_ctx.device_type)
...@@ -401,8 +448,7 @@ def run_fusible_network(dev, tgt): ...@@ -401,8 +448,7 @@ def run_fusible_network(dev, tgt):
exp = relay.exp(subtract) exp = relay.exp(subtract)
_exp = relay.annotation.on_device(exp, cpu_ctx) _exp = relay.annotation.on_device(exp, cpu_ctx)
func = relay.Function([x, y], func = relay.Function([x, y], _exp)
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type) dev_ctx.device_type)
...@@ -472,11 +518,9 @@ def run_unpropagatable_graph(dev, tgt): ...@@ -472,11 +518,9 @@ def run_unpropagatable_graph(dev, tgt):
_add = relay.annotation.on_device(add, dev_ctx) _add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d) mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx) _mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul) sub = relay.subtract(_add, _mul)
_sub = relay.annotation.on_device(sub, dev_ctx) _sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d], func = relay.Function([a, b, c, d], _sub)
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type) dev_ctx.device_type)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment