test_lang_container.py 2.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
tqchen committed
17 18 19 20 21
import tvm

def test_array():
    a = tvm.convert([1,2,3])
    assert len(a) == 3
22 23 24
    assert a[-1].value == 3
    a_slice = a[-3:-1]
    assert (a_slice[0].value, a_slice[1].value) == (1, 2)
tqchen committed
25

26 27 28 29
def test_array_save_load_json():
    a = tvm.convert([1,2,3])
    json_str = tvm.save_json(a)
    a_loaded = tvm.load_json(json_str)
30
    assert(a_loaded[1].value == 2)
31

32

tqchen committed
33
def test_map():
34 35
    a = tvm.var('a')
    b = tvm.var('b')
tqchen committed
36 37 38 39 40
    amap = tvm.convert({a: 2,
                        b: 3})
    assert a in amap
    assert len(amap) == 2
    dd = dict(amap.items())
41 42
    assert a in dd
    assert b in dd
tqchen committed
43 44
    assert a + 1 not in amap

45 46 47 48 49 50 51 52 53 54 55

def test_str_map():
    amap = tvm.convert({'a': 2, 'b': 3})
    assert 'a' in amap
    assert len(amap) == 2
    dd = dict(amap.items())
    assert amap['a'].value == 2
    assert 'a' in dd
    assert 'b' in dd


56
def test_map_save_load_json():
57 58
    a = tvm.var('a')
    b = tvm.var('b')
59 60 61 62 63 64 65 66 67
    amap = tvm.convert({a: 2,
                        b: 3})
    json_str = tvm.save_json(amap)
    amap = tvm.load_json(json_str)
    assert len(amap) == 2
    dd = {kv[0].name : kv[1].value for kv in amap.items()}
    assert(dd == {"a": 2, "b": 3})


68 69 70 71 72 73
def test_in_container():
    arr = tvm.convert(['a', 'b', 'c'])
    assert 'a' in arr
    assert tvm.make.StringImm('a') in arr
    assert 'd' not in arr

tqchen committed
74
if __name__ == "__main__":
75
    test_str_map()
tqchen committed
76 77
    test_array()
    test_map()
78 79
    test_array_save_load_json()
    test_map_save_load_json()
80
    test_in_container()