test_pass_attrs_hash_equal.py 2.09 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
import tvm

def test_attrs_equal():
    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
    y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
    z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1))
    assert tvm.ir_pass.AttrsEqual(x, y)
    assert not tvm.ir_pass.AttrsEqual(x, z)

    dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
    assert not tvm.ir_pass.AttrsEqual(dattr, x)
    dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
    assert tvm.ir_pass.AttrsEqual(dattr, dattr2)

    assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y})
    # array related checks
    assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
    assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})

36 37 38 39 40 41
    n = tvm.var("n")
    assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})




42 43 44 45 46 47 48 49 50 51 52 53 54 55

def test_attrs_hash():
    fhash = tvm.ir_pass.AttrsHash
    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
    y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
    assert fhash({"x": x}) == fhash({"x": y})
    assert fhash({"x": x}) != fhash({"x": [y, 1]})
    assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]})
    assert fhash({"x": [x, 2]}) == fhash({"x": [y, 2]})


if __name__ == "__main__":
    test_attrs_equal()
    test_attrs_hash()