test_graph.py 4.92 KB
Newer Older
1
import json
2 3
import nnvm.symbol as sym
import nnvm.graph as graph
4
import nnvm.compiler.graph_util as graph_util
5 6 7

def test_json_pass():
    x = sym.Variable('x')
8
    y = sym.dense(data=x, name='conv', units=30)
9 10
    g = graph.create(y)
    ret = g.apply('SaveJSON')
11
    ret._set_json_attr('json', ret.json_attr('json'))
12
    g2 = ret.apply('LoadJSON')
13
    assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
14 15 16
    json = g.json()
    g2 = graph.load_json(json)
    assert json == g2.json()
17

18 19 20

def test_json_pass_with_attr():
    x = sym.Variable('x')
21
    y = sym.dense(data=x, name='fc', units=30)
22 23 24 25 26 27 28 29 30
    g = graph.create(y)
    g._set_json_attr('version', '0.1.0')
    ret = g.apply('SaveJSON')
    json_str = ret.json_attr('json')
    ret._set_json_attr('json', json_str)
    g2 = ret.apply('LoadJSON')
    assert g2.json_attr('version') == '0.1.0'


31 32
def test_graph_json_attr():
    x = sym.Variable('x')
33
    y = sym.dense(data=x, name='fc', units=30)
34 35 36
    g = graph.create(y)
    g._set_json_attr('ilist', [1,2,3], 'list_int')
    assert g.json_attr('ilist') == [1,2,3]
37

38 39 40
def test_list_args():
    x = sym.Variable('x')
    z = sym.Variable('z')
41 42
    y = sym.dense(data=x, name='fc', units=30)
    y = sym.elemwise_add(y, z, name='add1')
43

44
def test_infer_shape():
45 46 47
    x = sym.Variable('x', shape=(2, 4, 2))
    y = sym.elemwise_add(x, x, name='add1')
    y = sym.flatten(y, name="flatten")
48 49 50 51 52 53 54
    g = graph.create(y)
    g._set_json_attr("shape_attr_key", "shape")
    g = g.apply('InferShape')
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
55 56
    assert g.json_attr('shape')[jnode_row_ptr[nindex["flatten"]]] == [2, 8]
    assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [2, 4, 2]
57

58
def test_infer_shape_known_partial():
59 60 61
    x = sym.Variable('x')
    y = sym.elemwise_add(x, x, name='add1')
    y = sym.flatten(y, name="flatten1")
62 63
    g = graph.create(y)
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
64
    shape = [[2, 4, 2], [] , []]
65 66 67 68 69
    g._set_json_attr("shape", shape, 'list_shape')
    g = g.apply("InferShape")
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
70 71
    assert g.json_attr('shape')[jnode_row_ptr[nindex["flatten1"]]] == [2, 8]
    assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [2, 4, 2]
72

73
def test_infer_type():
74
    x = sym.Variable('x', dtype=0)
75 76
    y = sym.elemwise_add(x, x, name='add1')
    y = sym.cast(y, dtype="float64", name="cast1")
77
    g = graph.create(y)
78
    g._set_json_attr("dtype_attr_key", "dtype")
79 80 81 82 83 84 85 86
    g = g.apply('InferType')
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0

87 88
def test_plan_memory():
    x = sym.Variable('x', shape=(4, 2))
89 90 91 92
    x2 = sym.elemwise_add(x, x, name='addk')
    y = sym.flatten(x2, name="reshapek")
    y = sym.elemwise_add(y, x2, name="add2")
    y = sym.elemwise_add(y, y)
93 94 95 96 97 98 99 100 101 102 103 104 105
    g = graph.create(y)
    g._set_json_attr("shape_attr_key", "shape")
    g = g.apply(["InferShape", "InferType", "PlanMemory"])
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    storage_id = g.json_attr('storage_id')
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert (storage_id[jnode_row_ptr[nindex["addk"]]] !=
            storage_id[jnode_row_ptr[nindex["reshapek"]]])
    assert (storage_id[jnode_row_ptr[nindex["add2"]]] ==
            storage_id[jnode_row_ptr[nindex["reshapek"]]])

106 107 108 109 110 111 112 113 114 115
def test_print_graph_ir():
    x = sym.Variable("x", shape=(1, 1, 10, 20))
    y = sym.conv2d(x + 1, name="y", channels=10, kernel_size=(3,3))
    g = graph.create(y)
    g = g.apply("InferShape")
    ir1 = g.ir()
    ir2 = g.ir(join_entry_attrs=["shape"])
    assert("y_bias" in ir1)
    assert("shape=" in ir2)

Yao Wang committed
116 117 118 119 120
def test_gradient():
    x = sym.Variable("x")
    y = sym.Variable("y")
    z1 = sym.elemwise_add(x, sym.sqrt(y))
    z2 = sym.log(x)
121
    gradient = graph_util.gradients([z1, z2], [x, y])
Yao Wang committed
122 123 124 125 126
    assert len(gradient) == 2

    g1 = sym.Variable("g1")
    g2 = sym.Variable("g2")
    grad_ys = [g1, g2]
127
    gradient = graph_util.gradients(sym.Group([z1, z2]),
Yao Wang committed
128 129 130 131 132
                               sym.Group([x, y]), grad_ys=grad_ys)
    g_graph = graph.create(sym.Group(gradient)).ir()
    assert len(gradient) == 2
    assert "g1" in g_graph
    assert "g2" in g_graph
133

134
if __name__ == "__main__":
135
    test_print_graph_ir()
136
    test_json_pass_with_attr()
137
    test_graph_json_attr()
138
    test_json_pass()
139
    test_infer_shape()
140
    test_infer_shape_known_partial()
141
    test_infer_type()
142
    test_plan_memory()
143
    test_list_args()
Yao Wang committed
144
    test_gradient()