Unverified Commit fe74b37a by Tianqi Chen Committed by GitHub

Conditions updated to cover better user scenarios (#4951)

* Conditions updated to cover better user scenarios

* [1] New test case added

* [2] New test case added

* [3] Proper variable name used

* [4] Review Comments handled

* [5] Review comments handled

* [6] Review comments handled
parent 7a06bbed
......@@ -50,14 +50,14 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>()) return false;
if (lhs.same_as(rhs)) return true;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>()) return false;
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
using namespace tvm;
class TestAlphaEquals {
runtime::PackedFunc *_packed_func;
public:
TestAlphaEquals(const char* func_name) {
_packed_func = new runtime::PackedFunc();
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
void UpdatePackedFunc(const char* func_name) {
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
bool operator()(ObjectRef input_1, ObjectRef input_2) {
TVMRetValue rv;
std::vector<TVMValue> values(2);
std::vector<int> codes(2);
runtime::TVMArgsSetter setter(values.data(), codes.data());
setter(0, input_1);
setter(1, input_2);
_packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
return bool(rv);
};
};
TEST(Relay, AlphaTestEmptyTypeNodes) {
auto x = TypeVar("x", kTypeData);
auto y = TypeVar();
EXPECT_FALSE(relay::AlphaEqual(x, y));
TestAlphaEquals test_equals("relay._make._alpha_equal");
EXPECT_FALSE(test_equals(x, y));
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -28,6 +28,15 @@ def alpha_equal(x, y):
"""
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def alpha_equal_commutative(x, y):
"""
Check for commutative property of equality
"""
xy = analysis.alpha_equal(x, y)
yx = analysis.alpha_equal(y, x)
assert xy == yx
return xy
def test_tensor_type_alpha_equal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
......@@ -219,6 +228,26 @@ def test_constant_alpha_equal():
assert not alpha_equal(x, y)
assert alpha_equal(x, relay.const(1))
def test_type_node_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)
v1 = relay.TypeVar('v1', 0)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)
assert alpha_equal_commutative(v1, v1)
def test_type_node_incompatible_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.Var("v2")
assert not alpha_equal_commutative(v1, v2)
def test_expr_node_incompatible_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.PatternVar(relay.Var("v2"))
assert not alpha_equal_commutative(v1, v2)
def test_var_alpha_equal():
v1 = relay.Var("v1")
......@@ -676,6 +705,9 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_node_alpha_equal()
test_type_node_incompatible_alpha_equal()
test_expr_node_incompatible_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment