Commit 1e270aa4 by Zhi Committed by Tianqi Chen

[Relay]fix heterogenous annotation bug (#2622)

parent f23a7a54
...@@ -337,12 +337,17 @@ class DeviceInfo { ...@@ -337,12 +337,17 @@ class DeviceInfo {
private: private:
class PostDfsOrderVisitor : private ExprVisitor { class PostDfsOrderVisitor : private ExprVisitor {
public: public:
void Visit(const Expr& expr) { this->VisitExpr(expr); } void Visit(const Expr& expr) {
if (const auto* fn = expr.as<FunctionNode>()) {
this->VisitExpr(fn->body);
} else {
this->VisitExpr(expr);
}
}
private: private:
// Post order traversal. // Post order traversal.
void VisitExpr_(const FunctionNode* fn) final { void VisitExpr_(const FunctionNode* fn) final {
ExprVisitor::VisitExpr_(fn);
// TODO(zhiics) Skip annotation of function node for now. // TODO(zhiics) Skip annotation of function node for now.
} }
...@@ -356,7 +361,7 @@ class DeviceInfo { ...@@ -356,7 +361,7 @@ class DeviceInfo {
ExprVisitor::VisitExpr_(call); ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call); post_dfs_order_.push_back(call);
if (IsDeviceCopyNode(call)) { if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++; num_device_copy_ops_++;
} }
} }
...@@ -389,6 +394,26 @@ class DeviceInfo { ...@@ -389,6 +394,26 @@ class DeviceInfo {
friend DeviceInfo; friend DeviceInfo;
}; };
/*
* \brief Returns a device copy node based on the current expr node. It
* returns a device copy node either the current expr node is a device copy
* node or the current expr node is a function node whose body is a device
* copy node (i.e. the fused function of a device copy call node).
*/
static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
if (IsDeviceCopyNode(node)) {
return node;
} else if (const auto* call_node = dynamic_cast<const CallNode*>(node)) {
if (const auto* fn = call_node->op.as<FunctionNode>()) {
const ExprNode* body = fn->body.operator->();
if (IsDeviceCopyNode(body)) {
return body;
}
}
}
return nullptr;
}
void PropagateDeviceId() { void PropagateDeviceId() {
// Bottom-up propagation. // Bottom-up propagation.
BottomUpPropagation(); BottomUpPropagation();
...@@ -401,11 +426,11 @@ class DeviceInfo { ...@@ -401,11 +426,11 @@ class DeviceInfo {
int cur_dev_type = -1; int cur_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin(); for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) { it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (IsDeviceCopyNode(*it)) { if (const auto* node = GetDeviceCopyNode(*it)) {
last_copy_node = dynamic_cast<const CallNode*>(*it); last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>(); const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type; cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(last_copy_node), attrs->dst_dev_type); device_map_.Set(GetRef<Expr>(*it), attrs->dst_dev_type);
} else if (last_copy_node) { } else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it); Expr expr = GetRef<Expr>(*it);
CHECK_EQ(device_map_.count(expr), 0U); CHECK_EQ(device_map_.count(expr), 0U);
...@@ -418,8 +443,8 @@ class DeviceInfo { ...@@ -418,8 +443,8 @@ class DeviceInfo {
const CallNode* last_copy_node = nullptr; const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1; int cur_dev_type = -1;
for (const auto& it : post_visitor_.post_dfs_order_) { for (const auto& it : post_visitor_.post_dfs_order_) {
if (IsDeviceCopyNode(it)) { if (const auto* node = GetDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(it); last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>(); const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type; cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) { } else if (last_copy_node) {
......
"""Unit tests for heterogeneous compilation and execution.""" """Unit tests for heterogeneous compilation and execution."""
import json
import numpy as np import numpy as np
import tvm import tvm
...@@ -72,6 +73,7 @@ def test_annotate_all(): ...@@ -72,6 +73,7 @@ def test_annotate_all():
annotated_func = relay.ir_pass.infer_type(annotated()) annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected()) expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_none(): def test_annotate_none():
ctx1 = tvm.context(1) ctx1 = tvm.context(1)
...@@ -203,7 +205,7 @@ def test_conv_network(): ...@@ -203,7 +205,7 @@ def test_conv_network():
for did in storage_dev_type[1]: for did in storage_dev_type[1]:
device_types.append(did.value) device_types.append(did.value)
assert len(storage_ids) == 10 assert len(storage_ids) == 10
assert len(set(storage_ids)) == 7 assert len(set(storage_ids)) == 8
assert len(set(device_types)) == 2 assert len(set(device_types)) == 2
assert set(device_types) == {1, 2} assert set(device_types) == {1, 2}
...@@ -245,7 +247,8 @@ def test_fusible_network(): ...@@ -245,7 +247,8 @@ def test_fusible_network():
func = relay.Function([x, y], exp) func = relay.Function([x, y], exp)
return func return func
def test_runtime(target, device, func, fallback_device=None): def test_runtime(target, device, func, fallback_device=None,
expected_index=None):
params = {"x": x_data, "y": y_data} params = {"x": x_data, "y": y_data}
config = {"opt_level": 1} config = {"opt_level": 1}
if fallback_device: if fallback_device:
...@@ -256,6 +259,10 @@ def test_fusible_network(): ...@@ -256,6 +259,10 @@ def test_fusible_network():
target, target,
params=params) params=params)
contexts = [tvm.cpu(0), tvm.context(device)] contexts = [tvm.cpu(0), tvm.context(device)]
graph_json = json.loads(graph)
if "device_index" in graph_json["attrs"]:
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_runtime.create(graph, lib, contexts) mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params) mod.set_input(**params)
mod.run() mod.run()
...@@ -302,8 +309,10 @@ def test_fusible_network(): ...@@ -302,8 +309,10 @@ def test_fusible_network():
annotated_func = annotated() annotated_func = annotated()
expected_func = expected() expected_func = expected()
expected_index = [1, 1, 1, 2, 2, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func) check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device) test_runtime(target, device, annotated_func, fallback_device,
expected_index)
def test_fuse_all(device, tgt): def test_fuse_all(device, tgt):
"""Fuse all operators.""" """Fuse all operators."""
...@@ -344,6 +353,7 @@ def test_fusible_network(): ...@@ -344,6 +353,7 @@ def test_fusible_network():
fallback_device = tvm.context("cpu") fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt} target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated(): def annotated():
add = relay.add(x, y) add = relay.add(x, y)
...@@ -357,15 +367,28 @@ def test_fusible_network(): ...@@ -357,15 +367,28 @@ def test_fusible_network():
relay.Tuple(tvm.convert([_exp, exp]))) relay.Tuple(tvm.convert([_exp, exp])))
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func, func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type) dev_ctx.device_type)
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]), return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1]) func.body[1])
def expected():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
copy_sub_exp = relay.device_copy(subtract, dev_ctx, cpu_ctx)
exp = relay.exp(copy_sub_exp)
func = relay.Function([x, y], exp)
return func
annotated_func = annotated() annotated_func = annotated()
expected_func = get_func() expected_func = expected()
expected_index = [2, 2, 2, 1, 1]
check_annotated_graph(annotated_func, expected_func) check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device) test_runtime(target, device, annotated_func, fallback_device,
expected_index)
def test_fallback_all_operators(device, tgt): def test_fallback_all_operators(device, tgt):
target = {device: tgt} target = {device: tgt}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment