import tvm from tvm.contrib import util, rpc import tvm_graph as tg import numpy as np import os def test_rpc_executor(): host = 'localhost' port = 9091 server = rpc.Server(host, port) tmp = util.tempdir() sym_fname = tmp.relpath('net.json') lib_fname = tmp.relpath('net.o') param_fname = tmp.relpath('net.param') x = tg.Variable('x') y = tg.Variable('y') sym = tg.exp(y + x) + tg.exp(x + y) shape = (10, 128) dtype = tvm.float32 na = tvm.nd.array(np.ones(shape).astype(dtype)) nb = tvm.nd.array(np.ones(shape).astype(dtype)) tg.save_params(param_fname, {'x': na, 'y': nb}) remote = rpc.connect(host, port) ctx = remote.cpu(0) target = "llvm" shapes = {'x': shape, 'y': shape} sym_json = tg.compile_graph(lib_fname, sym, target, shapes) remote.upload(lib_fname) param_blob = bytearray(open(param_fname, "rb").read()) rm = tg.remote_load_exec(remote, sym_json, os.path.basename(lib_fname), param_blob, ctx) run, get_output = rm['run'], rm['get_output'] nc = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) run() get_output(0, nc) npa = na.asnumpy() npb = nb.asnumpy() np.testing.assert_allclose(nc.asnumpy(), np.exp(npa + npb) + np.exp(npb + npa)) server.terminate() if __name__ == "__main__": test_rpc_executor()