# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm def test_attrs_equal(): x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1)) assert tvm.ir_pass.AttrsEqual(x, y) assert not tvm.ir_pass.AttrsEqual(x, z) dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert not tvm.ir_pass.AttrsEqual(dattr, x) dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert tvm.ir_pass.AttrsEqual(dattr, dattr2) assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y}) # array related checks assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) n = tvm.var("n") assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1}) def test_attrs_hash(): fhash = tvm.ir_pass.AttrsHash x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) assert fhash({"x": x}) == fhash({"x": y}) assert fhash({"x": x}) != fhash({"x": [y, 1]}) assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]}) assert fhash({"x": [x, 2]}) == fhash({"x": [y, 2]}) if __name__ == "__main__": test_attrs_equal() test_attrs_hash()