# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm def test_domain_touched(): i = tvm.var('i') j = tvm.var('j') n = tvm.convert(100) m = tvm.var('m') a = tvm.placeholder((n, m), name = 'a') b = tvm.placeholder((n, m), name = 'b') ir = tvm.make.For( i, 0, n, 0, 0, tvm.make.For(j, 0, m, 0, 0, tvm.make.Provide( a.op, 0, tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), [i, j] ) ) ) a_domain_r = tvm.arith.DomainTouched(ir, a, True, False) assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 assert a_domain_r[1].extent.name == 'm' a_domain_w = tvm.arith.DomainTouched(ir, a, False, True) assert a_domain_w[0].min.value == 0 assert a_domain_w[0].extent.value == 100 assert a_domain_w[1].min.value == 0 assert a_domain_w[1].extent.name == 'm' a_domain_rw= tvm.arith.DomainTouched(ir, a, True, True) assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[1].min.value == -1 assert isinstance(a_domain_rw[1].extent, tvm.expr.Add) assert a_domain_rw[1].extent.a.name == 'm' assert a_domain_rw[1].extent.b.value == 1 b_domain_r = tvm.arith.DomainTouched(ir, b, True, False) assert b_domain_r assert b_domain_r[0].min.value == -1 assert b_domain_r[0].extent.value == 100 assert b_domain_r[1].min.value == 1 assert b_domain_r[1].extent.name == 'm' b_domain_w = tvm.arith.DomainTouched(ir, b, False, True) assert isinstance(b_domain_w, tvm.container.Array) assert len(b_domain_w) == 0 if __name__ == "__main__": test_domain_touched()