Unverified Commit 919ae889 by Zhi Committed by GitHub

[REFACTOR][IR] alpha_equal to structural_equal (#5161)

parent 07ac7712
...@@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode { ...@@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
} }
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind); return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
} }
void SHashReduce(SHashReducer hash_reduce) const { void SHashReduce(SHashReducer hash_reduce) const {
......
...@@ -65,61 +65,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod); ...@@ -65,61 +65,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
TVM_DLL bool ConstantCheck(const Expr& e); TVM_DLL bool ConstantCheck(const Expr& e);
/*! /*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
/*!
* \brief Check that each Var is only bound once. * \brief Check that each Var is only bound once.
* *
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
"""Unified type system in the project.""" """Unified type system in the project."""
from enum import IntEnum from enum import IntEnum
import tvm
import tvm._ffi import tvm._ffi
from .base import Node from .base import Node
...@@ -26,7 +27,7 @@ class Type(Node): ...@@ -26,7 +27,7 @@ class Type(Node):
"""The base class of all types.""" """The base class of all types."""
def __eq__(self, other): def __eq__(self, other):
"""Compare two types for structural equivalence.""" """Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other)) return bool(tvm.ir.structural_equal(self, other))
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
......
...@@ -33,7 +33,6 @@ from . import parser ...@@ -33,7 +33,6 @@ from . import parser
from . import transform from . import transform
from . import analysis from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize from .build_module import build, create_executor, optimize
from .transform import build_config from .transform import build_config
from . import debug from . import debug
......
...@@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None): ...@@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod) return _ffi_api.all_type_vars(expr, use_mod)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_ffi_api._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_ffi_api._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_graph_equal(lhs, rhs)
def collect_device_info(expr): def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device """Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators. ids are propagated from the `device_copy` operators.
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
// NOTE: reverse dependency on relay. // NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level, // These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked. // and are only used in minimum cases where they are clearly marked.
...@@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod, ...@@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< AsText(func, false) << AsText(func, false)
<< std::endl; << std::endl;
} }
func = func = relay::Function(concat(func->params, fv),
relay::Function(concat(func->params, fv), func->body,
func->body, func->ret_type,
func->ret_type, concat(func->type_params, ftv),
concat(func->type_params, ftv), func->attrs);
func->attrs);
// Type check the item before we add it to the module. // Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var); relay::Function checked_func = InferType(func, mod, var);
return checked_func; return checked_func;
...@@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var, ...@@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
CHECK(update) CHECK(update)
<< "Already have definition for " << var->name_hint; << "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type(); auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type)) CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode."; << "Module#update changes type, not possible in this mode.";
} }
var->checked_type_ = type; var->checked_type_ = type;
...@@ -353,9 +353,8 @@ IRModule IRModule::FromExpr( ...@@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<BaseFuncNode>()) { if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node); func = GetRef<BaseFunc>(func_node);
} else { } else {
func = relay::Function( func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
relay::FreeTypeVars(expr, mod), {});
} }
auto main_gv = GlobalVar("main"); auto main_gv = GlobalVar("main");
mod->Add(main_gv, func); mod->Add(main_gv, func);
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file type_solver.cc * \file type_solver.cc
* \brief Type solver implementations. * \brief Type solver implementations.
*/ */
#include <tvm/node/structural_equal.h>
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <string> #include <string>
...@@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return rc.Check(t); return rc.Check(t);
} }
// default: unify only if alpha-equal // default: unify only if structural-equal
Type VisitTypeDefault_(const Object* op, const Type& tn) final { Type VisitTypeDefault_(const Object* op, const Type& tn) final {
ObjectRef nr = GetRef<ObjectRef>(op); ObjectRef nr = GetRef<ObjectRef>(op);
Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>()); Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) { if (!tvm::StructuralEqual()(t1, tn)) {
return Type(nullptr); return Type(nullptr);
} }
return t1; return t1;
...@@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
auto tt1 = GetRef<TensorType>(op); auto tt1 = GetRef<TensorType>(op);
auto tt2 = GetRef<TensorType>(tt_node); auto tt2 = GetRef<TensorType>(tt_node);
if (AlphaEqual(tt1, tt2)) { if (tvm::StructuralEqual()(tt1, tt2)) {
return std::move(tt1); return std::move(tt1);
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
...@@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal( ...@@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const { const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false; if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() && return this->target->str() == other->target->str() &&
AlphaEqual(this->source_func, other->source_func); tvm::StructuralEqual()(this->source_func, other->source_func);
} }
} // namespace relay } // namespace relay
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Lift all local functions into global functions. * \brief Lift all local functions into global functions.
*/ */
#include <tvm/node/structural_equal.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h> #include <tvm/support/logging.h>
...@@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator { ...@@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator {
if (module_->ContainGlobalVar(name)) { if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name); const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision"; CHECK(tvm::StructuralEqual()(lifted_func, existing_func))
<< "lifted function hash collision";
// If an identical function already exists, use its global var. // If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name); global = module_->GetGlobalVar(name);
} else { } else {
......
...@@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data, ...@@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data,
TVM_REGISTER_GLOBAL("relay.op._make.split") TVM_REGISTER_GLOBAL("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) { if (args.type_codes[1] == kDLInt) {
*rv = MakeSplit(args[0], tir::make_const(DataType::Int(64), int64_t(args[1])), args[2]); // Note: we change it from Int(64) to Int(32) for now as
// combine_parallel_dense will transform the graph with Int(32).
// More invetigation is needs to check which one we should use.
*rv = MakeSplit(args[0],
tir::make_const(DataType::Int(32), static_cast<int>(args[1])),
args[2]);
} else { } else {
*rv = MakeSplit(args[0], args[1], args[2]); *rv = MakeSplit(args[0], args[1], args[2]);
} }
......
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
* Thus, it is necessary to wrap this outer function so that the input/output types remain the same * Thus, it is necessary to wrap this outer function so that the input/output types remain the same
*/ */
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
...@@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> { ...@@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
Expr WrapExpr(const Expr expr, const Type& type) { Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) { if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"), return Call(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type}); {expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) { } else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields; tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) { for (size_t i = 0; i < type_anno->fields.size(); i++) {
...@@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { ...@@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Expr VisitExpr_(const ConstantNode* op) final { Expr VisitExpr_(const ConstantNode* op) final {
return Call(module_->GetConstructor("GradCell", "Raw"), return Call(module_->GetConstructor("GradCell", "Raw"),
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()}); {GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
} }
Expr VisitExpr_(const CallNode* call_node) final { Expr VisitExpr_(const CallNode* call_node) final {
...@@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { ...@@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
// call appropriate GradCell constructor // call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", constructor_name), return Call(module_->GetConstructor("GradCell", constructor_name),
{func}, Attrs(), {call_node->checked_type()}); {func}, Attrs(), {call_node->checked_type()});
} }
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
// ones_like and zeros_like need TensorType input // ones_like and zeros_like need TensorType input
Expr result = CallPrimitiveOp(call_node); Expr result = CallPrimitiveOp(call_node);
// fn() -> T, function returns result of operation // fn() -> T, function returns result of operation
Expr func = Function({}, result, Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
{call_node->checked_type()}, Array<TypeVar>());
// call appropriate GradCell constructor // call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", "One"), return Call(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()}); {func}, Attrs(), {call_node->checked_type()});
} }
// handle all other ops // handle all other ops
Expr result = CallPrimitiveOp(call_node); Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor // wrap result with Raw constructor
return Call(module_->GetConstructor("GradCell", "Raw"), {result}, return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()}); Attrs(), {call_node->checked_type()});
} }
// not an op // not an op
return ExprMutator::VisitExpr_(call_node); return ExprMutator::VisitExpr_(call_node);
...@@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { ...@@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
// can only use overloaded functions if 2 arguments of same type // can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 || if (call_node->args.size() != 2 ||
!AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { !tvm::StructuralEqual()(call_node->args[0]->checked_type(),
call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node); Expr result = CallPrimitiveOp(call_node);
return Call(module_->GetConstructor("GradCell", "Raw"), {result}, return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()}); Attrs(), {call_node->checked_type()});
} }
tvm::Array<Expr> args; tvm::Array<Expr> args;
...@@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { ...@@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator {
Var("rhs", paramType)}; Var("rhs", paramType)};
// use primitive op in this case // use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]}); Expr callOp = Call(call_node->op, {params[0], params[1]});
Expr func = Function(params, callOp, paramType, Expr func = Function(params, callOp, paramType, Array<TypeVar>());
Array<TypeVar>());
// pass "fallback" function and tensors as arguments // pass "fallback" function and tensors as arguments
args.push_back(func); args.push_back(func);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#include <builtin_fp16.h> #include <builtin_fp16.h>
#include <tvm/node/structural_equal.h>
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { ...@@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
return false; return false;
} }
return AlphaEqual(a, b); return tvm::StructuralEqual()(a, b);
} }
inline Expr GetField(Expr t, size_t i) { inline Expr GetField(Expr t, size_t i) {
......
...@@ -353,7 +353,7 @@ Function UnCPS(const Function& f) { ...@@ -353,7 +353,7 @@ Function UnCPS(const Function& f) {
auto answer_type = new_type_params.back(); auto answer_type = new_type_params.back();
new_type_params.pop_back(); new_type_params.pop_back();
// TODO(@M.K.): make alphaequal work on free term // TODO(@M.K.): make alphaequal work on free term
// CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); // CHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type)));
auto x = Var("x", new_ret_type); auto x = Var("x", new_ret_type);
auto cont = Function({x}, x, new_ret_type, {}, {}); auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args; tvm::Array<Expr> args;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
using namespace tvm;
class TestAlphaEquals {
runtime::PackedFunc *_packed_func;
public:
TestAlphaEquals(const char* func_name) {
_packed_func = new runtime::PackedFunc();
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
void UpdatePackedFunc(const char* func_name) {
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}
bool operator()(ObjectRef input_1, ObjectRef input_2) {
TVMRetValue rv;
std::vector<TVMValue> values(2);
std::vector<int> codes(2);
runtime::TVMArgsSetter setter(values.data(), codes.data());
setter(0, input_1);
setter(1, input_2);
_packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
return bool(rv);
};
};
TEST(Relay, AlphaTestEmptyTypeNodes) {
auto x = TypeVar("x", kTypeData);
auto y = TypeVar();
EXPECT_FALSE(relay::AlphaEqual(x, y));
TestAlphaEquals test_equals("relay._make._alpha_equal");
EXPECT_FALSE(test_equals(x, y));
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/node/structural_equal.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
...@@ -38,7 +39,7 @@ TEST(Relay, SelfReference) { ...@@ -38,7 +39,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main"); auto type_fx = mod->Lookup("main");
auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {}); auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
#include <tvm/node/structural_equal.h>
#include <tvm/driver/driver_api.h> #include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
...@@ -102,7 +103,7 @@ TEST(Relay, Sequential) { ...@@ -102,7 +103,7 @@ TEST(Relay, Sequential) {
auto mod1 = IRModule::FromExpr(expected_func); auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1); mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main"); auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected)); CHECK(tvm::StructuralEqual()(f, expected));
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Test graph equality of caffe2 models.""" """Test graph equality of caffe2 models."""
import tvm
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from model_zoo import c2_squeezenet, relay_squeezenet from model_zoo import c2_squeezenet, relay_squeezenet
...@@ -23,7 +24,7 @@ from model_zoo import c2_squeezenet, relay_squeezenet ...@@ -23,7 +24,7 @@ from model_zoo import c2_squeezenet, relay_squeezenet
def compare_graph(lhs_mod, rhs_mod): def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod) lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod) rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"]) assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_squeeze_net(): def test_squeeze_net():
......
...@@ -25,7 +25,7 @@ import model_zoo ...@@ -25,7 +25,7 @@ import model_zoo
def compare_graph(lhs_mod, rhs_mod): def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod) lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod) rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"]) assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_mlp(): def test_mlp():
shape = {"data": (1, 1, 28, 28)} shape = {"data": (1, 1, 28, 28)}
......
...@@ -77,7 +77,7 @@ def test_extract_identity(): ...@@ -77,7 +77,7 @@ def test_extract_identity():
mod["main"] = mod["main"].with_attr( mod["main"] = mod["main"].with_attr(
"Primitive", tvm.tir.IntImm("int32", 1)) "Primitive", tvm.tir.IntImm("int32", 1))
relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"]) tvm.ir.structural_equal(list(items.values())[0], mod["main"])
def test_extract_conv_net(): def test_extract_conv_net():
......
...@@ -136,7 +136,7 @@ def test_extern_dnnl(): ...@@ -136,7 +136,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape) mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod) mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape) ref_mod = expected(dtype, ishape, w1shape)
assert relay.analysis.alpha_equal(mod, ref_mod) assert tvm.ir.structural_equal(mod, ref_mod)
def test_run(): def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True): if not tvm.get_global_func("relay.ext.dnnl", True):
......
...@@ -27,7 +27,7 @@ def test_callgraph_construct(): ...@@ -27,7 +27,7 @@ def test_callgraph_construct():
mod["g1"] = relay.Function([x, y], x + y) mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.analysis.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph) assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module) assert tvm.ir.structural_equal(mod, call_graph.module)
def test_print_element(): def test_print_element():
......
...@@ -29,11 +29,11 @@ def test_bind_params(): ...@@ -29,11 +29,11 @@ def test_bind_params():
fexpected =relay.Function( fexpected =relay.Function(
[y], [y],
relay.add(relay.const(1, "float32"), y)) relay.add(relay.const(1, "float32"), y))
assert relay.analysis.alpha_equal(fbinded, fexpected) assert tvm.ir.structural_equal(fbinded, fexpected)
zbinded = relay.bind(z, {y: x}) zbinded = relay.bind(z, {y: x})
zexpected = relay.add(x, x) zexpected = relay.add(x, x)
assert relay.analysis.alpha_equal(zbinded, zexpected) assert tvm.ir.structural_equal(zbinded, zexpected)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,13 +21,12 @@ from tvm import te ...@@ -21,13 +21,12 @@ from tvm import te
from tvm import relay from tvm import relay
from tvm.tir.expr import * from tvm.tir.expr import *
from tvm.relay import op from tvm.relay import op
from tvm.relay.analysis import graph_equal
import numpy as np import numpy as np
def check_json_roundtrip(node): def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node) json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str) back = tvm.ir.load_json(json_str)
assert graph_equal(back, node) assert tvm.ir.structural_equal(back, node, map_free_vars=True)
# Span # Span
......
...@@ -107,7 +107,7 @@ def test_func_type_sequal(): ...@@ -107,7 +107,7 @@ def test_func_type_sequal():
ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tp1, tp3]),
tvm.runtime.convert([tr1])) tvm.runtime.convert([tr1]))
translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
tvm.runtime.convert([tp2, tp4]), tvm.runtime.convert([tp2, tp4]),
tvm.runtime.convert([tr2])) tvm.runtime.convert([tr2]))
assert ft == translate_vars assert ft == translate_vars
......
...@@ -20,7 +20,7 @@ from tvm import relay ...@@ -20,7 +20,7 @@ from tvm import relay
import tvm.relay.testing import tvm.relay.testing
import numpy as np import numpy as np
from tvm.relay import Expr from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars from tvm.relay.analysis import free_vars
do_print = [False] do_print = [False]
...@@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False): ...@@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False):
return txt return txt
x = relay.fromtext(txt) x = relay.fromtext(txt)
if unify_free_vars: if unify_free_vars:
assert_graph_equal(x, p) tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
else: else:
assert_alpha_equal(x, p) tvm.ir.assert_structural_equal(x, p)
return txt return txt
def show(text): def show(text):
......
...@@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal(): ...@@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal():
""" """
) )
relay.analysis.assert_alpha_equal(df, df_parsed) tvm.ir.assert_structural_equal(df, df_parsed)
def test_checkpoint_alpha_equal_tuple(): def test_checkpoint_alpha_equal_tuple():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
...@@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple(): ...@@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple():
""" """
) )
relay.analysis.assert_alpha_equal(df, df_parsed) tvm.ir.assert_structural_equal(df, df_parsed)
def test_collapse_sum_like(): def test_collapse_sum_like():
shape = (3, 4, 5, 6) shape = (3, 4, 5, 6)
......
...@@ -66,7 +66,7 @@ def test_alter_op(): ...@@ -66,7 +66,7 @@ def test_alter_op():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_return_none(): def test_alter_return_none():
...@@ -88,7 +88,7 @@ def test_alter_return_none(): ...@@ -88,7 +88,7 @@ def test_alter_return_none():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(before(), transform.InferType()) b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
...@@ -151,7 +151,7 @@ def test_alter_layout(): ...@@ -151,7 +151,7 @@ def test_alter_layout():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_dual_path(): def test_alter_layout_dual_path():
...@@ -214,7 +214,7 @@ def test_alter_layout_dual_path(): ...@@ -214,7 +214,7 @@ def test_alter_layout_dual_path():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_resnet(): def test_alter_layout_resnet():
"""Test alternating the layout of a residual block """Test alternating the layout of a residual block
...@@ -271,7 +271,7 @@ def test_alter_layout_resnet(): ...@@ -271,7 +271,7 @@ def test_alter_layout_resnet():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_op(): def test_alter_layout_broadcast_op():
...@@ -318,7 +318,7 @@ def test_alter_layout_broadcast_op(): ...@@ -318,7 +318,7 @@ def test_alter_layout_broadcast_op():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_scalar_op(): def test_alter_layout_broadcast_scalar_op():
...@@ -381,7 +381,7 @@ def test_alter_layout_broadcast_scalar_op(): ...@@ -381,7 +381,7 @@ def test_alter_layout_broadcast_scalar_op():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar(): def test_alter_layout_scalar():
...@@ -424,7 +424,7 @@ def test_alter_layout_scalar(): ...@@ -424,7 +424,7 @@ def test_alter_layout_scalar():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate(): def test_alter_layout_concatenate():
...@@ -478,7 +478,7 @@ def test_alter_layout_concatenate(): ...@@ -478,7 +478,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType()) b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# NHWC layout transformation. # NHWC layout transformation.
def before_nhwc(): def before_nhwc():
...@@ -524,7 +524,7 @@ def test_alter_layout_concatenate(): ...@@ -524,7 +524,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType()) b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nchw_upsamping_op(): def test_alter_layout_nchw_upsamping_op():
...@@ -561,7 +561,7 @@ def test_alter_layout_nchw_upsamping_op(): ...@@ -561,7 +561,7 @@ def test_alter_layout_nchw_upsamping_op():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_strided_slice(): def test_alter_layout_strided_slice():
...@@ -597,7 +597,7 @@ def test_alter_layout_strided_slice(): ...@@ -597,7 +597,7 @@ def test_alter_layout_strided_slice():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_depthwise_conv2d(): def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator""" """Test depthwise_conv2d operator"""
...@@ -632,7 +632,7 @@ def test_alter_layout_depthwise_conv2d(): ...@@ -632,7 +632,7 @@ def test_alter_layout_depthwise_conv2d():
transform.AlterOpLayout()]) transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b)) assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_prelu(): def test_alter_layout_prelu():
"""Test PRelu operator""" """Test PRelu operator"""
...@@ -672,7 +672,7 @@ def test_alter_layout_prelu(): ...@@ -672,7 +672,7 @@ def test_alter_layout_prelu():
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert(analysis.alpha_equal(a, b)) assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_pad(): def test_alter_layout_pad():
...@@ -715,7 +715,7 @@ def test_alter_layout_pad(): ...@@ -715,7 +715,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType()) b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion. # Check NHWC conversion.
def before_nhwc(): def before_nhwc():
...@@ -749,7 +749,7 @@ def test_alter_layout_pad(): ...@@ -749,7 +749,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType()) b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that conversion does not happen when padding along split axis. # Check that conversion does not happen when padding along split axis.
def before(): def before():
...@@ -782,7 +782,7 @@ def test_alter_layout_pad(): ...@@ -782,7 +782,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_pool(): def test_alter_layout_pool():
...@@ -825,7 +825,7 @@ def test_alter_layout_pool(): ...@@ -825,7 +825,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType()) b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion. # Check NHWC conversion.
def before_nhwc(): def before_nhwc():
...@@ -859,7 +859,7 @@ def test_alter_layout_pool(): ...@@ -859,7 +859,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType()) b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum(): def test_alter_layout_sum():
...@@ -902,7 +902,7 @@ def test_alter_layout_sum(): ...@@ -902,7 +902,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType()) b = run_opt_pass(expected_nchw(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion. # Check NHWC conversion.
def before_nhwc(): def before_nhwc():
...@@ -937,7 +937,7 @@ def test_alter_layout_sum(): ...@@ -937,7 +937,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType()) b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the # TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
...@@ -999,7 +999,7 @@ def test_alter_layout_nhwc_nchw_arm(): ...@@ -999,7 +999,7 @@ def test_alter_layout_nhwc_nchw_arm():
a = run_opt_pass(a, transform.AlterOpLayout()) a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType()) b = run_opt_pass(expected_nhwc(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_op_with_global_var(): def test_alter_op_with_global_var():
"""Test directly replacing an operator with a new one""" """Test directly replacing an operator with a new one"""
...@@ -1041,7 +1041,7 @@ def test_alter_op_with_global_var(): ...@@ -1041,7 +1041,7 @@ def test_alter_op_with_global_var():
a = transform.AlterOpLayout()(a) a = transform.AlterOpLayout()(a)
b = transform.InferType()(expected()) b = transform.InferType()(expected())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a)
if __name__ == "__main__": if __name__ == "__main__":
test_alter_op() test_alter_op()
......
...@@ -64,7 +64,7 @@ def test_redundant_annotation(): ...@@ -64,7 +64,7 @@ def test_redundant_annotation():
annotated_func = annotated() annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType()) expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(annotated_func, expected_func) assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_expr(): def test_annotate_expr():
...@@ -91,7 +91,7 @@ def test_annotate_expr(): ...@@ -91,7 +91,7 @@ def test_annotate_expr():
annotated_expr = annotated() annotated_expr = annotated()
expected_expr = run_opt_pass(expected(), transform.InferType()) expected_expr = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_expr, expected_expr) assert tvm.ir.structural_equal(annotated_expr, expected_expr)
def test_annotate_all(): def test_annotate_all():
...@@ -120,7 +120,7 @@ def test_annotate_all(): ...@@ -120,7 +120,7 @@ def test_annotate_all():
annotated_func = annotated() annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType()) expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func) assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_none(): def test_annotate_none():
...@@ -146,13 +146,13 @@ def test_annotate_none(): ...@@ -146,13 +146,13 @@ def test_annotate_none():
annotated_func = annotated() annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType()) expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func) assert tvm.ir.structural_equal(annotated_func, expected_func)
def check_annotated_graph(annotated_func, expected_func): def check_annotated_graph(annotated_func, expected_func):
annotated_func = run_opt_pass(annotated_func, transform.InferType()) annotated_func = run_opt_pass(annotated_func, transform.InferType())
expected_func = run_opt_pass(expected_func, transform.InferType()) expected_func = run_opt_pass(expected_func, transform.InferType())
assert relay.analysis.alpha_equal(annotated_func, expected_func) assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_conv_network(): def test_conv_network():
...@@ -596,7 +596,7 @@ def test_tuple_get_item(): ...@@ -596,7 +596,7 @@ def test_tuple_get_item():
annotated_func = annotated() annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType()) expected_func = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(annotated_func, expected_func) assert tvm.ir.structural_equal(annotated_func, expected_func)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -64,7 +64,7 @@ def test_canonicalize_cast(): ...@@ -64,7 +64,7 @@ def test_canonicalize_cast():
mod[gv] = y_expected mod[gv] = y_expected
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
y_expected = mod["expected"] y_expected = mod["expected"]
assert relay.analysis.alpha_equal(y, y_expected) assert tvm.ir.structural_equal(y, y_expected)
check((1, 16, 7, 7)) check((1, 16, 7, 7))
......
...@@ -72,7 +72,7 @@ def test_combine_parallel_conv2d(): ...@@ -72,7 +72,7 @@ def test_combine_parallel_conv2d():
transform.CombineParallelConv2D(min_num_branches=2)) transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 4, 4, 4) check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7) check((1, 4, 16, 16), 4, 8, 4, 7)
...@@ -118,7 +118,7 @@ def test_combine_parallel_conv2d_scale_relu(): ...@@ -118,7 +118,7 @@ def test_combine_parallel_conv2d_scale_relu():
transform.CombineParallelConv2D(min_num_branches=2)) transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8) check((1, 4, 16, 16), 4, 8)
...@@ -157,7 +157,7 @@ def test_combine_parallel_conv2d_scale(): ...@@ -157,7 +157,7 @@ def test_combine_parallel_conv2d_scale():
transform.CombineParallelConv2D(min_num_branches=2)) transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8) check((1, 4, 16, 16), 4, 8)
...@@ -193,7 +193,7 @@ def test_combine_parallel_conv2d_multiple_blocks(): ...@@ -193,7 +193,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
transform.CombineParallelConv2D(min_num_branches=2)) transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w, out_c, repeat) y_expected = expected(x, w, out_c, repeat)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4) check((1, 4, 16, 16), 4)
......
...@@ -75,7 +75,7 @@ def test_combine_parallel_dense(): ...@@ -75,7 +75,7 @@ def test_combine_parallel_dense():
transform.CombineParallelDense(min_num_branches=2)) transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4) y_expected = expected(x, w1, w2, w3, w4)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4) check(3, 5, 4)
check(100, 200, 300) check(100, 200, 300)
...@@ -127,7 +127,7 @@ def test_combine_parallel_dense_biasadd(): ...@@ -127,7 +127,7 @@ def test_combine_parallel_dense_biasadd():
transform.CombineParallelDense(min_num_branches=2)) transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, is_2d_bias) y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, False) check(3, 5, 4, False)
check(100, 200, 300, False) check(100, 200, 300, False)
...@@ -184,7 +184,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape(): ...@@ -184,7 +184,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
transform.CombineParallelDense(min_num_branches=2)) transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape) y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
y_expected = run_opt_pass(y_expected, transform.InferType()) y_expected = run_opt_pass(y_expected, transform.InferType())
assert relay.analysis.alpha_equal(y, y_expected) tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, 0.5, 0.25, (1, 1, 15)) check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
......
...@@ -52,7 +52,7 @@ def test_no_convert_layout(): ...@@ -52,7 +52,7 @@ def test_no_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_convert_layout(): def test_conv_convert_layout():
...@@ -87,7 +87,7 @@ def test_conv_convert_layout(): ...@@ -87,7 +87,7 @@ def test_conv_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bias_pool_convert_layout(): def test_conv_bias_pool_convert_layout():
...@@ -132,7 +132,7 @@ def test_conv_bias_pool_convert_layout(): ...@@ -132,7 +132,7 @@ def test_conv_bias_pool_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_concat_convert_layout(): def test_conv_concat_convert_layout():
...@@ -180,7 +180,7 @@ def test_conv_concat_convert_layout(): ...@@ -180,7 +180,7 @@ def test_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_dual_path_convert_layout(): def test_dual_path_convert_layout():
...@@ -235,7 +235,7 @@ def test_dual_path_convert_layout(): ...@@ -235,7 +235,7 @@ def test_dual_path_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_bn_convert_layout(): def test_bn_convert_layout():
...@@ -315,7 +315,7 @@ def test_resnet_convert_layout(): ...@@ -315,7 +315,7 @@ def test_resnet_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_scalar_convert_layout(): def test_scalar_convert_layout():
...@@ -347,7 +347,7 @@ def test_scalar_convert_layout(): ...@@ -347,7 +347,7 @@ def test_scalar_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bn_convert_layout(): def test_conv_bn_convert_layout():
...@@ -395,7 +395,7 @@ def test_conv_bn_convert_layout(): ...@@ -395,7 +395,7 @@ def test_conv_bn_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_requantize_convert_layout(): def test_qnn_conv_requantize_convert_layout():
...@@ -451,7 +451,7 @@ def test_qnn_conv_requantize_convert_layout(): ...@@ -451,7 +451,7 @@ def test_qnn_conv_requantize_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_concat_convert_layout(): def test_qnn_conv_concat_convert_layout():
...@@ -529,7 +529,7 @@ def test_qnn_conv_concat_convert_layout(): ...@@ -529,7 +529,7 @@ def test_qnn_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_add_convert_layout(): def test_qnn_conv_add_convert_layout():
...@@ -609,7 +609,7 @@ def test_qnn_conv_add_convert_layout(): ...@@ -609,7 +609,7 @@ def test_qnn_conv_add_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW')) a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -18,7 +18,7 @@ import tvm ...@@ -18,7 +18,7 @@ import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay import Function, transform from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal from tvm.relay.analysis import free_vars
from tvm.relay.op import log, add, equal, subtract from tvm.relay.op import log, add, equal, subtract
from tvm.relay.testing import inception_v3 from tvm.relay.testing import inception_v3
...@@ -69,7 +69,7 @@ def test_used_let(): ...@@ -69,7 +69,7 @@ def test_used_let():
def test_inline(): def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination(True)) orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
def test_chain_unused_let(): def test_chain_unused_let():
...@@ -105,7 +105,7 @@ def test_recursion(): ...@@ -105,7 +105,7 @@ def test_recursion():
orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)])) orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination()) dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType()) orig = run_opt_pass(orig, transform.InferType())
assert_alpha_equal(dced, orig) tvm.ir.assert_structural_equal(dced, orig)
def test_recursion_dead(): def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three) x = relay.Let(e.a, e.one, e.three)
......
...@@ -52,7 +52,7 @@ def test_simple(): ...@@ -52,7 +52,7 @@ def test_simple():
z = before() z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr()) z = run_opt_pass(z, transform.EliminateCommonSubexpr())
assert analysis.alpha_equal(z, expected()) assert tvm.ir.structural_equal(z, expected())
def test_callback(): def test_callback():
...@@ -82,7 +82,7 @@ def test_callback(): ...@@ -82,7 +82,7 @@ def test_callback():
z = before() z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
assert analysis.alpha_equal(z, expected()) assert tvm.ir.structural_equal(z, expected())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -47,7 +47,8 @@ def test_eta_expand_global_var(): ...@@ -47,7 +47,8 @@ def test_eta_expand_global_var():
} }
} }
""") """)
relay.analysis.assert_graph_equal(mod['main'], expected['main']) tvm.ir.assert_structural_equal(mod['main'], expected['main'],
map_free_vars=True)
def test_eta_expand_constructor(): def test_eta_expand_constructor():
...@@ -76,7 +77,8 @@ def test_eta_expand_constructor(): ...@@ -76,7 +77,8 @@ def test_eta_expand_constructor():
} }
} }
""") """)
relay.analysis.assert_graph_equal(mod['main'], expected['main']) tvm.ir.assert_structural_equal(mod['main'], expected['main'],
map_free_vars=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -59,7 +59,7 @@ def test_fold_const(): ...@@ -59,7 +59,7 @@ def test_fold_const():
with tvm.target.create("cuda"): with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_let(): def test_fold_let():
...@@ -84,7 +84,7 @@ def test_fold_let(): ...@@ -84,7 +84,7 @@ def test_fold_let():
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_tuple(): def test_fold_tuple():
...@@ -106,7 +106,7 @@ def test_fold_tuple(): ...@@ -106,7 +106,7 @@ def test_fold_tuple():
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_concat(): def test_fold_concat():
...@@ -125,7 +125,7 @@ def test_fold_concat(): ...@@ -125,7 +125,7 @@ def test_fold_concat():
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_shape_of(): def test_fold_shape_of():
...@@ -146,7 +146,7 @@ def test_fold_shape_of(): ...@@ -146,7 +146,7 @@ def test_fold_shape_of():
for dtype in ["int32", "float32"]: for dtype in ["int32", "float32"]:
zz = run_opt_pass(before(dtype), transform.FoldConstant()) zz = run_opt_pass(before(dtype), transform.FoldConstant())
zexpected = run_opt_pass(expected(dtype), transform.InferType()) zexpected = run_opt_pass(expected(dtype), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_full(): def test_fold_full():
...@@ -161,7 +161,7 @@ def test_fold_full(): ...@@ -161,7 +161,7 @@ def test_fold_full():
zz = run_opt_pass(before(), transform.FoldConstant()) zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType()) zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_batch_norm(): def test_fold_batch_norm():
...@@ -202,7 +202,7 @@ def test_fold_batch_norm(): ...@@ -202,7 +202,7 @@ def test_fold_batch_norm():
mod = remove_bn_pass(mod) mod = remove_bn_pass(mod)
expect = run_infer_type(expected()) expect = run_infer_type(expected())
assert relay.analysis.graph_equal(mod["main"], expect) assert tvm.ir.structural_equal(mod["main"], expect)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,7 +79,7 @@ def test_fold_fwd_simple(): ...@@ -79,7 +79,7 @@ def test_fold_fwd_simple():
y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_folded = run_opt_pass(y1_folded, transform.InferType())
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2) check((2, 4, 10, 10), 2)
...@@ -148,7 +148,7 @@ def test_fold_fwd_dual_path(): ...@@ -148,7 +148,7 @@ def test_fold_fwd_dual_path():
weight = relay.var("weight", type_dict["weight"]) weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels) y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3) check((2, 4, 10, 3), 3)
...@@ -177,7 +177,7 @@ def test_fold_fwd_fail(): ...@@ -177,7 +177,7 @@ def test_fold_fwd_fail():
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType()) y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded) assert tvm.ir.structural_equal(y1, y1_folded)
check((2, 11, 10, 4), 4) check((2, 11, 10, 4), 4)
...@@ -205,7 +205,7 @@ def test_fold_fwd_relu_fail(): ...@@ -205,7 +205,7 @@ def test_fold_fwd_relu_fail():
y1 = before(x, weight, in_bias, in_scale, channels) y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType()) y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded) assert tvm.ir.structural_equal(y1, y1_folded)
in_scale = relay.var("in_scale", shape=(4,)) in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale) check((2, 11, 10, 4), 4, in_scale)
...@@ -249,7 +249,7 @@ def test_fold_fwd_negative_scale(): ...@@ -249,7 +249,7 @@ def test_fold_fwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
y1_expected = expected(x, weight, in_scale, channels) y1_expected = expected(x, weight, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4) check((2, 4, 10, 10), 4)
...@@ -300,7 +300,7 @@ def test_fold_bwd_simple(): ...@@ -300,7 +300,7 @@ def test_fold_bwd_simple():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
...@@ -359,7 +359,7 @@ def test_fold_bwd_dual_path(): ...@@ -359,7 +359,7 @@ def test_fold_bwd_dual_path():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
...@@ -431,7 +431,7 @@ def test_fold_bwd_dual_consumer(): ...@@ -431,7 +431,7 @@ def test_fold_bwd_dual_consumer():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4) check((2, 4, 10, 10), 4)
...@@ -480,7 +480,7 @@ def test_fold_bwd_fail(): ...@@ -480,7 +480,7 @@ def test_fold_bwd_fail():
y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType()) y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1_folded, y1) assert tvm.ir.structural_equal(y1_folded, y1)
check((4, 4, 10, 10), 4, fail1) check((4, 4, 10, 10), 4, fail1)
check((4, 4, 10, 10), 4, fail2) check((4, 4, 10, 10), 4, fail2)
...@@ -505,7 +505,7 @@ def test_fold_bwd_relu_fail(): ...@@ -505,7 +505,7 @@ def test_fold_bwd_relu_fail():
y1 = before(x, weight, out_scale, channels) y1 = before(x, weight, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType()) y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
assert relay.analysis.alpha_equal(y1, y1_folded) assert tvm.ir.structural_equal(y1, y1_folded)
out_scale = relay.var("in_scale", shape=(4, 1, 1)) out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, out_scale) check((4, 4, 10, 10), 4, out_scale)
...@@ -547,7 +547,7 @@ def test_fold_bwd_negative_scale(): ...@@ -547,7 +547,7 @@ def test_fold_bwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_scale, channels) y1_expected = expected(x, weight, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType())
assert relay.analysis.alpha_equal(y1_folded, y1_expected) assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8) check((2, 4, 10, 10), 8)
......
...@@ -45,7 +45,7 @@ def test_fuse_simple(): ...@@ -45,7 +45,7 @@ def test_fuse_simple():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps()) zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType()) after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_conv2d_fuse(): def test_conv2d_fuse():
...@@ -127,7 +127,7 @@ def test_conv2d_fuse(): ...@@ -127,7 +127,7 @@ def test_conv2d_fuse():
z = before(dshape) z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_concatenate(): def test_concatenate():
...@@ -167,7 +167,7 @@ def test_concatenate(): ...@@ -167,7 +167,7 @@ def test_concatenate():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz) assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_tuple_root(): def test_tuple_root():
...@@ -204,7 +204,7 @@ def test_tuple_root(): ...@@ -204,7 +204,7 @@ def test_tuple_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz) assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_stop_fusion(): def test_stop_fusion():
...@@ -235,7 +235,7 @@ def test_stop_fusion(): ...@@ -235,7 +235,7 @@ def test_stop_fusion():
z = before(dshape) z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps()) zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_fuse_myia_regression(): def test_fuse_myia_regression():
...@@ -271,7 +271,7 @@ def test_fuse_myia_regression(): ...@@ -271,7 +271,7 @@ def test_fuse_myia_regression():
f = before(dshape, dtype) f = before(dshape, dtype)
zz = run_opt_pass(f, transform.FuseOps()) zz = run_opt_pass(f, transform.FuseOps())
after = run_opt_pass(expected(dshape, dtype), transform.InferType()) after = run_opt_pass(expected(dshape, dtype), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_fuse_tuple_get_elemwise(): def test_fuse_tuple_get_elemwise():
...@@ -309,7 +309,7 @@ def test_fuse_tuple_get_elemwise(): ...@@ -309,7 +309,7 @@ def test_fuse_tuple_get_elemwise():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz) assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType()) after = run_opt_pass(expected(dim), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_tuple_get_root(): def test_tuple_get_root():
...@@ -346,7 +346,7 @@ def test_tuple_get_root(): ...@@ -346,7 +346,7 @@ def test_tuple_get_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz) assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType()) after = run_opt_pass(expected(dim), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
fuse0 = relay.transform.FuseOps(fuse_opt_level=0) fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
...@@ -379,7 +379,7 @@ def test_tuple_intermediate(): ...@@ -379,7 +379,7 @@ def test_tuple_intermediate():
m = fuse2(tvm.IRModule.from_expr(orig)) m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(x), transform.InferType()) after = run_opt_pass(expected(x), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after) assert tvm.ir.structural_equal(m["main"], after)
def test_tuple_consecutive(): def test_tuple_consecutive():
...@@ -437,7 +437,7 @@ def test_tuple_consecutive(): ...@@ -437,7 +437,7 @@ def test_tuple_consecutive():
m = fuse2(tvm.IRModule.from_expr(orig)) m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after) assert tvm.ir.structural_equal(m["main"], after)
def test_inception_like(): def test_inception_like():
...@@ -510,7 +510,7 @@ def test_inception_like(): ...@@ -510,7 +510,7 @@ def test_inception_like():
m = fuse2(tvm.IRModule.from_expr(orig)) m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m["main"], after) assert tvm.ir.structural_equal(m["main"], after)
def test_fuse_parallel_injective(): def test_fuse_parallel_injective():
...@@ -541,7 +541,7 @@ def test_fuse_parallel_injective(): ...@@ -541,7 +541,7 @@ def test_fuse_parallel_injective():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz) assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(), transform.InferType()) after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
def test_immutable(): def test_immutable():
...@@ -570,8 +570,8 @@ def test_immutable(): ...@@ -570,8 +570,8 @@ def test_immutable():
mod = before() mod = before()
new_mod = transform.FuseOps(fuse_opt_level=2)(mod) new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
assert relay.analysis.alpha_equal(mod, before()) assert tvm.ir.structural_equal(mod, before())
assert relay.analysis.alpha_equal(new_mod, expected()) assert tvm.ir.structural_equal(new_mod, expected())
def test_split(): def test_split():
...@@ -619,7 +619,7 @@ def test_fuse_max(): ...@@ -619,7 +619,7 @@ def test_fuse_max():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps()) zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType()) after = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.alpha_equal(zz, after) assert tvm.ir.structural_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
...@@ -292,7 +292,7 @@ def test_concat(): ...@@ -292,7 +292,7 @@ def test_concat():
func = relay.Function([x], y) func = relay.Function([x], y)
func = run_infer_type(func) func = run_infer_type(func)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func))
assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])]))) tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
# no value validation as concatenate has dummy gradient right now. # no value validation as concatenate has dummy gradient right now.
......
...@@ -115,7 +115,7 @@ def test_call_chain_inline_leaf(): ...@@ -115,7 +115,7 @@ def test_call_chain_inline_leaf():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels(): def test_call_chain_inline_multiple_levels():
...@@ -188,7 +188,7 @@ def test_call_chain_inline_multiple_levels(): ...@@ -188,7 +188,7 @@ def test_call_chain_inline_multiple_levels():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels_extern_compiler(): def test_call_chain_inline_multiple_levels_extern_compiler():
...@@ -266,7 +266,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler(): ...@@ -266,7 +266,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_call_with_global(): def test_recursive_call_with_global():
...@@ -321,7 +321,7 @@ def test_recursive_call_with_global(): ...@@ -321,7 +321,7 @@ def test_recursive_call_with_global():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_called(): def test_recursive_called():
...@@ -330,7 +330,7 @@ def test_recursive_called(): ...@@ -330,7 +330,7 @@ def test_recursive_called():
mod["main"] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
ref_mod = mod ref_mod = mod
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, ref_mod) assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called(): def test_recursive_not_called():
...@@ -356,7 +356,7 @@ def test_recursive_not_called(): ...@@ -356,7 +356,7 @@ def test_recursive_not_called():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
ref_mod = expected() ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod) assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called_extern_compiler(): def test_recursive_not_called_extern_compiler():
...@@ -387,7 +387,7 @@ def test_recursive_not_called_extern_compiler(): ...@@ -387,7 +387,7 @@ def test_recursive_not_called_extern_compiler():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
ref_mod = expected() ref_mod = expected()
assert relay.analysis.alpha_equal(mod, ref_mod) assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_globalvar_as_call_arg(): def test_globalvar_as_call_arg():
...@@ -434,7 +434,7 @@ def test_globalvar_as_call_arg(): ...@@ -434,7 +434,7 @@ def test_globalvar_as_call_arg():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_as_call_arg_extern_compiler(): def test_globalvar_as_call_arg_extern_compiler():
...@@ -500,7 +500,7 @@ def test_globalvar_as_call_arg_extern_compiler(): ...@@ -500,7 +500,7 @@ def test_globalvar_as_call_arg_extern_compiler():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args(): def test_inline_globalvar_without_args():
...@@ -531,7 +531,7 @@ def test_inline_globalvar_without_args(): ...@@ -531,7 +531,7 @@ def test_inline_globalvar_without_args():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args_extern_compiler(): def test_inline_globalvar_without_args_extern_compiler():
...@@ -566,7 +566,7 @@ def test_inline_globalvar_without_args_extern_compiler(): ...@@ -566,7 +566,7 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_called_by_multiple_functions(): def test_globalvar_called_by_multiple_functions():
...@@ -644,7 +644,7 @@ def test_globalvar_called_by_multiple_functions(): ...@@ -644,7 +644,7 @@ def test_globalvar_called_by_multiple_functions():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_entry_with_inline(): def test_entry_with_inline():
...@@ -674,7 +674,7 @@ def test_entry_with_inline(): ...@@ -674,7 +674,7 @@ def test_entry_with_inline():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod()) assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline(): def test_callee_not_inline():
...@@ -707,7 +707,7 @@ def test_callee_not_inline(): ...@@ -707,7 +707,7 @@ def test_callee_not_inline():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, get_mod()) assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline_leaf_inline(): def test_callee_not_inline_leaf_inline():
...@@ -765,7 +765,7 @@ def test_callee_not_inline_leaf_inline(): ...@@ -765,7 +765,7 @@ def test_callee_not_inline_leaf_inline():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_callee_not_inline_leaf_inline_extern_compiler(): def test_callee_not_inline_leaf_inline_extern_compiler():
...@@ -830,7 +830,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler(): ...@@ -830,7 +830,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
mod = get_mod() mod = get_mod()
mod = relay.transform.Inline()(mod) mod = relay.transform.Inline()(mod)
assert relay.analysis.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -68,7 +68,7 @@ def test_legalize(): ...@@ -68,7 +68,7 @@ def test_legalize():
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_none(): def test_legalize_none():
"""Test doing nothing by returning 'None' """ """Test doing nothing by returning 'None' """
...@@ -89,7 +89,7 @@ def test_legalize_none(): ...@@ -89,7 +89,7 @@ def test_legalize_none():
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType()) b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
def test_legalize_multiple_ops(): def test_legalize_multiple_ops():
...@@ -134,7 +134,7 @@ def test_legalize_multiple_ops(): ...@@ -134,7 +134,7 @@ def test_legalize_multiple_ops():
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_multi_input(): def test_legalize_multi_input():
...@@ -170,7 +170,7 @@ def test_legalize_multi_input(): ...@@ -170,7 +170,7 @@ def test_legalize_multi_input():
a = run_opt_pass(a, transform.Legalize()) a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'): ...@@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'):
def check_func(func, ref_func): def check_func(func, ref_func):
func = run_infer_type(func) func = run_infer_type(func)
ref_func = run_infer_type(ref_func) ref_func = run_infer_type(ref_func)
assert analysis.graph_equal(func, ref_func) assert tvm.ir.structural_equal(func, ref_func)
def test_module_pass(): def test_module_pass():
...@@ -211,7 +211,7 @@ def test_function_class_pass(): ...@@ -211,7 +211,7 @@ def test_function_class_pass():
mod = fpass(mod) mod = fpass(mod)
# wrap in expr # wrap in expr
mod2 = tvm.IRModule.from_expr(f1) mod2 = tvm.IRModule.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"]) assert tvm.ir.structural_equal(mod["main"], mod2["main"])
def test_function_pass(): def test_function_pass():
...@@ -496,7 +496,7 @@ def test_sequential_with_scoping(): ...@@ -496,7 +496,7 @@ def test_sequential_with_scoping():
zz = mod["main"] zz = mod["main"]
zexpected = run_infer_type(expected()) zexpected = run_infer_type(expected())
assert analysis.alpha_equal(zz, zexpected) assert tvm.ir.structural_equal(zz, zexpected)
def test_print_ir(capfd): def test_print_ir(capfd):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Unit tests for merge composite.""" """Unit tests for merge composite."""
import tvm
from tvm import relay from tvm import relay
from tvm import tir from tvm import tir
from tvm.relay.testing import run_opt_pass from tvm.relay.testing import run_opt_pass
...@@ -192,7 +193,7 @@ def test_simple_merge(): ...@@ -192,7 +193,7 @@ def test_simple_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType()) expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_branch_merge(): def test_branch_merge():
...@@ -270,7 +271,7 @@ def test_branch_merge(): ...@@ -270,7 +271,7 @@ def test_branch_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType()) expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_reuse_call_merge(): def test_reuse_call_merge():
...@@ -329,7 +330,7 @@ def test_reuse_call_merge(): ...@@ -329,7 +330,7 @@ def test_reuse_call_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType()) expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_patterns(): def test_multiple_patterns():
...@@ -422,7 +423,7 @@ def test_multiple_patterns(): ...@@ -422,7 +423,7 @@ def test_multiple_patterns():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType()) expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_merge_order(): def test_merge_order():
...@@ -494,7 +495,7 @@ def test_merge_order(): ...@@ -494,7 +495,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check B highest priority # check B highest priority
pattern_table = [ pattern_table = [
...@@ -505,7 +506,7 @@ def test_merge_order(): ...@@ -505,7 +506,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check C highest priority # check C highest priority
pattern_table = [ pattern_table = [
...@@ -516,7 +517,7 @@ def test_merge_order(): ...@@ -516,7 +517,7 @@ def test_merge_order():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_parallel_merge(): def test_parallel_merge():
...@@ -563,7 +564,7 @@ def test_parallel_merge(): ...@@ -563,7 +564,7 @@ def test_parallel_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after(), relay.transform.InferType()) expected = run_opt_pass(after(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_input_subgraphs(): def test_multiple_input_subgraphs():
...@@ -676,13 +677,13 @@ def test_multiple_input_subgraphs(): ...@@ -676,13 +677,13 @@ def test_multiple_input_subgraphs():
result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A(), relay.transform.InferType()) expected = run_opt_pass(after_A(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check case 'B' # check case 'B'
result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_B(), relay.transform.InferType()) expected = run_opt_pass(after_B(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_tuple_get_item_merge(): def test_tuple_get_item_merge():
...@@ -728,7 +729,7 @@ def test_tuple_get_item_merge(): ...@@ -728,7 +729,7 @@ def test_tuple_get_item_merge():
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType()) expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
...@@ -124,7 +123,7 @@ def test_ad(): ...@@ -124,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body) body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body)) expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType()) expected = run_opt_pass(expected, transform.InferType())
assert_alpha_equal(g, expected) tvm.ir.assert_structural_equal(g, expected)
def test_if_ref(): def test_if_ref():
...@@ -312,7 +311,7 @@ def test_concat(): ...@@ -312,7 +311,7 @@ def test_concat():
x = Var("x", t) x = Var("x", t)
y = Var("x", t) y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0))) orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(dcpe(orig), orig) tvm.ir.assert_structural_equal(dcpe(orig), orig)
def test_triangle_number(): def test_triangle_number():
...@@ -321,7 +320,7 @@ def test_triangle_number(): ...@@ -321,7 +320,7 @@ def test_triangle_number():
f_var = Var("f") f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1)))) f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10)))) orig = run_infer_type(Let(f_var, f, f_var(const(10))))
assert_alpha_equal(dcpe(orig), const(55)) tvm.ir.assert_structural_equal(dcpe(orig), const(55))
def test_nat_update(): def test_nat_update():
...@@ -337,7 +336,7 @@ def test_tuple_match(): ...@@ -337,7 +336,7 @@ def test_tuple_match():
b = relay.Var("b") b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert_alpha_equal(dcpe(x), const(2)) tvm.ir.assert_structural_equal(dcpe(x), const(2))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -339,7 +339,7 @@ def test_extern_ccompiler_default_ops(): ...@@ -339,7 +339,7 @@ def test_extern_ccompiler_default_ops():
fused_mod = transform.FuseOps(2)(mod) fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected() expected_mod = expected()
assert relay.alpha_equal(fused_mod, expected_mod) assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
x_data = np.random.rand(8, 8).astype('float32') x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32')
...@@ -427,7 +427,7 @@ def test_extern_dnnl(): ...@@ -427,7 +427,7 @@ def test_extern_dnnl():
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func()) mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
assert relay.alpha_equal(mod, expected()) assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
ref_mod = tvm.IRModule() ref_mod = tvm.IRModule()
ref_mod["main"] = get_func() ref_mod["main"] = get_func()
...@@ -561,7 +561,7 @@ def test_function_lifting(): ...@@ -561,7 +561,7 @@ def test_function_lifting():
partitioned = partition() partitioned = partition()
ref_mod = expected() ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_function_lifting_inline(): def test_function_lifting_inline():
...@@ -631,7 +631,7 @@ def test_function_lifting_inline(): ...@@ -631,7 +631,7 @@ def test_function_lifting_inline():
partitioned = partition() partitioned = partition()
ref_mod = expected() ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_constant_propagation(): def test_constant_propagation():
...@@ -671,7 +671,7 @@ def test_constant_propagation(): ...@@ -671,7 +671,7 @@ def test_constant_propagation():
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
expected_mod = expected() expected_mod = expected()
assert relay.alpha_equal(mod, expected_mod) assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
y_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32')
np_add = ones + y_data np_add = ones + y_data
......
...@@ -31,7 +31,7 @@ def alpha_equal(x, y): ...@@ -31,7 +31,7 @@ def alpha_equal(x, y):
""" """
x = x['main'] x = x['main']
y = y['main'] y = y['main']
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
...@@ -85,12 +85,12 @@ def test_qnn_legalize(): ...@@ -85,12 +85,12 @@ def test_qnn_legalize():
# Check that Relay Legalize does not change the graph. # Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize()) a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType()) b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph. # Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize()) a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d(): def test_qnn_legalize_qnn_conv2d():
......
...@@ -110,7 +110,7 @@ def test_call_globalvar_without_args(): ...@@ -110,7 +110,7 @@ def test_call_globalvar_without_args():
mod = get_mod() mod = get_mod()
ref_mod = get_mod() ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod) mod = relay.transform.RemoveUnusedFunctions()(mod)
assert relay.alpha_equal(mod, ref_mod) assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from tvm.ir import IRModule from tvm.ir import IRModule, structural_equal
from tvm import relay as rly from tvm import relay as rly
from tvm.relay.transform import SimplifyInference from tvm.relay.transform import SimplifyInference
...@@ -56,7 +56,7 @@ def test_simplify_batchnorm(dtype='float32'): ...@@ -56,7 +56,7 @@ def test_simplify_batchnorm(dtype='float32'):
mod = simplify(mod) mod = simplify(mod)
y1 = mod["main"].body y1 = mod["main"].body
assert rly.analysis.graph_equal(y1, y2) assert structural_equal(y1, y2, map_free_vars=True)
check(2, 1, 1) check(2, 1, 1)
check(4, 1, 1) check(4, 1, 1)
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.analysis import detect_feature
from tvm.relay import op, create_executor, transform from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count from tvm.relay.testing import add_nat_definitions, count
......
...@@ -18,7 +18,6 @@ import tvm ...@@ -18,7 +18,6 @@ import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
from tvm.relay.analysis import assert_graph_equal
from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
from tvm.relay.adt import TypeData from tvm.relay.adt import TypeData
...@@ -34,7 +33,8 @@ def check_visit(typ): ...@@ -34,7 +33,8 @@ def check_visit(typ):
ev = TypeVisitor() ev = TypeVisitor()
ev.visit(typ) ev.visit(typ)
assert_graph_equal(TypeMutator().visit(typ), typ) tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ,
map_free_vars=True)
def test_type_var(): def test_type_var():
......
...@@ -18,10 +18,9 @@ ...@@ -18,10 +18,9 @@
import tvm import tvm
def check_json_roundtrip(node): def check_json_roundtrip(node):
from tvm.relay.analysis import graph_equal
json_str = tvm.ir.save_json(node) json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str) back = tvm.ir.load_json(json_str)
assert graph_equal(back, node) assert tvm.ir.structural_equal(back, node, map_free_vars=True)
def test_prim_type(): def test_prim_type():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment