# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm import numpy as np import json from tvm import rpc from tvm.contrib import util, graph_runtime def test_graph_simple(): n = 4 A = tvm.placeholder((n,), name='A') B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = tvm.create_schedule(B.op) node0 = {"op": "null", "name": "x", "inputs": []} node1 = {"op": "tvm_op", "name": "add", "inputs": [[0, 0, 0]], "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs" : "1", "num_outputs" : "1"}} nodes = [node0, node1] arg_nodes = [0] node_row_ptr = [0, 1, 2] outputs = [[1, 0, 0]] shape = (4,) attrs = { "shape" : ["list_shape", [shape, shape]], "dltype" : ["list_str", ["float32", "float32"]], "storage_id" : ["list_int", [0, 1]], } graph = {"nodes": nodes, "arg_nodes": arg_nodes, "node_row_ptr": node_row_ptr, "heads": outputs, "attrs": attrs} graph = json.dumps(graph) def check_verify(): if not tvm.module.enabled("llvm"): print("Skip because llvm is not enabled") return mlib = tvm.build(s, [A, B], "llvm", name="myadd") mod = graph_runtime.create(graph, mlib, tvm.cpu(0)) a = np.random.uniform(size=(n,)).astype(A.dtype) mod.run(x=a) out = mod.get_output(0, tvm.nd.empty((n,))) np.testing.assert_equal(out.asnumpy(), a + 1) def check_remote(): if not tvm.module.enabled("llvm"): print("Skip because llvm is not enabled") return mlib = tvm.build(s, [A, B], "llvm", name="myadd") server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) temp = util.tempdir() ctx = remote.cpu(0) path_dso = temp.relpath("dev_lib.so") mlib.export_library(path_dso) remote.upload(path_dso) mlib = remote.load_module("dev_lib.so") mod = graph_runtime.create(graph, mlib, remote.cpu(0)) a = np.random.uniform(size=(n,)).astype(A.dtype) mod.run(x=tvm.nd.array(a, ctx)) out = tvm.nd.empty((n,), ctx=ctx) out = mod.get_output(0, out) np.testing.assert_equal(out.asnumpy(), a + 1) def check_sharing(): from tvm import relay x = relay.var('x', shape=(1, 10)) y = relay.var('y', shape=(1, 10)) z = relay.add(x, y) func = relay.Function([x, y], z) x_in = np.ones((1, 10)).astype("float32") params = {'x': x_in} graph, lib, params = relay.build(func, target="llvm", params=params) if not tvm.module.enabled("llvm"): print("Skip because llvm is not enabled") return mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0)) mod_shared.load_params(relay.save_param_dict(params)) num_mods = 10 mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) for _ in range(num_mods)] for mod in mods: mod.share_params(mod_shared, relay.save_param_dict(params)) a = np.random.uniform(size=(1, 10)).astype("float32") for mod in mods: mod.run(y=a) out = mod.get_output(0, tvm.nd.empty((1, 10))) np.testing.assert_equal(out.asnumpy(), x_in + a) # Explicitly delete the shared module and verify correctness. del mod_shared for mod in mods: mod.run(y=a) out = mod.get_output(0, tvm.nd.empty((1, 10))) np.testing.assert_equal(out.asnumpy(), x_in + a) del mod check_verify() check_remote() check_sharing() if __name__ == "__main__": test_graph_simple()