test_ir_nodes.py 6.19 KB
Newer Older
1 2 3 4
""" test ir"""
import tvm
from tvm import relay
from tvm.expr import *
5
from tvm.relay import op
6 7 8 9 10 11 12 13
from tvm.relay.ir_pass import graph_equal


def check_json_roundtrip(node):
    json_str = tvm.save_json(node)
    back = tvm.load_json(json_str)
    assert graph_equal(back, node)

14

15 16 17 18 19 20 21
def test_bad_constructor():
    try:
        x = relay.ty.TensorType("xx", "xx")
    except tvm.TVMError:
        pass


22 23 24 25 26 27 28 29 30 31 32
# Span
def test_span():
    span = relay.Span(None, 1, 1)
    assert span.source == None
    assert span.lineno == 1
    assert span.col_offset == 1
    assert span.same_as(span)
    assert span == span
    assert isinstance(span, relay.base.Span)
    str(span)

33 34 35 36 37 38 39
    # span is not a node so we can't use graph_equal
    # to test the round trip
    back = tvm.load_json(tvm.save_json(span))
    assert back.source == span.source
    assert back.lineno == span.lineno
    assert back.col_offset == span.col_offset

40 41 42 43 44 45 46 47 48 49
# Types

def test_tensor_type():
    shape = tvm.convert([1, 2, 3])
    dtype = 'float32'
    tt = relay.TensorType(shape, dtype)
    assert tt.dtype == dtype
    assert tt.shape == shape
    assert tt.span == None
    str(tt)
50
    check_json_roundtrip(tt)
51 52 53


def test_type_param():
54
    tp = relay.TypeVar('name', relay.Kind.Type)
55
    assert tp.kind == relay.Kind.Type
56
    # assert tp.span  # TODO allow us to set span
57
    str(tp)
58
    check_json_roundtrip(tp)
59 60 61 62 63 64


def test_func_type():
    type_params = tvm.convert([])
    type_constraints = tvm.convert([])  # TODO: fill me in
    arg_types = tvm.convert([])
65
    ret_type = relay.TensorType((1, 2, 3), 'float32')
66 67 68 69 70 71
    tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
    assert tf.type_params == type_params
    assert tf.type_constraints == type_constraints
    assert tf.arg_types == arg_types
    assert tf.ret_type == ret_type
    assert tf.span == None
72
    # TODO make sure we can set span
73
    str(tf)
74
    check_json_roundtrip(tf)
75 76


77
def test_tuple_type():
78
    tp = relay.TypeVar('tp', relay.Kind.Type)
79 80 81 82 83 84
    tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
    tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
    fields = tvm.convert([tp, tf, tt])

    tup_ty = relay.TupleType(fields)
    assert tup_ty.fields == fields
85 86
    str(tup_ty)
    check_json_roundtrip(tup_ty)
87 88


89
def test_type_relation():
90
    tp = relay.TypeVar('tp', relay.Kind.Type)
91 92
    tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([]))
    tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
93
    args = tvm.convert([tp, tf, tt])
94 95

    num_inputs = 2
96 97
    func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
    attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
98 99 100 101

    tr = relay.TypeRelation(func, args, num_inputs, attrs)
    assert tr.args == args
    assert tr.num_inputs == num_inputs
102 103
    str(tr)
    check_json_roundtrip(tr)
104 105


106 107 108 109 110 111
def test_constant():
    arr = tvm.nd.array(10)
    const = relay.Constant(arr)
    assert const.data == arr
    assert const.span == None
    str(const)
112
    check_json_roundtrip(const)
113 114 115 116 117 118 119 120


def test_tuple():
    fields = tvm.convert([])
    tup = relay.Tuple(fields)
    assert tup.fields == fields
    assert tup.span == None
    str(tup)
121
    check_json_roundtrip(tup)
122 123 124 125 126


def test_local_var():
    name_hint = 's'
    lv = relay.Var(name_hint)
127 128
    assert lv.name_hint == name_hint
    assert lv.type_annotation is None
129 130
    # assert lv.span == None todo(@jroesch): what do we do about spans
    str(lv)
131
    check_json_roundtrip(lv)
132

133 134 135 136 137
    t1 = relay.ty.TensorType((), "float")
    lv = relay.Var(name_hint, t1)
    assert lv.name_hint == name_hint
    assert lv.type_annotation == t1

138 139 140 141 142 143 144

def test_global_var():
    name_hint = 'g'
    gv = relay.GlobalVar(name_hint)
    gv.name_hint == name_hint
    # assert lv.span == None todo(@jroesch): what do we do about spans
    str(gv)
145
    check_json_roundtrip(gv)
146 147 148 149


def test_function():
    param_names = ['a', 'b', 'c', 'd']
150
    params = tvm.convert([relay.Var(n) for n in param_names])
151 152
    ret_type = relay.TupleType(tvm.convert([]))
    body = relay.Tuple(tvm.convert([]))
153
    type_params = tvm.convert([])
154
    fn = relay.Function(params, body, ret_type, type_params)
155 156 157 158 159
    assert fn.params == params
    assert fn.body == body
    assert fn.type_params == type_params
    assert fn.span == None
    str(fn)
160
    check_json_roundtrip(fn)
161 162 163 164 165 166 167 168 169 170 171


def test_call():
    op = relay.Var('f')
    arg_names = ['a', 'b', 'c', 'd']
    args = tvm.convert([relay.Var(n) for n in arg_names])
    call = relay.Call(op, args, None, None)
    assert call.op == op
    assert call.args == args
    assert call.span == None
    str(call)
172
    check_json_roundtrip(call)
173 174 175 176 177 178 179 180 181


def test_let():
    lv = relay.Var('x')
    ty = None
    arr = tvm.nd.array(10)
    value = relay.Constant(arr)
    # I would prefer that the order of arguments
    # matches syntax let x: t = v in b
182
    let = relay.Let(lv, value, lv)
183 184 185 186 187
    assert let.var == lv
    assert let.value == value
    assert let.body == lv
    assert let.span == None
    str(let)
188
    check_json_roundtrip(let)
189 190 191 192 193 194 195 196 197 198 199 200


def test_if():
    cond = relay.Var('cond')
    left = relay.Var('left')
    right = relay.Var('right')
    ife = relay.If(cond, left, right)
    assert ife.cond == cond
    assert ife.true_branch == left
    assert ife.false_branch == right
    assert ife.span == None
    str(ife)
201
    check_json_roundtrip(ife)
202 203


204 205 206
def test_tuple_get_item():
    tup = relay.Var("tuple")
    get = relay.TupleGetItem(tup, 1)
207
    assert get.tuple_value == tup
208 209
    assert get.index == 1
    str(get)
210 211
    check_json_roundtrip(get)

212

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
def test_op():
    add = op.op.get("add")
    check_json_roundtrip(add)


def test_conv2d_attrs():
    data = relay.var('data', shape=(1, 3, 224, 224))
    param = relay.var('param', shape=(64, 3, 7, 7))
    out = op.nn.conv2d(
        data,
        param,
        strides=(2, 2),
        padding=(3, 3),
        channels=64,
        kernel_size=(7, 7))
    check_json_roundtrip(out)


231
if __name__ == "__main__":
232
    test_bad_constructor()
233 234 235 236
    test_span()
    test_tensor_type()
    test_type_param()
    test_func_type()
237 238
    test_tuple_type()
    test_type_relation()
239 240 241 242 243 244 245 246
    test_constant()
    test_tuple()
    test_local_var()
    test_global_var()
    test_function()
    test_call()
    test_let()
    test_if()
247
    test_tuple_get_item()
248 249
    test_op()
    test_conv2d_attrs()