test_top_level3.py 2.01 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
24
from nnvm.testing.check_computation import check_function
25

26
def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
27 28
    x = sym.Variable("x")
    y = symfunc(x)
29 30 31
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, lambda x: np_func(x), np_backward,
                   dtype=dtype, shape=shape, in_range=(rnd_min, rnd_max))
32 33 34 35 36 37 38 39 40 41 42 43 44 45


def test_floor():
    check_map(sym.floor, np.floor)

def test_ceil():
    check_map(sym.ceil, np.ceil)

def test_trunc():
    check_map(sym.trunc, np.trunc)

def test_round():
    check_map(sym.round, np.round)

46 47 48 49
def test_abs():
    check_map(sym.abs, np.abs)
    check_map(sym.abs, np.abs, dtype = "int32")
    check_map(sym.abs, np.abs, dtype = "int8")
50

51 52 53 54 55 56
def test_shift():
    n = 3
    for dtype in ["int32", "int8"]:
        check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
        check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)

57
if __name__ == "__main__":
58
    test_shift()
59 60 61
    test_floor()
    test_ceil()
    test_round()
62
    test_abs()
63
    test_trunc()