test_top_level3.py 1.25 KB
Newer Older
1 2 3 4 5 6 7
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
8
from nnvm.testing.check_computation import check_function
9

10
def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
11 12
    x = sym.Variable("x")
    y = symfunc(x)
13 14 15
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, lambda x: np_func(x), np_backward,
                   dtype=dtype, shape=shape, in_range=(rnd_min, rnd_max))
16 17 18 19 20 21 22 23 24 25 26 27 28 29


def test_floor():
    check_map(sym.floor, np.floor)

def test_ceil():
    check_map(sym.ceil, np.ceil)

def test_trunc():
    check_map(sym.trunc, np.trunc)

def test_round():
    check_map(sym.round, np.round)

30 31 32 33
def test_abs():
    check_map(sym.abs, np.abs)
    check_map(sym.abs, np.abs, dtype = "int32")
    check_map(sym.abs, np.abs, dtype = "int8")
34

35 36 37 38 39 40
def test_shift():
    n = 3
    for dtype in ["int32", "int8"]:
        check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
        check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)

41
if __name__ == "__main__":
42
    test_shift()
43 44 45
    test_floor()
    test_ceil()
    test_round()
46
    test_abs()
47
    test_trunc()