Unverified Commit 919ae889 by Zhi Committed by GitHub

[REFACTOR][IR] alpha_equal to structural_equal (#5161)

parent 07ac7712
......@@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
......
......@@ -65,61 +65,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
TVM_DLL bool ConstantCheck(const Expr& e);
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
......
......@@ -16,6 +16,7 @@
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm
import tvm._ffi
from .base import Node
......@@ -26,7 +27,7 @@ class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other))
return bool(tvm.ir.structural_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
......
......@@ -33,7 +33,6 @@ from . import parser
from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import debug
......
......@@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_ffi_api._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_ffi_api._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_graph_equal(lhs, rhs)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
......
......@@ -23,6 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
......@@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< AsText(func, false)
<< std::endl;
}
func =
relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
func = relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
......@@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
......@@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/analysis/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
// Alpha Equal handler for Relay.
class AlphaEqualHandler:
public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>,
public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public:
explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
: map_free_var_(map_free_var), assert_mode_(assert_mode) { }
/*!
* Check equality of two nodes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
return VisitAttr(lhs, rhs);
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return The comparison result.
*/
bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
auto compute = [&]() {
return VisitAttr(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
/*!
* Check equality of two types.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the comparison result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_) {
CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
}
return result;
}
/*!
* Check equality of two expressions.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return The comparison result.
*/
bool ExprEqual(const Expr& lhs, const Expr& rhs) {
auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
}
if (this->VisitExpr(lhs, rhs)) {
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
};
return Compare(compute(), lhs, rhs);
}
protected:
// So that the new definition of equality in relay can be handled directly.
// Specifically, if a DictAttr contains a value defined by a relay AST.
// We want to able to recursively check the equality in the attr defined by the relay AST.
bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
return true;
}
// Fall back to the object equal case.
return AttrsEqualHandler::VisitAttr(lhs, rhs);
}
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return The compare result.
*/
bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
return lhs == rhs;
}
/*!
* \brief Check Equality of leaf node of the graph.
* if map_free_var_ is set to true, try to map via equal node.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return The compare result.
*/
bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
} else {
if (map_free_var_) {
if (lhs->type_index() != rhs->type_index()) return false;
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
}
}
using AttrsEqualHandler::VisitAttr_;
bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final {
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
}
// Type equality
bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
return (lhs->dtype == rhs->dtype &&
AttrEqual(lhs->shape, rhs->shape));
} else {
return false;
}
}
bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
}
bool VisitType_(const PrimTypeNode* lhs, const Type& other) final {
if (const PrimTypeNode* rhs = other.as<PrimTypeNode>()) {
return lhs->dtype == rhs->dtype;
} else {
return false;
}
}
bool VisitType_(const PointerTypeNode* lhs, const Type& other) final {
if (const PointerTypeNode* rhs = other.as<PointerTypeNode>()) {
return TypeEqual(lhs->element_type, rhs->element_type);
} else {
return false;
}
}
bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
if (lhs->kind != rhs->kind) return false;
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
} else {
return false;
}
}
bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
return false;
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
}
for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
}
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
if (!TypeEqual(lhs->type_constraints[i],
rhs->type_constraints[i])) {
return false;
}
}
return true;
} else {
return false;
}
}
bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (lhs->func->name != rhs->func->name) return false;
if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
return TypeEqual(lhs->value, rhs->value);
}
return false;
}
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
}
bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
const TypeCallNode* rhs = other.as<TypeCallNode>();
if (rhs == nullptr
|| lhs->args.size() != rhs->args.size()
|| !TypeEqual(lhs->func, rhs->func)) {
return false;
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!TypeEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
return true;
}
bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
const TypeDataNode* rhs = other.as<TypeDataNode>();
if (rhs == nullptr
|| lhs->type_vars.size() != rhs->type_vars.size()
|| !TypeEqual(lhs->header, rhs->header)) {
return false;
}
for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->constructors.size(); ++i) {
if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
return false;
}
}
return true;
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(*lhs.operator->());
return std::memcmp(lhs->data, rhs->data, data_size) == 0;
} else {
return false;
}
}
}
// merge declaration of two variables together.
bool MergeVarDecl(const Var& lhs, const Var& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!TypeEqual(lhs->type_annotation,
rhs->type_annotation)) return false;
CHECK(!equal_map_.count(lhs))
<< "Duplicated declaration of variable " << lhs;
equal_map_[lhs] = rhs;
return true;
}
bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
// This function will only be triggered if we are matching free variables.
if (const VarNode* rhs = other.as<VarNode>()) {
if (lhs->name_hint() != rhs->name_hint()) return false;
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
} else {
return false;
}
}
bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
// use name equality for global var for now.
return lhs->name_hint == rhs->name_hint;
}
return false;
}
bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
if (const TupleNode* rhs = other.as<TupleNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
if (const FunctionNode* rhs = other.as<FunctionNode>()) {
if (lhs->params.size() != rhs->params.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
// map type parameter to be the same
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
}
// check parameter type annotations
for (size_t i = 0; i < lhs->params.size(); ++i) {
if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
}
// check return types.
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
// skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
return false;
}
}
bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
if (!ExprEqual(lhs->value, rhs->value)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
if (const IfNode* rhs = other.as<IfNode>()) {
return ExprEqual(lhs->cond, rhs->cond) &&
ExprEqual(lhs->true_branch, rhs->true_branch) &&
ExprEqual(lhs->false_branch, rhs->false_branch);
} else {
return false;
}
}
bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
return lhs == other.get();
}
bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
if (const ConstantNode* rhs = other.as<ConstantNode>()) {
return NDArrayEqual(lhs->data, rhs->data);
} else {
return false;
}
}
bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
} else {
return false;
}
}
bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
return ExprEqual(lhs->value, rhs->value);
} else {
return false;
}
}
bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
if (const RefReadNode* rhs = other.as<RefReadNode>()) {
return ExprEqual(lhs->ref, rhs->ref);
} else {
return false;
}
}
bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value);
} else {
return false;
}
}
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
return lhs->name_hint == rhs->name_hint;
}
return false;
}
bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
}
bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
return Compare(VisitPattern(lhs, rhs), lhs, rhs);
}
bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
return other.as<PatternWildcardNode>();
}
bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
if (const auto* rhs = other.as<PatternVarNode>()) {
return MergeVarDecl(lhs->var, rhs->var);
}
return false;
}
bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternConstructorNode>();
if (rhs == nullptr
|| !ExprEqual(lhs->constructor, rhs->constructor)
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}
for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}
bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternTupleNode>();
if (rhs == nullptr
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}
for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}
bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>();
if (rhs == nullptr
|| !ExprEqual(lhs->data, rhs->data)
|| lhs->clauses.size() != rhs->clauses.size()
|| lhs->complete != rhs->complete) {
return false;
}
for (size_t i = 0; i < lhs->clauses.size(); ++i) {
if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
return false;
}
}
return true;
}
private:
// whether to map open terms.
bool map_free_var_;
// if in assert mode, must return true, and will throw error otherwise.
bool assert_mode_;
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
};
bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
}
bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}
TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
.set_body_typed([](Type a, Type b) {
return AlphaEqual(a, b);
});
TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});
TVM_REGISTER_GLOBAL("relay.analysis._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});
} // namespace relay
} // namespace tvm
......@@ -21,6 +21,7 @@
* \file type_solver.cc
* \brief Type solver implementations.
*/
#include <tvm/node/structural_equal.h>
#include <tvm/ir/type_functor.h>
#include <tvm/tir/op.h>
#include <string>
......@@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return rc.Check(t);
}
// default: unify only if alpha-equal
// default: unify only if structural-equal
Type VisitTypeDefault_(const Object* op, const Type& tn) final {
ObjectRef nr = GetRef<ObjectRef>(op);
Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) {
if (!tvm::StructuralEqual()(t1, tn)) {
return Type(nullptr);
}
return t1;
......@@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
auto tt1 = GetRef<TensorType>(op);
auto tt2 = GetRef<TensorType>(tt_node);
if (AlphaEqual(tt1, tt2)) {
if (tvm::StructuralEqual()(tt1, tt2)) {
return std::move(tt1);
}
......
......@@ -25,6 +25,7 @@
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
......@@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
AlphaEqual(this->source_func, other->source_func);
tvm::StructuralEqual()(this->source_func, other->source_func);
}
} // namespace relay
......
......@@ -22,6 +22,7 @@
* \brief Lift all local functions into global functions.
*/
#include <tvm/node/structural_equal.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
......@@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator {
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
CHECK(tvm::StructuralEqual()(lifted_func, existing_func))
<< "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
......
......@@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data,
TVM_REGISTER_GLOBAL("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) {
*rv = MakeSplit(args[0], tir::make_const(DataType::Int(64), int64_t(args[1])), args[2]);
// Note: we change it from Int(64) to Int(32) for now as
// combine_parallel_dense will transform the graph with Int(32).
// More invetigation is needs to check which one we should use.
*rv = MakeSplit(args[0],
tir::make_const(DataType::Int(32), static_cast<int>(args[1])),
args[2]);
} else {
*rv = MakeSplit(args[0], args[1], args[2]);
}
......
......@@ -59,6 +59,7 @@
* Thus, it is necessary to wrap this outer function so that the input/output types remain the same
*/
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/type_functor.h>
......@@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type});
{expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
......@@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Expr VisitExpr_(const ConstantNode* op) final {
return Call(module_->GetConstructor("GradCell", "Raw"),
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
}
Expr VisitExpr_(const CallNode* call_node) final {
......@@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", constructor_name),
{func}, Attrs(), {call_node->checked_type()});
{func}, Attrs(), {call_node->checked_type()});
}
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
// ones_like and zeros_like need TensorType input
Expr result = CallPrimitiveOp(call_node);
// fn() -> T, function returns result of operation
Expr func = Function({}, result,
{call_node->checked_type()}, Array<TypeVar>());
Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()});
{func}, Attrs(), {call_node->checked_type()});
}
// handle all other ops
Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
Attrs(), {call_node->checked_type()});
}
// not an op
return ExprMutator::VisitExpr_(call_node);
......@@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
// can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 ||
!AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
!tvm::StructuralEqual()(call_node->args[0]->checked_type(),
call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node);
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
Attrs(), {call_node->checked_type()});
}
tvm::Array<Expr> args;
......@@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Var("rhs", paramType)};
// use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]});
Expr func = Function(params, callOp, paramType,
Array<TypeVar>());
Expr func = Function(params, callOp, paramType, Array<TypeVar>());
// pass "fallback" function and tensors as arguments
args.push_back(func);
......
......@@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#include <builtin_fp16.h>
#include <tvm/node/structural_equal.h>
#include <tvm/tir/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
......@@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
return false;
}
return AlphaEqual(a, b);
return tvm::StructuralEqual()(a, b);
}
inline Expr GetField(Expr t, size_t i) {
......
......@@ -353,7 +353,7 @@ Function UnCPS(const Function& f) {
auto answer_type = new_type_params.back();
new_type_params.pop_back();
// TODO(@M.K.): make alphaequal work on free term
// CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
// CHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type)));
auto x = Var("x", new_ret_type);
auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
using namespace tvm;
class TestAlphaEquals {
runtime::PackedFunc *_packed_func;
public:
TestAlphaEquals(const char* func_name) {
_packed_func = new runtime::PackedFunc();
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
void UpdatePackedFunc(const char* func_name) {
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
bool operator()(ObjectRef input_1, ObjectRef input_2) {
TVMRetValue rv;
std::vector<TVMValue> values(2);
std::vector<int> codes(2);
runtime::TVMArgsSetter setter(values.data(), codes.data());
setter(0, input_1);
setter(1, input_2);
_packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
return bool(rv);
};
};
TEST(Relay, AlphaTestEmptyTypeNodes) {
auto x = TypeVar("x", kTypeData);
auto y = TypeVar();
EXPECT_FALSE(relay::AlphaEqual(x, y));
TestAlphaEquals test_equals("relay._make._alpha_equal");
EXPECT_FALSE(test_equals(x, y));
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/node/structural_equal.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
......@@ -38,7 +39,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main");
auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
}
int main(int argc, char ** argv) {
......
......@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/node/structural_equal.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
......@@ -102,7 +103,7 @@ TEST(Relay, Sequential) {
auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected));
CHECK(tvm::StructuralEqual()(f, expected));
}
int main(int argc, char** argv) {
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test graph equality of caffe2 models."""
import tvm
from tvm import relay
from tvm.relay import transform
from model_zoo import c2_squeezenet, relay_squeezenet
......@@ -23,7 +24,7 @@ from model_zoo import c2_squeezenet, relay_squeezenet
def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_squeeze_net():
......
......@@ -25,7 +25,7 @@ import model_zoo
def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_mlp():
shape = {"data": (1, 1, 28, 28)}
......
......@@ -77,7 +77,7 @@ def test_extract_identity():
mod["main"] = mod["main"].with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
tvm.ir.structural_equal(list(items.values())[0], mod["main"])
def test_extract_conv_net():
......
......@@ -136,7 +136,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape)
assert relay.analysis.alpha_equal(mod, ref_mod)
assert tvm.ir.structural_equal(mod, ref_mod)
def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True):
......
......@@ -27,7 +27,7 @@ def test_callgraph_construct():
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module)
assert tvm.ir.structural_equal(mod, call_graph.module)
def test_print_element():
......
......@@ -29,11 +29,11 @@ def test_bind_params():
fexpected =relay.Function(
[y],
relay.add(relay.const(1, "float32"), y))
assert relay.analysis.alpha_equal(fbinded, fexpected)
assert tvm.ir.structural_equal(fbinded, fexpected)
zbinded = relay.bind(z, {y: x})
zexpected = relay.add(x, x)
assert relay.analysis.alpha_equal(zbinded, zexpected)
assert tvm.ir.structural_equal(zbinded, zexpected)
if __name__ == "__main__":
......
......@@ -21,13 +21,12 @@ from tvm import te
from tvm import relay
from tvm.tir.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal
import numpy as np
def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
assert graph_equal(back, node)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)
# Span
......
......@@ -107,7 +107,7 @@ def test_func_type_sequal():
ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
tvm.runtime.convert([tp1, tp3]),
tvm.runtime.convert([tr1]))
translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
tvm.runtime.convert([tp2, tp4]),
tvm.runtime.convert([tr2]))
assert ft == translate_vars
......
......@@ -20,7 +20,7 @@ from tvm import relay
import tvm.relay.testing
import numpy as np
from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars
from tvm.relay.analysis import free_vars
do_print = [False]
......@@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False):
return txt
x = relay.fromtext(txt)
if unify_free_vars:
assert_graph_equal(x, p)
tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
else:
assert_alpha_equal(x, p)
tvm.ir.assert_structural_equal(x, p)
return txt
def show(text):
......
......@@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal():
"""
)
relay.analysis.assert_alpha_equal(df, df_parsed)
tvm.ir.assert_structural_equal(df, df_parsed)
def test_checkpoint_alpha_equal_tuple():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
......@@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple():
"""
)
relay.analysis.assert_alpha_equal(df, df_parsed)
tvm.ir.assert_structural_equal(df, df_parsed)
def test_collapse_sum_like():
shape = (3, 4, 5, 6)
......
......@@ -66,7 +66,7 @@ def test_alter_op():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_return_none():
......@@ -88,7 +88,7 @@ def test_alter_return_none():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
......@@ -151,7 +151,7 @@ def test_alter_layout():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_dual_path():
......@@ -214,7 +214,7 @@ def test_alter_layout_dual_path():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_resnet():
"""Test alternating the layout of a residual block
......@@ -271,7 +271,7 @@ def test_alter_layout_resnet():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_op():
......@@ -318,7 +318,7 @@ def test_alter_layout_broadcast_op():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_scalar_op():
......@@ -381,7 +381,7 @@ def test_alter_layout_broadcast_scalar_op():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar():
......@@ -424,7 +424,7 @@ def test_alter_layout_scalar():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate():
......@@ -478,7 +478,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# NHWC layout transformation.
def before_nhwc():
......@@ -524,7 +524,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nchw_upsamping_op():
......@@ -561,7 +561,7 @@ def test_alter_layout_nchw_upsamping_op():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_strided_slice():
......@@ -597,7 +597,7 @@ def test_alter_layout_strided_slice():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator"""
......@@ -632,7 +632,7 @@ def test_alter_layout_depthwise_conv2d():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b))
assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_prelu():
"""Test PRelu operator"""
......@@ -672,7 +672,7 @@ def test_alter_layout_prelu():
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b))
assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_pad():
......@@ -715,7 +715,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
......@@ -749,7 +749,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that conversion does not happen when padding along split axis.
def before():
......@@ -782,7 +782,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_pool():
......@@ -825,7 +825,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
......@@ -859,7 +859,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum():
......@@ -902,7 +902,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
......@@ -937,7 +937,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
......@@ -999,7 +999,7 @@ def test_alter_layout_nhwc_nchw_arm():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_op_with_global_var():
"""Test directly replacing an operator with a new one"""
......@@ -1041,7 +1041,7 @@ def test_alter_op_with_global_var():
a = transform.AlterOpLayout()(a)
b = transform.InferType()(expected())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a)
if __name__ == "__main__":
test_alter_op()
......
......@@ -64,7 +64,7 @@ def test_redundant_annotation():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(annotated_func, expected_func)
assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_expr():
......@@ -91,7 +91,7 @@ def test_annotate_expr():
annotated_expr = annotated()
expected_expr = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_expr, expected_expr)
assert tvm.ir.structural_equal(annotated_expr, expected_expr)
def test_annotate_all():
......@@ -120,7 +120,7 @@ def test_annotate_all():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func)
assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_none():
......@@ -146,13 +146,13 @@ def test_annotate_none():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func)
assert tvm.ir.structural_equal(annotated_func, expected_func)
def check_annotated_graph(annotated_func, expected_func):
annotated_func = run_opt_pass(annotated_func, transform.InferType())
expected_func = run_opt_pass(expected_func, transform.InferType())
assert relay.analysis.alpha_equal(annotated_func, expected_func)
assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_conv_network():
......@@ -596,7 +596,7 @@ def test_tuple_get_item():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func)
assert tvm.ir.structural_equal(annotated_func, expected_func)
if __name__ == "__main__":
......
......@@ -64,7 +64,7 @@ def test_canonicalize_cast():
mod[gv] = y_expected
mod = _transform.InferType()(mod)
y_expected = mod["expected"]
assert relay.analysis.alpha_equal(y, y_expected)
assert tvm.ir.structural_equal(y, y_expected)
check((1, 16, 7, 7))
......
......@@ -72,7 +72,7 @@ def test_combine_parallel_conv2d():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7)
......@@ -118,7 +118,7 @@ def test_combine_parallel_conv2d_scale_relu():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8)
......@@ -157,7 +157,7 @@ def test_combine_parallel_conv2d_scale():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8)
......@@ -193,7 +193,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w, out_c, repeat)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4)
......
......@@ -75,7 +75,7 @@ def test_combine_parallel_dense():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4)
check(100, 200, 300)
......@@ -127,7 +127,7 @@ def test_combine_parallel_dense_biasadd():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, False)
check(100, 200, 300, False)
......@@ -184,7 +184,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected)
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
......
......@@ -52,7 +52,7 @@ def test_no_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_convert_layout():
......@@ -87,7 +87,7 @@ def test_conv_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bias_pool_convert_layout():
......@@ -132,7 +132,7 @@ def test_conv_bias_pool_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_concat_convert_layout():
......@@ -180,7 +180,7 @@ def test_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_dual_path_convert_layout():
......@@ -235,7 +235,7 @@ def test_dual_path_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_bn_convert_layout():
......@@ -315,7 +315,7 @@ def test_resnet_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_scalar_convert_layout():
......@@ -347,7 +347,7 @@ def test_scalar_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bn_convert_layout():
......@@ -395,7 +395,7 @@ def test_conv_bn_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_requantize_convert_layout():
......@@ -451,7 +451,7 @@ def test_qnn_conv_requantize_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_concat_convert_layout():
......@@ -529,7 +529,7 @@ def test_qnn_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_add_convert_layout():
......@@ -609,7 +609,7 @@ def test_qnn_conv_add_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
......
......@@ -18,7 +18,7 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.analysis import free_vars
from tvm.relay.op import log, add, equal, subtract
from tvm.relay.testing import inception_v3
......@@ -69,7 +69,7 @@ def test_used_let():
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let():
......@@ -105,7 +105,7 @@ def test_recursion():
orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType())
assert_alpha_equal(dced, orig)
tvm.ir.assert_structural_equal(dced, orig)
def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
......
......@@ -52,7 +52,7 @@ def test_simple():
z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
assert analysis.alpha_equal(z, expected())
assert tvm.ir.structural_equal(z, expected())
def test_callback():
......@@ -82,7 +82,7 @@ def test_callback():
z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
assert analysis.alpha_equal(z, expected())
assert tvm.ir.structural_equal(z, expected())
if __name__ == "__main__":
......
......@@ -47,7 +47,8 @@ def test_eta_expand_global_var():
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
tvm.ir.assert_structural_equal(mod['main'], expected['main'],
map_free_vars=True)
def test_eta_expand_constructor():
......@@ -76,7 +77,8 @@ def test_eta_expand_constructor():
}
}
""")
relay.analysis.assert_graph_equal(mod['main'], expected['main'])
tvm.ir.assert_structural_equal(mod['main'], expected['main'],
map_free_vars=True)
if __name__ == '__main__':
......
......@@ -59,7 +59,7 @@ def test_fold_const():
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_let():
......@@ -84,7 +84,7 @@ def test_fold_let():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_tuple():
......@@ -106,7 +106,7 @@ def test_fold_tuple():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_concat():
......@@ -125,7 +125,7 @@ def test_fold_concat():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_shape_of():
......@@ -146,7 +146,7 @@ def test_fold_shape_of():
for dtype in ["int32", "float32"]:
zz = run_opt_pass(before(dtype), transform.FoldConstant())
zexpected = run_opt_pass(expected(dtype), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_full():
......@@ -161,7 +161,7 @@ def test_fold_full():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_batch_norm():
......@@ -202,7 +202,7 @@ def test_fold_batch_norm():
mod = remove_bn_pass(mod)
expect = run_infer_type(expected())
assert relay.analysis.graph_equal(mod["main"], expect)
assert tvm.ir.structural_equal(mod["main"], expect)
if __name__ == "__main__":
......
......@@ -79,7 +79,7 @@ def test_fold_fwd_simple():
y1_folded = run_opt_pass(y1_folded, transform.InferType())
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
......@@ -148,7 +148,7 @@ def test_fold_fwd_dual_path():
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
......@@ -177,7 +177,7 @@ def test_fold_fwd_fail():
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded)
assert tvm.ir.structural_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
......@@ -205,7 +205,7 @@ def test_fold_fwd_relu_fail():
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded)
assert tvm.ir.structural_equal(y1, y1_folded)
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
......@@ -249,7 +249,7 @@ def test_fold_fwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
y1_expected = expected(x, weight, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
......@@ -300,7 +300,7 @@ def test_fold_bwd_simple():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
......@@ -359,7 +359,7 @@ def test_fold_bwd_dual_path():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
......@@ -431,7 +431,7 @@ def test_fold_bwd_dual_consumer():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
......@@ -480,7 +480,7 @@ def test_fold_bwd_fail():
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1_folded, y1)
assert tvm.ir.structural_equal(y1_folded, y1)
check((4, 4, 10, 10), 4, fail1)
check((4, 4, 10, 10), 4, fail2)
......@@ -505,7 +505,7 @@ def test_fold_bwd_relu_fail():
y1 = before(x, weight, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded)
assert tvm.ir.structural_equal(y1, y1_folded)
out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, out_scale)
......@@ -547,7 +547,7 @@ def test_fold_bwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected)
assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
......
......@@ -45,7 +45,7 @@ def test_fuse_simple():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_conv2d_fuse():
......@@ -127,7 +127,7 @@ def test_conv2d_fuse():
z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_concatenate():
......@@ -167,7 +167,7 @@ def test_concatenate():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_tuple_root():
......@@ -204,7 +204,7 @@ def test_tuple_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_stop_fusion():
......@@ -235,7 +235,7 @@ def test_stop_fusion():
z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_fuse_myia_regression():
......@@ -271,7 +271,7 @@ def test_fuse_myia_regression():
f = before(dshape, dtype)
zz = run_opt_pass(f, transform.FuseOps())
after = run_opt_pass(expected(dshape, dtype), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_fuse_tuple_get_elemwise():
......@@ -309,7 +309,7 @@ def test_fuse_tuple_get_elemwise():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_tuple_get_root():
......@@ -346,7 +346,7 @@ def test_tuple_get_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
......@@ -379,7 +379,7 @@ def test_tuple_intermediate():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(x), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after)
assert tvm.ir.structural_equal(m["main"], after)
def test_tuple_consecutive():
......@@ -437,7 +437,7 @@ def test_tuple_consecutive():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after)
assert tvm.ir.structural_equal(m["main"], after)
def test_inception_like():
......@@ -510,7 +510,7 @@ def test_inception_like():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after)
assert tvm.ir.structural_equal(m["main"], after)
def test_fuse_parallel_injective():
......@@ -541,7 +541,7 @@ def test_fuse_parallel_injective():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
def test_immutable():
......@@ -570,8 +570,8 @@ def test_immutable():
mod = before()
new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
assert relay.analysis.alpha_equal(mod, before())
assert relay.analysis.alpha_equal(new_mod, expected())
assert tvm.ir.structural_equal(mod, before())
assert tvm.ir.structural_equal(new_mod, expected())
def test_split():
......@@ -619,7 +619,7 @@ def test_fuse_max():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after)
assert tvm.ir.structural_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
......
......@@ -19,7 +19,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
......@@ -292,7 +292,7 @@ def test_concat():
func = relay.Function([x], y)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
# no value validation as concatenate has dummy gradient right now.
......
......@@ -115,7 +115,7 @@ def test_call_chain_inline_leaf():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels():
......@@ -188,7 +188,7 @@ def test_call_chain_inline_multiple_levels():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels_extern_compiler():
......@@ -266,7 +266,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_call_with_global():
......@@ -321,7 +321,7 @@ def test_recursive_call_with_global():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_called():
......@@ -330,7 +330,7 @@ def test_recursive_called():
mod["main"] = relay.Function([iarg], sum_up(iarg))
ref_mod = mod
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, ref_mod)
assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called():
......@@ -356,7 +356,7 @@ def test_recursive_not_called():
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called_extern_compiler():
......@@ -387,7 +387,7 @@ def test_recursive_not_called_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod)
assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_globalvar_as_call_arg():
......@@ -434,7 +434,7 @@ def test_globalvar_as_call_arg():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_as_call_arg_extern_compiler():
......@@ -500,7 +500,7 @@ def test_globalvar_as_call_arg_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args():
......@@ -531,7 +531,7 @@ def test_inline_globalvar_without_args():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args_extern_compiler():
......@@ -566,7 +566,7 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_called_by_multiple_functions():
......@@ -644,7 +644,7 @@ def test_globalvar_called_by_multiple_functions():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_entry_with_inline():
......@@ -674,7 +674,7 @@ def test_entry_with_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod())
assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline():
......@@ -707,7 +707,7 @@ def test_callee_not_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod())
assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline_leaf_inline():
......@@ -765,7 +765,7 @@ def test_callee_not_inline_leaf_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_callee_not_inline_leaf_inline_extern_compiler():
......@@ -830,7 +830,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
if __name__ == '__main__':
......
......@@ -68,7 +68,7 @@ def test_legalize():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_none():
"""Test doing nothing by returning 'None' """
......@@ -89,7 +89,7 @@ def test_legalize_none():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
def test_legalize_multiple_ops():
......@@ -134,7 +134,7 @@ def test_legalize_multiple_ops():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_multi_input():
......@@ -170,7 +170,7 @@ def test_legalize_multi_input():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
......
......@@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'):
def check_func(func, ref_func):
func = run_infer_type(func)
ref_func = run_infer_type(ref_func)
assert analysis.graph_equal(func, ref_func)
assert tvm.ir.structural_equal(func, ref_func)
def test_module_pass():
......@@ -211,7 +211,7 @@ def test_function_class_pass():
mod = fpass(mod)
# wrap in expr
mod2 = tvm.IRModule.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"])
assert tvm.ir.structural_equal(mod["main"], mod2["main"])
def test_function_pass():
......@@ -496,7 +496,7 @@ def test_sequential_with_scoping():
zz = mod["main"]
zexpected = run_infer_type(expected())
assert analysis.alpha_equal(zz, zexpected)
assert tvm.ir.structural_equal(zz, zexpected)
def test_print_ir(capfd):
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for merge composite."""
import tvm
from tvm import relay
from tvm import tir
from tvm.relay.testing import run_opt_pass
......@@ -192,7 +193,7 @@ def test_simple_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_branch_merge():
......@@ -270,7 +271,7 @@ def test_branch_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_reuse_call_merge():
......@@ -329,7 +330,7 @@ def test_reuse_call_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_patterns():
......@@ -422,7 +423,7 @@ def test_multiple_patterns():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_merge_order():
......@@ -494,7 +495,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check B highest priority
pattern_table = [
......@@ -505,7 +506,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check C highest priority
pattern_table = [
......@@ -516,7 +517,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_parallel_merge():
......@@ -563,7 +564,7 @@ def test_parallel_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_input_subgraphs():
......@@ -676,13 +677,13 @@ def test_multiple_input_subgraphs():
result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check case 'B'
result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_B(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_tuple_get_item_merge():
......@@ -728,7 +729,7 @@ def test_tuple_get_item_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
if __name__ == "__main__":
......
......@@ -19,7 +19,6 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
......@@ -124,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType())
assert_alpha_equal(g, expected)
tvm.ir.assert_structural_equal(g, expected)
def test_if_ref():
......@@ -312,7 +311,7 @@ def test_concat():
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(dcpe(orig), orig)
tvm.ir.assert_structural_equal(dcpe(orig), orig)
def test_triangle_number():
......@@ -321,7 +320,7 @@ def test_triangle_number():
f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10))))
assert_alpha_equal(dcpe(orig), const(55))
tvm.ir.assert_structural_equal(dcpe(orig), const(55))
def test_nat_update():
......@@ -337,7 +336,7 @@ def test_tuple_match():
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert_alpha_equal(dcpe(x), const(2))
tvm.ir.assert_structural_equal(dcpe(x), const(2))
if __name__ == '__main__':
......
......@@ -339,7 +339,7 @@ def test_extern_ccompiler_default_ops():
fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected()
assert relay.alpha_equal(fused_mod, expected_mod)
assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
......@@ -427,7 +427,7 @@ def test_extern_dnnl():
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod)
assert relay.alpha_equal(mod, expected())
assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
ref_mod = tvm.IRModule()
ref_mod["main"] = get_func()
......@@ -561,7 +561,7 @@ def test_function_lifting():
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_function_lifting_inline():
......@@ -631,7 +631,7 @@ def test_function_lifting_inline():
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_constant_propagation():
......@@ -671,7 +671,7 @@ def test_constant_propagation():
mod = transform.PartitionGraph()(mod)
expected_mod = expected()
assert relay.alpha_equal(mod, expected_mod)
assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
y_data = np.random.rand(8, 8).astype('float32')
np_add = ones + y_data
......
......@@ -31,7 +31,7 @@ def alpha_equal(x, y):
"""
x = x['main']
y = y['main']
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
......@@ -85,12 +85,12 @@ def test_qnn_legalize():
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d():
......
......@@ -110,7 +110,7 @@ def test_call_globalvar_without_args():
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
assert relay.alpha_equal(mod, ref_mod)
assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
if __name__ == '__main__':
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm.ir import IRModule
from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference
......@@ -56,7 +56,7 @@ def test_simplify_batchnorm(dtype='float32'):
mod = simplify(mod)
y1 = mod["main"].body
assert rly.analysis.graph_equal(y1, y2)
assert structural_equal(y1, y2, map_free_vars=True)
check(2, 1, 1)
check(4, 1, 1)
......
......@@ -18,7 +18,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.analysis import detect_feature
from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count
......
......@@ -18,7 +18,6 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
from tvm.relay.analysis import assert_graph_equal
from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
from tvm.relay.adt import TypeData
......@@ -34,7 +33,8 @@ def check_visit(typ):
ev = TypeVisitor()
ev.visit(typ)
assert_graph_equal(TypeMutator().visit(typ), typ)
tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ,
map_free_vars=True)
def test_type_var():
......
......@@ -18,10 +18,9 @@
import tvm
def check_json_roundtrip(node):
from tvm.relay.analysis import graph_equal
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
assert graph_equal(back, node)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)
def test_prim_type():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment