test_forward.py 28.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23
# pylint: disable=import-self, invalid-name, unused-argument
"""
TFLite testcases
================
This article is a test script to test TFLite operator with Relay.
"""
from __future__ import print_function
24
from functools import partial
25 26 27 28 29
import numpy as np
import tvm
from tvm import relay
import tensorflow as tf
from tensorflow.python.framework import constant_op
30 31
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
32 33 34
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
35 36 37 38
try:
    from tensorflow import lite as interpreter_wrapper
except ImportError:
    from tensorflow.contrib import lite as interpreter_wrapper
39

40
import tvm.relay.testing.tf as tf_testing
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

#######################################################################
# Generic run functions for TVM & TFLite
# --------------------------------------
def convert_to_list(x):
    if not isinstance(x, list):
        x = [x]
    return x

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                  out_names=None):
    """ Generic function to compile on relay and execute on tvm """
    try:
        import tflite.Model
    except ImportError:
        raise ImportError("The tflite package must be installed")

    # get TFLite model from buffer
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

    input_data = convert_to_list(input_data)
    input_node = convert_to_list(input_node)

    shape_dict = {}
    dtype_dict = {}
    for i, e in enumerate(input_node):
        shape_dict[e] = input_data[i].shape
        dtype_dict[e] = input_data[i].dtype.name

70 71 72
    mod, params = relay.frontend.from_tflite(tflite_model,
                                             shape_dict=shape_dict,
                                             dtype_dict=dtype_dict)
73
    with relay.build_config(opt_level=3):
74
        graph, lib, params = relay.build(mod, target, params=params)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

    ctx = tvm.context(target, 0)
    from tvm.contrib import graph_runtime
    m = graph_runtime.create(graph, lib, ctx)
    # set inputs
    for i, e in enumerate(input_node):
        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

    m.set_input(**params)
    # execute
    m.run()
    # get outputs
    assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
        out_names, num_output)
    tvm_output_list = []
    for i in range(0, num_output):
        tvm_output = m.get_output(i)
        tvm_output_list.append(tvm_output.asnumpy())
    return tvm_output_list


def run_tflite_graph(tflite_model_buf, input_data):
    """ Generic function to execute TFLite """
    input_data = convert_to_list(input_data)

    interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # set input
    assert len(input_data) == len(input_details)
    for i in range(len(input_details)):
        interpreter.set_tensor(input_details[i]['index'], input_data[i])

    # Run
    interpreter.invoke()

    # get output
    tflite_output = list()
    for i in range(len(output_details)):
        tflite_output.append(interpreter.get_tensor(output_details[i]['index']))

    return tflite_output


122 123
def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                            output_tensors, init_global_variables=False):
124
    """Generic function to generate and compare TFLite and TVM output"""
125
    in_data = convert_to_list(in_data)
126 127 128 129 130 131 132 133 134
    in_name = convert_to_list(in_name)
    in_node = [0] * len(in_name)
    for i in range(len(in_name)):
        in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]

    with tf.Session() as sess:
        if init_global_variables:
            sess.run(variables.global_variables_initializer())
        # convert to tflite model
135
        converter = interpreter_wrapper.TFLiteConverter.from_session(
136 137
            sess, input_tensors, output_tensors)
        tflite_model_buffer = converter.convert()
138
        tflite_output = run_tflite_graph(tflite_model_buffer, in_data)
139 140 141 142 143 144 145

        for device in ["llvm"]:
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                continue

146
            tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device)
147
            for i in range(len(tflite_output)):
148
                tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
149 150


151 152 153 154 155 156 157 158 159 160 161 162 163 164
def with_fused_activation_function(input_tensor, fn_name):
    if fn_name is None or fn_name == "NONE":
        return input_tensor
    if fn_name == "RELU":
        return nn_ops.relu(input_tensor)
    if fn_name == "RELU6":
        return nn_ops.relu6(input_tensor)
    if fn_name == "RELU_N1_TO_1":
        return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
    if fn_name == "TANH":
        return math_ops.tanh(input_tensor)
    raise AssertionError("Unknown fused_activation_function {}".format(fn_name))


165 166 167 168 169 170 171 172 173 174 175 176 177
#######################################################################
# Pooling
# -------
def _test_pooling_iteration(input_shape, **kwargs):
    """ One iteration of pool operation with given shapes and attributes """

    x = -np.arange(
        np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
        out = nn_ops.pool(in_data, **kwargs)

178
        compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out])
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255


def _test_pooling(input_shape, **kwargs):
    _test_pooling_iteration(input_shape, **kwargs)


def test_forward_pooling():
    """ Pooling """

    for pool_type in ['AVG', 'MAX']:
        _test_pooling(input_shape=[2, 9, 10, 2],
                      window_shape=[1, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 10, 9, 2],
                      window_shape=[1, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 9, 10, 2],
                      window_shape=[2, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 10, 9, 2],
                      window_shape=[2, 3],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[2, 1])


#######################################################################
# Convolution
# -----------

def _test_convolution(tensor_in_sizes, filter_in_sizes,
                      dilations, strides, padding, data_format,
                      is_depthwise=False):
    """ One iteration of convolution with given shapes and attributes """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
        strides = [1] + strides + [1]
        dilations = [1] + dilations + [1]

        if is_depthwise:
            out = nn_ops.depthwise_conv2d_native(in_data,
                                                 in_filter,
                                                 strides=strides,
                                                 padding=padding,
                                                 data_format=data_format)
        else:
            out = nn_ops.conv2d(in_data,
                                in_filter,
                                strides=strides,
                                padding=padding,
                                data_format=data_format)
256 257
        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282


def test_forward_convolution():
    _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
    _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')

    # depthwise convolution
    _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
    _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
    _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)


#######################################################################
# Reshape
# -------

def _test_reshape(data, out_shape):
    """ One iteration of reshape operation with given data and out shape """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = array_ops.reshape(in_data, out_shape)

283
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
284 285 286 287 288 289 290 291 292 293


def test_forward_reshape():
    _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3])
    _test_reshape(np.arange(6), [-1, 2])
    _test_reshape(np.arange(6), [3, -1])
    _test_reshape(np.arange(6), [-1])


#######################################################################
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
# Resize
# ------

def _test_resize(tf_resize_op, data, align_corners):
    """ One iteration of Resize """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
        size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
        out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
        compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor])


def test_all_resize():
    """ Resize """
    data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)]
    ### RESIZE_BILINEAR
    _test_resize(tf.image.resize_bilinear, data, align_corners=False)
    _test_resize(tf.image.resize_bilinear, data, align_corners=True)
    ### RESIZE_NEAREST_NEIGHBOR (was added in v1.13)
    # According to topi resize.h
    # Align corners not supported for nearest neighbour
    from tflite.BuiltinOperator import BuiltinOperator
    if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
        _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
# Concatenation
# -------------

def _test_concatenation(data, axis):
    """ One iteration of concatenation """

    assert len(data) >= 1

    with tf.Graph().as_default():
        in_data = [
            array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
            for idx, tensor in enumerate(data)]
        out = array_ops.concat(in_data, axis=axis)
        name = ["in_{}:0".format(idx) for idx in range(len(data))]

340
        compare_tflite_with_tvm(data, name, in_data, [out])
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359


def test_forward_concatenation():

    _test_concatenation(
        [np.arange(6).reshape((1, 2, 1, 3)),
        np.arange(6).reshape((1, 2, 1, 3))], 1)

    _test_concatenation(
        [np.arange(6).reshape((3, 2)),
         np.arange(6).reshape((3, 2))], 1)

    _test_concatenation(
        [np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3))], 1)


#######################################################################
360
# Element-wise
361 362
# ---

363
def _test_elemwise(math_op, data, fused_activation_function=None):
364
    """ One iteration of elemwise """
365 366 367 368 369 370 371

    assert len(data) == 2

    # Test with two tensors
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
                   array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
372
        out = math_op(in_data[0], in_data[1])
373
        out = with_fused_activation_function(out, fused_activation_function)
374
        compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
375 376 377 378

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
379
        out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
380
        out = with_fused_activation_function(out, fused_activation_function)
381
        compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
382 383


384 385 386 387
#######################################################################
# Add
# ---

388
def _test_add(data, fused_activation_function=None):
389
    """ One iteration of add """
390
    return _test_elemwise(math_ops.add, data, fused_activation_function)
391

392 393
#######################################################################
# Subtract
394
# --------
395

396
def _test_sub(data, fused_activation_function=None):
397
    """ One iteration of subtract """
398
    return _test_elemwise(math_ops.subtract, data, fused_activation_function)
399
#######################################################################
400 401
# Mul
# ---
402
def _test_mul(data, fused_activation_function=None):
403
    """ One iteration of mul """
404
    return _test_elemwise(math_ops.multiply, data, fused_activation_function)
405

406 407
#######################################################################
# Divide
408
# ------
409

410
def _test_div(data, fused_activation_function=None):
411
    """ One iteration of divide """
412
    return _test_elemwise(math_ops.divide, data, fused_activation_function)
413 414
#######################################################################
# Power
415
# -----
416 417 418 419 420 421

def _test_pow(data):
    """ One iteration of power """
    return _test_elemwise(math_ops.pow, data)
#######################################################################
# Maximum
422
# -------
423 424 425 426 427 428

def _test_maximum(data):
    """ One iteration of maximum """
    return _test_elemwise(math_ops.maximum, data)
#######################################################################
# Minimum
429
# -------
430 431 432 433

def _test_minimum(data):
    """ One iteration of minimum """
    return _test_elemwise(math_ops.minimum, data)
434

435 436 437
def _test_forward_elemwise(testop):
    """ Elewise"""
    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
438
               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
439
    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
440
               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
441
    testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
442
               np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])
443

444 445
def test_all_elemwise():
    _test_forward_elemwise(_test_add)
446 447
    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
448
    _test_forward_elemwise(_test_sub)
449 450
    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
451
    _test_forward_elemwise(_test_mul)
452 453
    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
454
    _test_forward_elemwise(_test_div)
455 456
    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
457 458 459
    _test_forward_elemwise(_test_pow)
    _test_forward_elemwise(_test_maximum)
    _test_forward_elemwise(_test_minimum)
460 461

#######################################################################
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
# Reduce
# ------

def _test_reduce(math_op, data, keep_dims=None):
    """ One iteration of reduce """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
        out = math_op(in_data, data[1], keep_dims)
        compare_tflite_with_tvm([data[0]], ['in:0'], [in_data], [out])


#######################################################################
# Reduce_min
# ----------

def _test_reduce_min(data, keep_dims=None):
    """ One iteration of reduce_min """
    return _test_reduce(math_ops.reduce_min, data, keep_dims)

#######################################################################
# Reduce_max
# ----------

def _test_reduce_max(data, keep_dims=None):
    """ One iteration of reduce_max """
    return _test_reduce(math_ops.reduce_max, data, keep_dims)

#######################################################################
# Reduce_mean
# -----------

def _test_reduce_mean(data, keep_dims=None):
    """ One iteration of reduce_mean """
    return _test_reduce(math_ops.reduce_mean, data, keep_dims)

#######################################################################
# Reduce_prod
# -----------

def _test_reduce_prod(data, keep_dims=None):
    """ One iteration of reduce_prod """
    return _test_reduce(math_ops.reduce_prod, data, keep_dims)


def _test_forward_reduce(testop):
    """ Reduce """
    data0 = [np.random.rand(16, 16, 16, 16).astype("float32"), None]
    data1 = [np.random.rand(16, 16, 16, 16).astype("float32"), np.array([1, 2], dtype=np.int32)]
    testop(data0)
    testop(data0, keep_dims=False)
    testop(data0, keep_dims=True)
    testop(data1)
    testop(data1, keep_dims=False)
    testop(data1, keep_dims=True)


def test_all_reduce():
    _test_forward_reduce(_test_reduce_min)
    _test_forward_reduce(_test_reduce_max)
    _test_forward_reduce(_test_reduce_mean)
    _test_forward_reduce(_test_reduce_prod)


#######################################################################
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
# Squeeze
# -------

def _test_squeeze(data, squeeze_dims=None):
    """ One iteration of squeeze """

    if squeeze_dims is None:
        squeeze_dims = []

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

        if squeeze_dims:
            out = array_ops.squeeze(in_data, squeeze_dims)
        else:
            out = array_ops.squeeze(in_data)

547
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
548 549 550 551 552 553 554


def test_forward_squeeze():
    """ Squeeze """
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2])
    _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])

555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583

#######################################################################
# Pad
# ---

def _test_pad(data):
    """ One iteration of PAD """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
        out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
        compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])


def test_forward_pad():
    """ Pad """
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
               np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)])


584
#######################################################################
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
# Logistic
# --------

def _test_logistic(data):
    """ One iteration of LOGISTIC """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = math_ops.sigmoid(in_data)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_logistic():
    """ LOGISTIC """
    _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))


#######################################################################
601 602 603 604 605 606 607 608
# Softmax
# -------

def _test_softmax(data):
    """ One iteration of softmax """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = nn_ops.softmax(in_data)
609
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
610 611 612 613 614

def test_forward_softmax():
    """ Softmax """
    _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651

#######################################################################
# Fully Connected
# -------

def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
    """ One iteration of fully connected """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
    assert int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0], \
        "input size and filter size are mismatched"

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')

        # reshape N H W C into N H*W*C
        in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])

        out = math_ops.mat_mul(in_data_reshape, in_filter)

        # if we have bias
        if bias_in_size:
            assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
            bias_array = [f * 1.0 for f in range(1, bias_in_size[0] + 1)]
            in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
            out = nn_ops.bias_add(out, in_bias)

652 653
        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
654 655 656 657 658 659 660 661 662 663


def test_forward_fully_connected():
    """ Fully Connected """
    _test_fully_connected([1, 1, 1, 150], [150, 100])
    _test_fully_connected([1, 1, 1, 150], [150, 100], [100])
    _test_fully_connected([5, 1, 1, 150], [150, 100])
    _test_fully_connected([5, 1, 1, 150], [150, 100], [100])


664 665 666
#######################################################################
# Mobilenet
# ---------
667

668
def test_forward_mobilenet_v1():
669
    """Test the Mobilenet V1 TF Lite model."""
670
    # MobilenetV1
671
    tflite_model_file = tf_testing.get_workload_official(
672
        "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
673
        "mobilenet_v1_1.0_224.tflite")
674 675
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
676 677
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
678
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
679 680
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)
681

682
def test_forward_mobilenet_v2():
683
    """Test the Mobilenet V2 TF Lite model."""
684 685 686 687 688 689 690 691
    # MobilenetV2
    tflite_model_file = tf_testing.get_workload_official(
        "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
        "mobilenet_v2_1.0_224.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
692
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
693 694 695
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

696
#######################################################################
697
# Inception
698 699 700
# ------------

def test_forward_inception_v3_net():
701
    """Test the Inception V3 TF Lite model."""
702 703 704
    # InceptionV3
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
705
        "inception_v3.tflite")
706 707 708 709
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
710
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
711 712
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)
713

714 715 716 717 718 719 720 721 722 723
def test_forward_inception_v4_net():
    """Test the Inception V4 TF Lite model."""
    # InceptionV4
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
        "inception_v4.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
724
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
725 726 727
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

728
#######################################################################
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
# SSD Mobilenet
# -------------

def test_forward_ssd_mobilenet_v1():
    """Test the SSD Mobilenet V1 TF Lite model."""
    # SSD MobilenetV1
    tflite_model_file = tf_testing.get_workload_official(
        "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz",
        "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

#######################################################################
747 748 749 750
# Main
# ----
if __name__ == '__main__':
    # Transforms
751
    test_forward_concatenation()
752
    test_forward_pad()
753
    test_forward_reshape()
754
    test_all_resize()
755 756 757 758
    test_forward_squeeze()

    # NN
    test_forward_convolution()
759
    test_forward_logistic()
760 761
    test_forward_pooling()
    test_forward_softmax()
762
    test_forward_fully_connected()
763

764 765
    # Elemwise
    test_all_elemwise()
766

767 768 769
    # Reduce
    test_all_reduce()

770
    # End to End
771 772
    test_forward_mobilenet_v1()
    test_forward_mobilenet_v2()
773
    test_forward_inception_v3_net()
774
    test_forward_inception_v4_net()
775
    test_forward_ssd_mobilenet_v1()