test_arith_domain_touched.py 2.65 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19 20

def test_domain_touched():
21 22 23 24 25 26
    i = te.var('i')
    j = te.var('j')
    n = tvm.runtime.convert(100)
    m = te.var('m')
    a = te.placeholder((n, m), name = 'a')
    b = te.placeholder((n, m), name = 'b')
27
    ir = tvm.tir.For(
28
            i, 0, n, 0, 0,
29 30
            tvm.tir.For(j, 0, m, 0, 0,
                tvm.tir.Provide(
31 32
                    a.op,
                    0,
33 34
                    tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
                    tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
35 36 37 38
                    [i, j]
                )
            )
    )
39
    a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
40 41 42 43 44
    assert a_domain_r[0].min.value == -1
    assert a_domain_r[0].extent.value == 100
    assert a_domain_r[1].min.value == -1
    assert a_domain_r[1].extent.name == 'm'

45
    a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True)
46 47 48 49 50
    assert a_domain_w[0].min.value == 0
    assert a_domain_w[0].extent.value == 100
    assert a_domain_w[1].min.value == 0
    assert a_domain_w[1].extent.name == 'm'

51
    a_domain_rw= tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
52 53 54
    assert a_domain_rw[0].min.value == -1
    assert a_domain_rw[0].extent.value == 101
    assert a_domain_rw[1].min.value == -1
55
    assert isinstance(a_domain_rw[1].extent, tvm.tir.Add)
56 57 58
    assert a_domain_rw[1].extent.a.name == 'm'
    assert a_domain_rw[1].extent.b.value == 1

59
    b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False)
60 61 62 63 64 65
    assert b_domain_r
    assert b_domain_r[0].min.value == -1
    assert b_domain_r[0].extent.value == 100
    assert b_domain_r[1].min.value == 1
    assert b_domain_r[1].extent.name == 'm'

66
    b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
67 68 69 70 71
    assert isinstance(b_domain_w, tvm.container.Array)
    assert len(b_domain_w) == 0

if __name__ == "__main__":
    test_domain_touched()