test_ir_op.py 888 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
from tvm import relay

def test_op_attr():
    log_op = relay.op.get("log")

    @relay.op.register("exp", "ftest")
    def test(x):
        return x + 1

    assert log_op.num_inputs  == 1
    assert log_op.get_attr("ftest") is None
    assert relay.op.get("exp").get_attr("ftest")(1) == 2

def test_op_level1():
    x = relay.Var("x")

17
    for op_name in ["log", "exp", "sqrt", "tanh"]:
18 19 20 21 22
        y = getattr(relay, op_name)(x)
        assert y.op.name == op_name
        assert y.op.support_level == 1
        assert y.args[0] == x

23 24 25 26 27 28 29 30
def test_op_level3():
    x = relay.Var("x")

    for op_name in ["ceil", "floor", "trunc", "round", "abs", "negative"]:
        y = getattr(relay, op_name)(x)
        assert y.op.name == op_name
        assert y.op.support_level == 3
        assert y.args[0] == x
31 32 33 34

if __name__ == "__main__":
    test_op_attr()
    test_op_level1()
35
    test_op_level3()