/*! * Copyright (c) 2018 by Contributors * \file src/tvm/relay/ir/alpha_equal.cc * \brief Alpha equality check by deep comparing two nodes. */ #include <tvm/ir_pass.h> #include <tvm/relay/expr_functor.h> #include <tvm/runtime/ndarray.h> #include <tvm/relay/pass.h> #include "type_functor.h" #include "../../lang/attr_functor.h" namespace tvm { namespace relay { // Alpha equal handler for relay. class AlphaEqualHandler: public AttrsEqualHandler, public TypeFunctor<bool(const Type&, const Type&)>, public ExprFunctor<bool(const Expr&, const Expr&)> { public: explicit AlphaEqualHandler(bool map_free_var) : map_free_var_(map_free_var) {} /*! * Check equality of two nodes. * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool Equal(const NodeRef& lhs, const NodeRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; if (lhs->derived_from<TypeNode>()) { if (!rhs->derived_from<TypeNode>()) return false; return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs)); } if (lhs->derived_from<ExprNode>()) { if (!rhs->derived_from<ExprNode>()) return false; return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs)); } return AttrEqual(lhs, rhs); } /*! * Check equality of two attributes. * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { return AttrsEqualHandler::Equal(lhs, rhs); } /*! * Check equality of two types. * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool TypeEqual(const Type& lhs, const Type& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitType(lhs, rhs); } /*! * Check equality of two expressions. * * \note We run graph structural equality checking when comparing two Exprs. * This means that AlphaEqualHandler can only be used once for each pair. * The equality checker checks data-flow equvalence of the Expr DAG. * This function also runs faster as it memomizes equal_map. * * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool ExprEqual(const Expr& lhs, const Expr& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; auto it = equal_map_.find(lhs); if (it != equal_map_.end()) { return it->second.same_as(rhs); } if (this->VisitExpr(lhs, rhs)) { equal_map_[lhs] = rhs; return true; } else { return false; } } protected: /*! * \brief Check if data type equals each other. * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool DataTypeEqual(const DataType& lhs, const DataType& rhs) { return lhs == rhs; } /*! * \brief Check Equality of leaf node of the graph. * if map_free_var_ is set to true, try to map via equal node. * \param lhs The left hand operand. * \param rhs The right hand operand. * \return the compare result. */ bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) { if (lhs.same_as(rhs)) return true; auto it = equal_map_.find(lhs); if (it != equal_map_.end()) { return it->second.same_as(rhs); } else { if (map_free_var_) { if (lhs->type_index() != rhs->type_index()) return false; equal_map_[lhs] = rhs; return true; } else { return false; } } } using AttrsEqualHandler::VisitAttr_; bool VisitAttr_(const Variable* lhs, const NodeRef& other) final { return LeafNodeEqual(GetRef<NodeRef>(lhs), other); } // Type equality bool VisitType_(const TensorTypeNode* lhs, const Type& other) final { if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) { return (lhs->dtype == rhs->dtype && AttrEqual(lhs->shape, rhs->shape)); } else { return false; } } bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final { return LeafNodeEqual(GetRef<NodeRef>(lhs), other); } bool VisitType_(const TypeVarNode* lhs, const Type& other) final { if (const TypeVarNode* rhs = other.as<TypeVarNode>()) { if (lhs->kind != rhs->kind) return false; return LeafNodeEqual(GetRef<NodeRef>(lhs), other); } else { return false; } } bool VisitType_(const FuncTypeNode* lhs, const Type& other) final { if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) { if (lhs->arg_types.size() != rhs->arg_types.size()) return false; if (lhs->type_params.size() != rhs->type_params.size()) return false; if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false; for (size_t i = 0; i < lhs->type_params.size(); ++i) { if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) { return false; } equal_map_[lhs->type_params[i]] = rhs->type_params[i]; // set up type parameter equal if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { // map variable equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; } } for (size_t i = 0; i < lhs->arg_types.size(); i++) { if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false; } if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; for (size_t i = 0; i < lhs->type_constraints.size(); i++) { if (!TypeEqual(lhs->type_constraints[i], rhs->type_constraints[i])) { return false; } } return true; } else { return false; } } bool VisitType_(const TypeRelationNode* lhs, const Type& other) final { if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) { if (lhs->func->name != rhs->func->name) return false; if (lhs->num_inputs != rhs->num_inputs) return false; if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; if (lhs->args.size() != rhs->args.size()) return false; for (size_t i = 0; i < lhs->args.size(); ++i) { if (!TypeEqual(lhs->args[i], rhs->args[i])) return false; } return true; } else { return false; } } bool VisitType_(const TupleTypeNode* lhs, const Type& other) final { if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) { if (lhs->fields.size() != rhs->fields.size()) return false; for (size_t i = 0; i < lhs->fields.size(); ++i) { if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false; } return true; } else { return false; } } // Expr equal checking. bool NDArrayEqual(const runtime::NDArray& lhs, const runtime::NDArray& rhs) { if (lhs.defined() != rhs.defined()) { return false; } else if (lhs.same_as(rhs)) { return true; } else { auto ldt = lhs->dtype; auto rdt = rhs->dtype; CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { size_t data_size = runtime::GetDataSize(*lhs.operator->()); return std::memcmp(lhs->data, rhs->data, data_size) == 0; } else { return false; } } } // merge declaration of two variables together. bool MergeVarDecl(const Var& lhs, const Var& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; CHECK(!equal_map_.count(lhs)) << "Duplicated declaration of variable " << lhs; equal_map_[lhs] = rhs; return true; } bool VisitExpr_(const VarNode* lhs, const Expr& other) final { if (const VarNode* rhs = other.as<VarNode>()) { if (lhs->name_hint != rhs->name_hint) return false; if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; return LeafNodeEqual(GetRef<NodeRef>(lhs), other); } else { return false; } } bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) { // use name equality for global var for now. if (lhs->name_hint != rhs->name_hint) return false; return true; } else { return false; } } bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { if (const TupleNode* rhs = other.as<TupleNode>()) { if (lhs->fields.size() != rhs->fields.size()) return false; for (size_t i = 0; i < lhs->fields.size(); ++i) { if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false; } return true; } else { return false; } } bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final { if (const FunctionNode* rhs = other.as<FunctionNode>()) { if (lhs->params.size() != rhs->params.size()) return false; if (lhs->type_params.size() != rhs->type_params.size()) return false; // map type parameter to be the same for (size_t i = 0; i < lhs->type_params.size(); ++i) { if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false; equal_map_[lhs->type_params[i]] = rhs->type_params[i]; } // check parameter type annotations for (size_t i = 0; i < lhs->params.size(); ++i) { if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false; } // check return types. if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; return ExprEqual(lhs->body, rhs->body); } else { return false; } } bool VisitExpr_(const CallNode* lhs, const Expr& other) final { if (const CallNode* rhs = other.as<CallNode>()) { if (!ExprEqual(lhs->op, rhs->op)) return false; if (lhs->args.size() != rhs->args.size()) return false; // skip type_args check for primitive ops. bool is_primitive = IsPrimitiveOp(lhs->op); if (!is_primitive) { if (lhs->type_args.size() != rhs->type_args.size()) { return false; } } for (size_t i = 0; i < lhs->args.size(); ++i) { if (!ExprEqual(lhs->args[i], rhs->args[i])) { return false; } } if (!is_primitive) { for (size_t i = 0; i < lhs->type_args.size(); ++i) { if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; } } return AttrEqual(lhs->attrs, rhs->attrs); } else { return false; } } bool VisitExpr_(const LetNode* lhs, const Expr& other) final { if (const LetNode* rhs = other.as<LetNode>()) { if (!ExprEqual(lhs->value, rhs->value)) return false; if (!MergeVarDecl(lhs->var, rhs->var)) return false; return ExprEqual(lhs->body, rhs->body); } else { return false; } } bool VisitExpr_(const IfNode* lhs, const Expr& other) final { if (const IfNode* rhs = other.as<IfNode>()) { return ExprEqual(lhs->cond, rhs->cond) && ExprEqual(lhs->true_branch, rhs->true_branch) && ExprEqual(lhs->false_branch, rhs->false_branch); } else { return false; } } bool VisitExpr_(const OpNode* op, const Expr& other) final { return op == other.get(); } bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final { if (const ConstantNode* rhs = other.as<ConstantNode>()) { return NDArrayEqual(lhs->data, rhs->data); } else { return false; } } bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final { if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) { return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index; } else { return false; } } private: // whether to map open terms. bool map_free_var_{false}; // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_; }; bool AlphaEqual(const Type& lhs, const Type& rhs) { return AlphaEqualHandler(false).TypeEqual(lhs, rhs); } bool AlphaEqual(const Expr& lhs, const Expr& rhs) { return AlphaEqualHandler(false).ExprEqual(lhs, rhs); } // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._make._alpha_equal") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = AlphaEqualHandler(false).Equal(args[0], args[1]); }); TVM_REGISTER_API("relay._make._type_alpha_equal") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]); }); TVM_REGISTER_API("relay._make._graph_equal") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = AlphaEqualHandler(true).Equal(args[0], args[1]); }); } // namespace relay } // namespace tvm