test_op_level4.py 4.38 KB
Newer Older
1
import tvm
2
import numpy as np
3
from tvm import relay
4 5 6 7
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
8

9 10 11 12 13 14
def assert_has_type(expr, typ, env=Environment({})):
    checked_expr = infer_type(env, expr)
    checked_type = checked_expr.checked_type
    if checked_type != typ:
        raise RuntimeError("Type mismatch %s vs %s" % (
            checked_type, typ))
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29

def test_cmp_type():
    for op in (relay.greater,
               relay.greater_equal,
               relay.less,
               relay.less_equal,
               relay.equal,
               relay.not_equal):
        ib = relay.ir_builder.IRBuilder()
        x = ib.param("x", relay.TensorType((10, 4), "float32"))
        y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
        with ib.function(x, y) as func:
            ib.ret(op(x.var, y.var))
        ib.ret(func)
        func = relay.ir_pass.infer_type(ib.env, func.to_func())
30
        ftype = func.checked_type
31 32 33
        assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")


Tianqi Chen committed
34
def test_binary_broadcast():
35 36
    for op in [relay.right_shift,
               relay.left_shift,
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
               relay.maximum]:
        ib = relay.ir_builder.IRBuilder()
        x = ib.param("x", relay.TensorType((10, 4), "int32"))
        y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
        with ib.function(x, y) as func:
            ib.ret(op(x.var, y.var))
        ib.ret(func)
        func = relay.ir_pass.infer_type(ib.env, func.to_func())
        ftype = func.checked_type
        assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")

def test_binary_op():
    def check_binary_op(opfunc):
        """
        Program:
            fn (x, y) {
                return x <op> y;
            }
        """
        b = IRBuilder()

        x = b.param('x', tensor_type(5, 5, 5))
        y = b.param('y', tensor_type(5, 5, 5))
        with b.function(x, y) as func:
            b.ret(opfunc(x.var, y.var))
        b.ret(func)
        prog, env = b.get()
        ttype = tensor_type(5, 5, 5)
        expected_ty = func_type([ttype, ttype], ttype)
        assert_has_type(func.to_func(), expected_ty)

    for opfunc in [relay.pow]:
        check_binary_op(opfunc)


def test_binary_broadcast_op():
    def check_binary_broadcast_op(opfunc):
        """
        Program:
            fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
                return x <op> y;
            }
        """
        b = IRBuilder()
        x = b.param('x', tensor_type(10, 4))
        y = b.param('y', tensor_type(5, 10, 1))
        with b.function(x, y) as func:
            b.ret(opfunc(x.var, y.var))
        b.ret(func)
        prog, env = b.get()

        expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
                                tensor_type(5, 10, 4))
        assert_has_type(func.to_func(), expected_ty)

    for opfunc in [relay.pow]:
        check_binary_broadcast_op(opfunc)

def test_cmp_type():
    for op in (relay.greater,
               relay.greater_equal,
               relay.less,
               relay.less_equal,
               relay.equal,
               relay.not_equal):
        ib = relay.ir_builder.IRBuilder()
        x = ib.param("x", relay.TensorType((10, 4), "float32"))
        y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
        with ib.function(x, y) as func:
            ib.ret(op(x.var, y.var))
        ib.ret(func)
        func = relay.ir_pass.infer_type(ib.env, func.to_func())
        ftype = func.checked_type
        assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")

def test_binary_broadcast():
    for op in [relay.right_shift,
               relay.left_shift,
115 116
               relay.maximum,
               relay.minimum]:
Tianqi Chen committed
117 118 119 120 121 122 123
        ib = relay.ir_builder.IRBuilder()
        x = ib.param("x", relay.TensorType((10, 4), "int32"))
        y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
        with ib.function(x, y) as func:
            ib.ret(op(x.var, y.var))
        ib.ret(func)
        func = relay.ir_pass.infer_type(ib.env, func.to_func())
124
        ftype = func.checked_type
Tianqi Chen committed
125 126 127
        assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")


128 129
if __name__ == "__main__":
    test_cmp_type()
Tianqi Chen committed
130
    test_binary_broadcast()
131 132
    test_binary_op()
    test_binary_broadcast_op()