test_compiler_cache.py 1.39 KB
Newer Older
1 2
import numpy as np
import tvm
3
from tvm.contrib import graph_runtime
4 5 6 7 8 9 10 11 12 13 14
import nnvm.symbol as sym
import nnvm.compiler

def test_compile_cache():
    x = sym.Variable("x")
    y = sym.Variable("y")
    z = sym.exp(y + x)
    shape = (10, 1)
    dtype = tvm.float32
    shape_dict = {"x": shape, "y": shape}
    def verify(graph, lib):
15
        m = graph_runtime.create(graph, lib, tvm.cpu(0))
16 17 18 19 20 21
        # get member functions
        na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
        nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
        m.run(x=na, y=nb)
        # get outputs
        out = m.get_output(0, tvm.nd.empty(shape, dtype))
22
        tvm.testing.assert_allclose(
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
            out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))

    engine = nnvm.compiler.engine
    graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
    inputs = [tvm.placeholder((10,)), tvm.placeholder((10,))]

    gkey = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs, "llvm")
    gkey2 = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs + inputs, "llvm")
    gf = engine[gkey]
    assert gf is not None
    assert engine[gkey2] is None
    graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
    assert graph.index.num_nodes == 3
    verify(graph, lib)
    # Test various set external cache
    engine.clear_cache()
    engine[gkey] = gf

if __name__ == "__main__":
    test_compile_cache()