Commit 3294d72b by Zhi Committed by Yizhi Liu

[Relay][heterogeneous] Fix tuple annotation (#3311)

* [Relay][heterogeneous] Fix TupleGetItem

* retrigger ci

* retrigger ci
parent 3770368f
......@@ -68,6 +68,7 @@ class ValidateAnnotation : private ExprVisitor {
private:
void VisitExpr_(const CallNode* call_node) final {
ExprVisitor::VisitExpr_(call_node);
if (IsOnDeviceNode(call_node)) {
int device_type = GetDeviceId(call_node);
if (annotation_map_.count(call_node)) {
......@@ -86,7 +87,14 @@ class ValidateAnnotation : private ExprVisitor {
annotation_map_.insert({node, GetDeviceId(call_node)});
}
}
ExprVisitor::VisitExpr_(call_node);
}
void VisitExpr_(const TupleGetItemNode* get_elem) final {
ExprVisitor::VisitExpr_(get_elem);
const auto* tn = get_elem->tuple.operator->();
if (annotation_map_.count(tn)) {
annotation_map_.insert({get_elem, annotation_map_.at(tn)});
}
}
/*
......@@ -253,7 +261,9 @@ class RewriteAnnotation : public ExprMutator {
if (src->is_type<CallNode>() || src->is_type<FunctionNode>()) {
return annotation_map_.at(dst) != fallback_device_;
} else {
return false;
// There shouldn't be any copy nodes between var/constant and another
// expression.
return !(src->is_type<VarNode>() || src->is_type<ConstantNode>());
}
} else {
return false;
......
......@@ -554,6 +554,7 @@ def run_unpropagatable_graph(dev, tgt):
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)
def test_check_run():
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
......@@ -563,7 +564,41 @@ def test_check_run():
run_fusible_network(dev, tgt)
run_unpropagatable_graph(dev, tgt)
def test_tuple_get_item():
dev = "cuda"
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
return
cpu_ctx = tvm.cpu(0)
gpu_ctx = tvm.context(dev)
def expected():
x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32"))
split = relay.op.split(x, 3)
elem0 = relay.device_copy(split[0], gpu_ctx, cpu_ctx)
elem1 = relay.device_copy(split[1], gpu_ctx, cpu_ctx)
sub = elem0 - elem1
func = relay.Function(relay.ir_pass.free_vars(sub), sub)
return func
def annotated():
x = relay.var("x", relay.ty.TensorType((3, 3, 4), "float32"))
split = relay.op.split(x, 3)
split = split.astuple()
split = relay.annotation.on_device(split, gpu_ctx)
split = relay.TupleWrapper(split, 3)
sub = split[0] - split[1]
func = relay.Function(relay.ir_pass.free_vars(sub), sub)
func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type)
return func
annotated_func = relay.ir_pass.infer_type(annotated())
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.graph_equal(annotated_func, expected_func)
if __name__ == "__main__":
test_redundant_annotation()
test_annotate_expr()
......@@ -571,3 +606,4 @@ if __name__ == "__main__":
test_annotate_none()
test_conv_network()
test_check_run()
test_tuple_get_item()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment