test_top_level3.py 1.23 KB
Newer Older
1 2 3 4 5 6 7 8 9
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
from test_top_level1 import helper

10
def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
11 12 13 14
    x = sym.Variable("x")
    y = symfunc(x)
    dshape = (1, 3, 32, 32)
    inputs = [('x', dshape, x)]
15 16
    helper(y, inputs, dtype, lambda x: np_func(x), np_backward,
           rnd_min=rnd_min, rnd_max=rnd_max)
17 18 19 20 21 22 23 24 25 26 27 28 29 30


def test_floor():
    check_map(sym.floor, np.floor)

def test_ceil():
    check_map(sym.ceil, np.ceil)

def test_trunc():
    check_map(sym.trunc, np.trunc)

def test_round():
    check_map(sym.round, np.round)

31 32 33 34
def test_abs():
    check_map(sym.abs, np.abs)
    check_map(sym.abs, np.abs, dtype = "int32")
    check_map(sym.abs, np.abs, dtype = "int8")
35

36 37 38 39 40 41
def test_shift():
    n = 3
    for dtype in ["int32", "int8"]:
        check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
        check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)

42
if __name__ == "__main__":
43
    test_shift()
44 45 46
    test_floor()
    test_ceil()
    test_round()
47
    test_abs()
48
    test_trunc()