test_top_assign.py 1.05 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
import numpy as np

import tvm
from tvm.contrib import graph_runtime

import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list


def test_update():
    w = sym.Variable("w")
    w2 = sym.Variable("w2")
    w = sym._assign(w, w + 1)
    w2 = sym._assign(w2, w + 1)

    dshape = (5, 3, 18, 18)
    shape_dict = {"w": dshape, "w2":dshape}
    dtype = "float32"

    def check(target, ctx):
        graph, lib, _ = nnvm.compiler.build(w2, target, shape_dict)

        m = graph_runtime.create(graph, lib, ctx)

        data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
        m.set_input("w", data)
        m.run()
        out = m.get_input("w2", tvm.nd.empty(dshape, dtype))
30
        tvm.testing.assert_allclose(out.asnumpy(), data.asnumpy() + 2, rtol=1e-5)
31 32 33

        m.run()
        out = m.get_input("w2", tvm.nd.empty(dshape, dtype))
34
        tvm.testing.assert_allclose(out.asnumpy(), data.asnumpy() + 3, rtol=1e-5)
35 36 37 38 39 40 41

    for target, ctx in ctx_list():
        check(target, ctx)


if __name__ == "__main__":
    test_update()