test_graph_gradient.py 4.66 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
import nnvm.symbol as sym
from nnvm.compiler import graph_util

def test_cnn_gradients():
    # input data
    h = 128
    w = 128
    data_shape = (1000, 3, h, w)
    data = sym.Variable('data', shape=data_shape, dtype=0)

    # conv2d
    num_channels = 64
    kernel_size = 32
    conv_w_shape = (num_channels, 3, kernel_size, kernel_size)
    conv_b_shape = (num_channels,)
    conv_w = sym.Variable('conv_w', shape=conv_w_shape)
    conv_b = sym.Variable('conv_b', shape=conv_b_shape)
    conv1 = sym.conv2d(data=data, weight=conv_w, bias=conv_b,
                      channels=num_channels, kernel_size=(kernel_size, kernel_size),
                      name='conv1')
    # relu1
    relu1 = sym.relu(data=conv1, name='relu1')
    # max pooling
    max_pooling1 = sym.max_pool2d(data=relu1, pool_size=(2, 2), name='max_pooling1')
    # flatten
    flatten1 = sym.flatten(data=max_pooling1)
    # shape after flatten
    flatten_out_shape = (h - kernel_size) * (w - kernel_size) * num_channels
    # dense1
    dense1_hidden_units = 100
    dense1 = sym.dense(data=flatten1, name='dense1', units=dense1_hidden_units)
    # relu2
    relu2 = sym.relu(data=dense1, name='relu2')
    # dense2
    dense2_hidden_units = 10
    dense2 = sym.dense(data=relu2, name='dense2', units=dense2_hidden_units)
    # softmax
    mlp = sym.softmax(data=dense2, name='softmax')
    # fake non-sparse label
    label = sym.full_like(mlp, fill_value=1)
    # cross entropy loss
    ce_loss = sym.sum(
        sym.elemwise_mul(sym.log_softmax(dense2), label),
        axis=1,
        keepdims=True,
        name="ce_loss")

    # input variables:
    # print grad_g.symbol.list_input_names()
    # >> ['data', 'conv_w', 'conv_b',
    #     'dense1_weight', 'dense1_bias',
    #     'dense2_weight', 'dense2_bias']

    # output gradient variables:
    # print grad_g.symbol.list_output_names()
    # >> ['conv1_grad_data', 'conv1_grad_weight', 'conv1_grad_bias',
    #     'dense1_grad_weight', 'dense1_grad_bias',
    #     'dense2_grad_weight', 'dense2_grad_bias']
    grad_g = graph_util.get_gradient_graph(ce_loss, ce_loss.list_input_variables())

    # infer shape
    in_shapes, out_shapes = graph_util.infer_shape(grad_g)

    # forward graph shape
    assert in_shapes == [list(data_shape), list(conv_w_shape), list(conv_b_shape),
                          [dense1_hidden_units, flatten_out_shape], [dense1_hidden_units],
                          [dense2_hidden_units, dense1_hidden_units], [dense2_hidden_units]]
    # input grads shape should be equal with input shape
    assert in_shapes == out_shapes

    # output grads w.r.t input variables
    grads = graph_util.gradients(ce_loss, ce_loss.list_input_variables())

    # gradients number should be equal with grad_input number
    assert len(grads) == len(ce_loss.list_input_variables())

    # infer type
    in_dtypes, out_dtypes = graph_util.infer_dtype(grad_g)
    assert out_dtypes == ['float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32']

def test_multi_loss_graph_gradients():
    # input data
    shape1 = (1000, 100)
    data1 = sym.Variable('data1', shape=(1000, 100), dtype=0)

    # fake non-sparse label
    label = sym.full(fill_value=3)

    # square loss
    sub1 = sym.elemwise_sub(data1, label, name="sub1")
    square_loss = sym.sum(data=sub1**2, axis=1, name="square_loss")

    # fake loss1
    shape2 = (1000, )
    data2 = sym.Variable('data2', shape=shape2, dtype=0)
    loss1 = sym.sqrt(data2, name="loss1")

    # fake loss2
    loss2 = sym.relu(data1, name='loss2')

    # block loss1
    total_loss = sym.elemwise_sum(
        sym.block_grad(loss1),
        square_loss,
        num_args=2,
        name="total_loss")

    # grad_g.symbol.list_output_names()
    # >> ['loss1_grad_0_output', 'grad_sum_output']
    grad_g = graph_util.get_gradient_graph([total_loss, loss2], total_loss.list_input_variables())
    # infer shape
    in_shapes, out_shapes = graph_util.infer_shape(grad_g)
    assert out_shapes == [list(shape2), list(shape1)]

    # grad_data1 is elemwise_sum of grad_loss2, grad_square_loss
    grad_data1 = grad_g.symbol[1]
    assert grad_data1.list_attr()['num_args'] == '2'

    # block grad should return zero grad
    grad_data2 = grad_g.symbol[0]
    assert 'zeros_like' in grad_g.ir()

    # test reverse infer shape for label
    assert grad_g.apply('InferShape').json_attr('shape_num_unknown_nodes') == 0

    # infer type
    in_dtypes, out_dtypes = graph_util.infer_dtype(grad_g)
    assert out_dtypes == ['float32', 'float32']
    
    # test reverse infer type for label
    assert grad_g.apply('InferType').json_attr('dtype_num_unknown_nodes') == 0


if __name__ == "__main__":
    test_cnn_gradients()
    test_multi_loss_graph_gradients()