Commit 2bf66660 by MORINAGA Committed by Yizhi Liu

[Heterogeneous][Bugfix] Fix bug of wrongly generated device_map (#2990)

* fix bug of device_index

* cpplint

* nose

* Update test_pass_annotation.py

* fix name of testcase

* delete comment
parent dc97e527
......@@ -334,9 +334,9 @@ class AnnotatationVisitor : private ExprVisitor {
* -Pass 1: Propagating the source device type to ops in a bottom-up way to the
* ancestors until encountering another copy op. For example, this way
* provides add, x, and y device types from the copy operator, `copy1`.
* -Pass 2: Propagating the destination device type of "the last" copy op in a
* top-down manner to the nodes on the output paths. For instance,
* this offers `subtract` and `exp` the same device type as `copy3`.
* -Pass 2: Propagating the destination device type of "the last" copy op to the
* remain nodes. For instance, this offers `subtract` and `exp` the
* same device type as `copy3`.
*/
class DeviceInfo {
......@@ -371,17 +371,22 @@ class DeviceInfo {
}
void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(cn);
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}
void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);
if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
}
}
}
......@@ -393,23 +398,27 @@ class DeviceInfo {
void VisitExpr_(const TupleGetItemNode* op) final {
ExprVisitor::VisitExpr_(op);
post_dfs_order_.push_back(op);
std::make_pair(op, has_copy_);
}
void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); }
void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}
void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
}
void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
}
int num_device_copy_ops_{0};
std::vector<const ExprNode*> post_dfs_order_;
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
friend DeviceInfo;
};
......@@ -435,46 +444,41 @@ class DeviceInfo {
void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
// Top-down propagation.
TopDownPropagation();
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}
void BottomUpPropagation() {
int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(*it)) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(*it), attrs->dst_dev_type);
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first),
attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
device_map_.Set(expr, cur_dev_type);
if (it->second) device_map_.Set(expr, cur_dev_type);
}
}
return out_dev_type;
}
void TopDownPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
if (const auto* node = GetDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it);
if (device_map_.count(expr) == 0) {
device_map_.Set(expr, cur_dev_type);
}
}
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}
PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
......@@ -503,3 +507,4 @@ TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
} // namespace relay
} // namespace tvm
......@@ -231,7 +231,7 @@ def test_conv_network():
check_storage_and_device_types()
def test_fusible_network():
def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
\ /
......@@ -417,15 +417,91 @@ def test_fusible_network():
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func)
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)
def run_unpropagatable_graph(dev, tgt):
R""" The network is as following:
a b c d
\ / \ /
add mul
\ /
subtract
"""
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
c = relay.var("c", shape=(10, 10))
d = relay.var("d", shape=(10, 10))
a_data = np.random.rand(10, 10).astype('float32')
b_data = np.random.rand(10, 10).astype('float32')
c_data = np.random.rand(10, 10).astype('float32')
d_data = np.random.rand(10, 10).astype('float32')
tmp_add = a_data + b_data
tmp_mul = np.multiply(c_data, d_data)
ref_res = np.subtract(tmp_add, tmp_mul)
fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", dev: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(dev)
def annotated():
add = relay.add(a, b)
_add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul)
_sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d],
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[3]),
func.body[3])
def expected():
add = relay.add(a, b)
mul = relay.multiply(c, d)
copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx)
sub = relay.subtract(add, copy_mul_sub)
func = relay.Function([a, b, c, d], sub)
return func
annotated_func = annotated()
expected_func = expected()
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
config = {"opt_level": 0}
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)
if "device_index" in graph_json["attrs"]:
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
def test_check_run():
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
continue
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)
run_fusible_network(dev, tgt)
run_unpropagatable_graph(dev, tgt)
if __name__ == "__main__":
......@@ -433,4 +509,4 @@ if __name__ == "__main__":
test_annotate_all()
test_annotate_none()
test_conv_network()
test_fusible_network()
test_check_run()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment