"""Unit tests for heterogeneous compilation and execution.""" import json import numpy as np import tvm from tvm import relay from tvm.contrib import graph_runtime def test_redundant_annotation(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) x = relay.var("x", shape=(3,)) y = relay.var("y", shape=(3,)) z = relay.var("z", shape=(3,)) def annotated(): add = relay.add(x, y) _add1 = relay.annotation.on_device(add, ctx2) _add2 = relay.annotation.on_device(add, ctx2) sub = relay.subtract(add, z) func = relay.Function([x, y, z], relay.Tuple(tvm.convert([_add1, _add2, sub]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[2]), func.body[2]) def expected(): add = relay.add(x, y) copy_add_sub = relay.device_copy(add, ctx2, ctx1) sub = relay.subtract(copy_add_sub, z) func = relay.Function([x, y, z], sub) return func annotated_func = relay.ir_pass.infer_type(annotated()) expected_func = relay.ir_pass.infer_type(expected()) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) def test_annotate_all(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) x = relay.var("x", shape=(3,)) y = relay.var("y", shape=(3,)) z = relay.var("z", shape=(3,)) def annotated(): add = relay.add(x, y) _add = relay.annotation.on_device(add, ctx2) sub = relay.subtract(add, z) _sub = relay.annotation.on_device(sub, ctx2) func = relay.Function([x, y, z], relay.Tuple(tvm.convert([_add, _sub, sub]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[2]), func.body[2]) def expected(): add = relay.add(x, y) sub = relay.subtract(add, z) func = relay.Function([x, y, z], sub) return func annotated_func = relay.ir_pass.infer_type(annotated()) expected_func = relay.ir_pass.infer_type(expected()) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) def test_annotate_none(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) x = relay.var("x", shape=(3,)) y = relay.var("y", shape=(3,)) z = relay.var("z", shape=(3,)) def annotated(): add = relay.add(x, y) sub = relay.subtract(add, z) func = relay.Function([x, y, z], sub) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) return func def expected(): add = relay.add(x, y) sub = relay.subtract(add, z) func = relay.Function([x, y, z], sub) return func annotated_func = relay.ir_pass.infer_type(annotated()) expected_func = relay.ir_pass.infer_type(expected()) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) def check_annotated_graph(annotated_func, expected_func): annotated_func = relay.ir_pass.infer_type(annotated_func) expected_func = relay.ir_pass.infer_type(expected_func) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) def test_conv_network(): R""" The network is as following: data1 data2 | | conv2d conv2d \ / add | conv2d """ batch_size = 1 dshape = (batch_size, 64, 56, 56) weight = relay.var("weight", shape=(64, 64, 3, 3)) data1 = relay.var("data1", shape=dshape) data2 = relay.var("data2", shape=dshape) dev1 = tvm.context(1) dev2 = tvm.context(2) def annotated(): conv2d_1 = relay.nn.conv2d( data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) _conv2d_1 = relay.annotation.on_device(conv2d_1, dev2) conv2d_2 = relay.nn.conv2d( data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) _conv2d_2 = relay.annotation.on_device(conv2d_2, dev2) add = relay.add(conv2d_1, conv2d_2) _add = relay.annotation.on_device(add, dev1) conv2d_3 = relay.nn.conv2d( add, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) _conv2d_3 = relay.annotation.on_device(conv2d_3, dev2) func = relay.Function([data1, data2, weight], relay.Tuple(tvm.convert([_conv2d_1, _conv2d_2, _conv2d_3, _add, conv2d_3]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, tvm.context(3).device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[4]), func.body[4]) def expected(): conv2d_1 = relay.nn.conv2d( data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) device_copy1 = relay.device_copy(conv2d_1, dev2, dev1) conv2d_2 = relay.nn.conv2d( data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) device_copy2 = relay.device_copy(conv2d_2, dev2, dev1) add = relay.add(device_copy1, device_copy2) device_copy3 = relay.device_copy(add, dev1, dev2) conv2d_3 = relay.nn.conv2d( device_copy3, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) func = relay.Function([data1, weight, data2], conv2d_3) return func def check_storage_and_device_types(): func = annotated() func = relay.ir_pass.rewrite_annotated_ops(func, 3) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.fuse_ops(func, opt_level=2) func = relay.ir_pass.infer_type(func) smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] for _, storage_dev_type in smap.items(): assert len(storage_dev_type) == 2 for sid in storage_dev_type[0]: storage_ids.append(sid.value) for did in storage_dev_type[1]: device_types.append(did.value) assert len(storage_ids) == 10 assert len(set(storage_ids)) == 8 assert len(set(device_types)) == 2 assert set(device_types) == {1, 2} annotated_func = annotated() expected_func = expected() check_annotated_graph(annotated_func, expected_func) check_storage_and_device_types() def test_fusible_network(): R""" The network is as following: x y \ / add / \ sqrt log \ / subtract | exp """ x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(10, 10)) x_data = np.random.rand(1, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') tmp_add = x_data + y_data tmp_sqrt = np.sqrt(tmp_add) tmp_log = np.log(tmp_add) tmp_sub = np.subtract(tmp_sqrt, tmp_log) ref_res = np.exp(tmp_sub) def get_func(): add = relay.add(x, y) sqrt = relay.sqrt(add) log = relay.log(add) subtract = relay.subtract(sqrt, log) exp = relay.exp(subtract) func = relay.Function([x, y], exp) return func def test_runtime(target, device, func, fallback_device=None, expected_index=None): params = {"x": x_data, "y": y_data} config = {"opt_level": 1} if fallback_device: config["fallback_device"] = fallback_device with relay.build_config(**config): graph, lib, params = relay.build( func, target, params=params) contexts = [tvm.cpu(0), tvm.context(device)] graph_json = json.loads(graph) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] assert device_index == expected_index mod = graph_runtime.create(graph, lib, contexts) mod.set_input(**params) mod.run() res = mod.get_output(0).asnumpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) def test_fuse_log_add(device, tgt): """ Only log and add are fused.""" fallback_device = tvm.context("cpu") target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) def annotated(): add = relay.add(x, y) sqrt = relay.sqrt(add) _sqrt = relay.annotation.on_device(sqrt, dev_ctx) log = relay.log(add) subtract = relay.subtract(sqrt, log) exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, dev_ctx) func = relay.Function([x, y], relay.Tuple(tvm.convert([_sqrt, _exp, exp]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[2]), func.body[2]) def expected(): add = relay.add(x, y) copy_add_sqrt = relay.device_copy(add, cpu_ctx, dev_ctx) sqrt = relay.sqrt(copy_add_sqrt) log = relay.log(add) copy_sqrt_subtract = relay.device_copy(sqrt, dev_ctx, cpu_ctx) subtract = relay.subtract(copy_sqrt_subtract, log) copy_sub_exp = relay.device_copy(subtract, cpu_ctx, dev_ctx) exp = relay.exp(copy_sub_exp) func = relay.Function([x, y], exp) return func annotated_func = annotated() expected_func = expected() expected_index = [1, 1, 1, 2, 2, 1, 1, 2, 2] check_annotated_graph(annotated_func, expected_func) test_runtime(target, device, annotated_func, fallback_device, expected_index) def test_fuse_all(device, tgt): """Fuse all operators.""" fallback_device = tvm.context("cpu") target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) def annotated(): add = relay.add(x, y) _add = relay.annotation.on_device(add, dev_ctx) sqrt = relay.sqrt(add) _sqrt = relay.annotation.on_device(sqrt, dev_ctx) log = relay.log(add) _log = relay.annotation.on_device(log, dev_ctx) subtract = relay.subtract(sqrt, log) _subtract = relay.annotation.on_device(subtract, dev_ctx) exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, dev_ctx) func = relay.Function([x, y], relay.Tuple(tvm.convert([_add, _sqrt, _log, _subtract, _exp, exp]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[5]), func.body[5]) annotated_func = annotated() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) test_runtime(target, device, annotated_func, fallback_device) def test_fallback_exp(device, tgt): fallback_device = tvm.context("cpu") target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) def annotated(): add = relay.add(x, y) sqrt = relay.sqrt(add) log = relay.log(add) subtract = relay.subtract(sqrt, log) exp = relay.exp(subtract) _exp = relay.annotation.on_device(exp, cpu_ctx) func = relay.Function([x, y], relay.Tuple(tvm.convert([_exp, exp]))) func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) func = relay.ir_pass.infer_type(func) return relay.Function(relay.ir_pass.free_vars(func.body[1]), func.body[1]) def expected(): add = relay.add(x, y) sqrt = relay.sqrt(add) log = relay.log(add) subtract = relay.subtract(sqrt, log) copy_sub_exp = relay.device_copy(subtract, dev_ctx, cpu_ctx) exp = relay.exp(copy_sub_exp) func = relay.Function([x, y], exp) return func annotated_func = annotated() expected_func = expected() expected_index = [2, 2, 2, 1, 1] check_annotated_graph(annotated_func, expected_func) test_runtime(target, device, annotated_func, fallback_device, expected_index) def test_fallback_all_operators(device, tgt): target = {device: tgt} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) test_runtime(target, device, annotated_func) for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), ("opencl", str(tvm.target.intel_graphics()))]: if not tvm.module.enabled(dev): print("Skip test because %s is not enabled." % dev) continue test_fuse_log_add(dev, tgt) test_fuse_all(dev, tgt) test_fallback_exp(dev, tgt) test_fallback_all_operators(dev, tgt) if __name__ == "__main__": test_redundant_annotation() test_annotate_all() test_annotate_none() test_conv_network() test_fusible_network()