Commit 1e270aa4 by Zhi Committed by Tianqi Chen

[Relay]fix heterogenous annotation bug (#2622)

parent f23a7a54
......@@ -337,12 +337,17 @@ class DeviceInfo {
private:
class PostDfsOrderVisitor : private ExprVisitor {
public:
void Visit(const Expr& expr) { this->VisitExpr(expr); }
void Visit(const Expr& expr) {
if (const auto* fn = expr.as<FunctionNode>()) {
this->VisitExpr(fn->body);
} else {
this->VisitExpr(expr);
}
}
private:
// Post order traversal.
void VisitExpr_(const FunctionNode* fn) final {
ExprVisitor::VisitExpr_(fn);
// TODO(zhiics) Skip annotation of function node for now.
}
......@@ -356,7 +361,7 @@ class DeviceInfo {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);
if (IsDeviceCopyNode(call)) {
if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++;
}
}
......@@ -389,6 +394,26 @@ class DeviceInfo {
friend DeviceInfo;
};
/*
* \brief Returns a device copy node based on the current expr node. It
* returns a device copy node either the current expr node is a device copy
* node or the current expr node is a function node whose body is a device
* copy node (i.e. the fused function of a device copy call node).
*/
static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
if (IsDeviceCopyNode(node)) {
return node;
} else if (const auto* call_node = dynamic_cast<const CallNode*>(node)) {
if (const auto* fn = call_node->op.as<FunctionNode>()) {
const ExprNode* body = fn->body.operator->();
if (IsDeviceCopyNode(body)) {
return body;
}
}
}
return nullptr;
}
void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
......@@ -401,11 +426,11 @@ class DeviceInfo {
int cur_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (IsDeviceCopyNode(*it)) {
last_copy_node = dynamic_cast<const CallNode*>(*it);
if (const auto* node = GetDeviceCopyNode(*it)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(last_copy_node), attrs->dst_dev_type);
device_map_.Set(GetRef<Expr>(*it), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
CHECK_EQ(device_map_.count(expr), 0U);
......@@ -418,8 +443,8 @@ class DeviceInfo {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
for (const auto& it : post_visitor_.post_dfs_order_) {
if (IsDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(it);
if (const auto* node = GetDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
......
"""Unit tests for heterogeneous compilation and execution."""
import json
import numpy as np
import tvm
......@@ -72,6 +73,7 @@ def test_annotate_all():
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
def test_annotate_none():
ctx1 = tvm.context(1)
......@@ -203,7 +205,7 @@ def test_conv_network():
for did in storage_dev_type[1]:
device_types.append(did.value)
assert len(storage_ids) == 10
assert len(set(storage_ids)) == 7
assert len(set(storage_ids)) == 8
assert len(set(device_types)) == 2
assert set(device_types) == {1, 2}
......@@ -245,7 +247,8 @@ def test_fusible_network():
func = relay.Function([x, y], exp)
return func
def test_runtime(target, device, func, fallback_device=None):
def test_runtime(target, device, func, fallback_device=None,
expected_index=None):
params = {"x": x_data, "y": y_data}
config = {"opt_level": 1}
if fallback_device:
......@@ -256,6 +259,10 @@ def test_fusible_network():
target,
params=params)
contexts = [tvm.cpu(0), tvm.context(device)]
graph_json = json.loads(graph)
if "device_index" in graph_json["attrs"]:
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
......@@ -302,8 +309,10 @@ def test_fusible_network():
annotated_func = annotated()
expected_func = expected()
expected_index = [1, 1, 1, 2, 2, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
test_runtime(target, device, annotated_func, fallback_device,
expected_index)
def test_fuse_all(device, tgt):
"""Fuse all operators."""
......@@ -344,6 +353,7 @@ def test_fusible_network():
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", device: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(device)
def annotated():
add = relay.add(x, y)
......@@ -357,15 +367,28 @@ def test_fusible_network():
relay.Tuple(tvm.convert([_exp, exp])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[1]),
func.body[1])
def expected():
add = relay.add(x, y)
sqrt = relay.sqrt(add)
log = relay.log(add)
subtract = relay.subtract(sqrt, log)
copy_sub_exp = relay.device_copy(subtract, dev_ctx, cpu_ctx)
exp = relay.exp(copy_sub_exp)
func = relay.Function([x, y], exp)
return func
annotated_func = annotated()
expected_func = get_func()
expected_func = expected()
expected_index = [2, 2, 2, 1, 1]
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func, fallback_device)
test_runtime(target, device, annotated_func, fallback_device,
expected_index)
def test_fallback_all_operators(device, tgt):
target = {device: tgt}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment