Unverified Commit 81db03ab by Tianqi Chen Committed by GitHub

[RELAY] Refactor AlphaEqual to support deep comparison of Attrs. (#1958)

parent 28f1a9fd
...@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node { ...@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node {
/*! \brief AttrFieldInfo */ /*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const NodeRef& value) const;
private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};
/*! /*!
* \brief Base class of all attribute class * \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly, * \note Do not subclass AttrBaseNode directly,
...@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node { ...@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node {
/*! /*!
* \brief Whether this attribute's content equals to another node. * \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node. * \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result. * \return The comparison result.
*/ */
TVM_DLL virtual bool ContentEqual(const Node* other) const = 0; TVM_DLL virtual bool ContentEqual(
const Node* other, AttrsEqual equal) const = 0;
/*! /*!
* \brief Content aware hash. * \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result. * \return the hash result.
*/ */
TVM_DLL virtual size_t ContentHash() const = 0; TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const char* _type_key = "Attrs"; static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
...@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final; bool ContentEqual(const Node* other, AttrsEqual equal) const final;
size_t ContentHash() const final; size_t ContentHash(AttrsHash hasher) const final;
// type info // type info
static constexpr const char* _type_key = "DictAttrs"; static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
}; };
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
bool operator()(const NodeRef& lhs, const NodeRef& rhs) const {
return AttrsEqual::Equal(lhs, rhs);
}
// comparator of NodeRef types.
static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs);
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
size_t operator()(const NodeRef& value) const {
return AttrsHash::Hash(value);
}
// hash function of the attribute and attribute fields.
static TVM_DLL size_t Hash(const NodeRef& lhs);
};
// Namespace containing detail implementations // Namespace containing detail implementations
namespace detail { namespace detail {
...@@ -342,8 +350,8 @@ class AttrsEqualVisitor { ...@@ -342,8 +350,8 @@ class AttrsEqualVisitor {
public: public:
bool result_{true}; bool result_{true};
// constructor // constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs) AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs) { : lhs_(lhs), rhs_(rhs), equal_(equal) {
} }
template<typename T> template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) { AttrNopEntry operator()(const char* key, T* lhs_value) {
...@@ -353,7 +361,7 @@ class AttrsEqualVisitor { ...@@ -353,7 +361,7 @@ class AttrsEqualVisitor {
reinterpret_cast<const char*>(rhs_) + reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) - (reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_))); reinterpret_cast<const char*>(lhs_)));
if (!AttrsEqual()(*lhs_value, *rhs_value)) { if (!equal_(*lhs_value, *rhs_value)) {
result_ = false; result_ = false;
} }
return AttrNopEntry(); return AttrNopEntry();
...@@ -362,17 +370,24 @@ class AttrsEqualVisitor { ...@@ -362,17 +370,24 @@ class AttrsEqualVisitor {
private: private:
const Node* lhs_; const Node* lhs_;
const Node* rhs_; const Node* rhs_;
const AttrsEqual& equal_;
}; };
class AttrsHashVisitor { class AttrsHashVisitor {
public: public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}
size_t result_{0}; size_t result_{0};
template<typename T> template<typename T>
AttrNopEntry operator()(const char* key, T* value) { AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, AttrsHash()(*value)); result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry(); return AttrNopEntry();
} }
private:
const AttrsHash& hasher_;
}; };
// helper entry that does initialization, set default. // helper entry that does initialization, set default.
...@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode { ...@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_; return visitor.fields_;
} }
bool ContentEqual(const Node* other) const final { bool ContentEqual(const Node* other, AttrsEqual equal) const final {
DerivedType* pself = self(); DerivedType* pself = self();
if (pself == other) return true; if (pself == other) return true;
if (other == nullptr) return false; if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false; if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other); detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
} }
size_t ContentHash() const final { size_t ContentHash(AttrsHash hasher) const final {
detail::AttrsHashVisitor visitor; detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key()); visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
......
...@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal") ...@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal")
TVM_REGISTER_API("ir_pass.AttrsEqual") TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal); .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_API("ir_pass.AttrsHash") TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash); .set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
return AttrsHash()(node);
});
TVM_REGISTER_API("ir_pass.ExprUseVar") TVM_REGISTER_API("ir_pass.ExprUseVar")
......
...@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
} }
} }
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; // deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
private: private:
// initialize the vtable. // initialize the vtable.
...@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(UIntImm); ATTR_FUNCTOR_DISPATCH(UIntImm);
ATTR_FUNCTOR_DISPATCH(FloatImm); ATTR_FUNCTOR_DISPATCH(FloatImm);
ATTR_FUNCTOR_DISPATCH(StringImm); ATTR_FUNCTOR_DISPATCH(StringImm);
ATTR_FUNCTOR_DISPATCH(Variable);
ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
ATTR_FUNCTOR_DISPATCH(GT);
ATTR_FUNCTOR_DISPATCH(LE);
ATTR_FUNCTOR_DISPATCH(LT);
ATTR_FUNCTOR_DISPATCH(EQ);
ATTR_FUNCTOR_DISPATCH(NE);
ATTR_FUNCTOR_DISPATCH(And);
ATTR_FUNCTOR_DISPATCH(Or);
ATTR_FUNCTOR_DISPATCH(Not);
ATTR_FUNCTOR_DISPATCH(Cast);
ATTR_FUNCTOR_DISPATCH(Call);
ATTR_FUNCTOR_DISPATCH(Select);
return vtable; return vtable;
} }
}; };
class AttrsEqualHandler :
protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs);
protected:
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const NodeRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const NodeRef& node) {
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Node* lhs) final;
size_t VisitAttr_(const ir::IntImm* lhs) final;
size_t VisitAttr_(const ir::UIntImm* lhs) final;
size_t VisitAttr_(const ir::FloatImm* lhs) final;
size_t VisitAttr_(const ir::StringImm* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
size_t VisitAttr_(const ir::GT* op) final;
size_t VisitAttr_(const ir::LE* op) final;
size_t VisitAttr_(const ir::LT* op) final;
size_t VisitAttr_(const ir::EQ* op) final;
size_t VisitAttr_(const ir::NE* op) final;
size_t VisitAttr_(const ir::And* op) final;
size_t VisitAttr_(const ir::Or* op) final;
size_t VisitAttr_(const ir::Not* op) final;
size_t VisitAttr_(const ir::Cast* op) final;
size_t VisitAttr_(const ir::Call* op) final;
size_t VisitAttr_(const ir::Select* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm } // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_ #endif // TVM_LANG_ATTR_FUNCTOR_H_
...@@ -51,156 +51,272 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); ...@@ -51,156 +51,272 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
using namespace ir; using namespace ir;
// Equal handler.
bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}
class AttrsEqualChecker : bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
public AttrFunctor<bool(const NodeRef&, const NodeRef&)> { if (lhs->derived_from<BaseAttrsNode>()) {
public: AttrsEqual equal;
bool Check(const NodeRef& lhs, const NodeRef& rhs) { equal.handler_ = this;
if (!equal_) return false; return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
if (lhs.same_as(rhs)) return true; other.get(), equal);
if (!lhs.defined() || !rhs.defined()) return false;
if (!this->VisitAttr(lhs, rhs)) {
equal_ = false;
}
return equal_;
} }
return lhs == other.get();
}
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
if (lhs->derived_from<BaseAttrsNode>()) { if (const auto* rhs = other.as<IntImm>()) {
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get()); return lhs->value == rhs->value;
}
return lhs == other.get();
} }
return false;
}
bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<IntImm>()) { if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
}
return false;
} }
return false;
}
bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<UIntImm>()) { if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
}
return false;
} }
return false;
}
bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<FloatImm>()) { if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
}
return false;
} }
return false;
}
bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<StringImm>()) { if (const auto* rhs = other.as<ArrayNode>()) {
return lhs->value == rhs->value; if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
} }
return false;
} }
return true;
}
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final { bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) { if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) { for (const auto& kv : lhs->data) {
if (!Check(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; auto it = rhs->data.find(kv.first);
} if (it == rhs->data.end()) return false;
if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
} }
return true;
} }
return true;
}
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final { #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
if (const auto* rhs = other.as<StrMapNode>()) { bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
if (rhs->data.size() != lhs->data.size()) return false; if (const auto* rhs = other.as<NodeName>()) { \
for (const auto& kv : lhs->data) { if (!Equal(lhs->a, rhs->a)) return false; \
auto it = rhs->data.find(kv.first); if (!Equal(lhs->b, rhs->b)) return false; \
if (it == rhs->data.end()) return false; return true; \
if (!Check(NodeRef(kv.second), NodeRef(it->second))) return false; } else { \
} return false; \
} } \
return true; } \
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Not>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
} }
}
private: bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
bool equal_{true}; if (const auto* rhs = other.as<Cast>()) {
}; if (lhs->type != rhs->type) return false;
return Equal(lhs->value, rhs->value);
class AttrContentHasher : } else {
public AttrFunctor<void(const NodeRef&)> { return false;
public:
size_t result_{0};
void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
} else {
Update(NodeHash()(GetRef<NodeRef>(value)));
}
} }
}
void VisitAttr_(const IntImm* op) final { bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
Update(std::hash<int64_t>()(op->value)); if (const auto* rhs = other.as<Call>()) {
return
lhs->name == rhs->name &&
lhs->type == rhs->type &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
return false;
} }
}
void VisitAttr_(const UIntImm* op) final { bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
Update(std::hash<uint64_t>()(op->value)); if (const auto* rhs = other.as<Select>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
Equal(lhs->false_value, rhs->false_value);
} else {
return false;
} }
}
void VisitAttr_(const FloatImm* op) final { // Hash Handler.
Update(std::hash<double>()(op->value)); size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
if (value->derived_from<BaseAttrsNode>()) {
AttrsHash hasher;
hasher.handler_ = this;
return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
} else {
return NodeHash()(GetRef<NodeRef>(value));
} }
}
void VisitAttr_(const StringImm* op) final { size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
Update(std::hash<std::string>()(op->value)); return std::hash<int64_t>()(op->value);
} }
void VisitAttr_(const ArrayNode* op) final { size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
Update(op->data.size()); return std::hash<uint64_t>()(op->value);
for (size_t i = 0; i < op->data.size(); ++i) { }
this->VisitAttr(NodeRef(op->data[i]));
} size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
return std::hash<std::string>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
size_t result = op->data.size();
for (size_t i = 0; i < op->data.size(); ++i) {
result = Combine(result, this->Hash(NodeRef(op->data[i])));
} }
return result;
}
void VisitAttr_(const StrMapNode* lhs) final { size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
using Entry = std::pair<std::string, NodePtr<Node> >; using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end()); std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) { std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first; return a.first < b.first;
}); });
size_t result = 0;
for (const Entry& kv : data) { for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first)); result = Combine(result, std::hash<std::string>()(kv.first));
this->VisitAttr(NodeRef(kv.second)); result = Combine(result, this->Hash(NodeRef(kv.second)));
} }
} return result;
}
void Update(size_t value) {
result_ = dmlc::HashCombine(result_, value);
}
};
bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) { #define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);
size_t AttrsHashHandler::VisitAttr_(const Not* op) {
static size_t key = std::hash<std::string>()(Not::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->value));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Call* op) {
static size_t key = std::hash<std::string>()(Call::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->args));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Select* op) {
static size_t key = std::hash<std::string>()(Select::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
res = Combine(res, Hash(op->false_value));
return res;
}
// Default case
bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
AttrsEqualChecker checker; if (handler_ == nullptr) {
return checker.Check(lhs, rhs); return AttrsEqualHandler().Equal(lhs, rhs);
} else {
return handler_->Equal(lhs, rhs);
}
} }
size_t AttrsHash::Hash(const NodeRef& node) { size_t AttrsHash::operator()(const NodeRef& node) const {
if (!node.defined()) return 0; if (!node.defined()) return 0;
AttrContentHasher hasher; if (handler_ == nullptr) {
hasher.VisitAttr(node); return AttrsHashHandler().Hash(node);
return hasher.result_; } else {
return handler_->Hash(node);
}
} }
size_t DictAttrsNode::ContentHash() const { size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
return AttrsHash()(this->dict); return hasher(this->dict);
} }
bool DictAttrsNode::ContentEqual(const Node* other) const { bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
if (this == other) return true; if (this == other) return true;
if (other == nullptr) return false; if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false; if (this->type_index() != other->type_index()) return false;
return AttrsEqual()(this->dict, static_cast<const DictAttrsNode*>(other)->dict); return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
} }
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
// Alpha equal handler for relay.
class AlphaEqualHandler:
public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>,
public ExprFunctor<bool(const Expr&, const Expr&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) {}
/*!
* Check equality of two nodes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs->derived_from<TypeNode>()) {
if (!rhs->derived_from<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->derived_from<ExprNode>()) {
if (!rhs->derived_from<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
return AttrEqual(lhs, rhs);
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqualHandler::Equal(lhs, rhs);
}
/*!
* Check equality of two types.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
}
/*!
* Check equality of two expressions.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool ExprEqual(const Expr& lhs, const Expr& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
}
if (this->VisitExpr(lhs, rhs)) {
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
}
protected:
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
return lhs == rhs;
}
/*!
* \brief Check Equality of leaf node of the graph.
* if map_free_var_ is set to true, try to map via equal node.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
} else {
if (map_free_var_) {
if (lhs->type_index() != rhs->type_index()) return false;
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
}
}
using AttrsEqualHandler::VisitAttr_;
bool VisitAttr_(const Variable* lhs, const NodeRef& other) final {
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
// Type equality
bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
return (lhs->dtype == rhs->dtype &&
AttrEqual(lhs->shape, rhs->shape));
} else {
return false;
}
}
bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
if (lhs->kind != rhs->kind) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else {
return false;
}
}
bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
return false;
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
}
for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
}
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
if (!TypeEqual(lhs->type_constraints[i],
rhs->type_constraints[i])) {
return false;
}
}
return true;
} else {
return false;
}
}
bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (!lhs->func.same_as(rhs->func)) return false;
if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(*lhs.operator->());
return std::memcmp(lhs->data, rhs->data, data_size) == 0;
} else {
return false;
}
}
}
// merge declaration of two variables together.
bool MergeVarDecl(const Var& lhs, const Var& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!TypeEqual(lhs->type_annotation,
rhs->type_annotation)) return false;
CHECK(!equal_map_.count(lhs))
<< "Duplicated declaration of variable " << lhs;
equal_map_[lhs] = rhs;
return true;
}
bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
if (const VarNode* rhs = other.as<VarNode>()) {
if (lhs->name_hint != rhs->name_hint) return false;
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else {
return false;
}
}
bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
// use name equality for global var for now.
if (lhs->name_hint != rhs->name_hint) return false;
return true;
} else {
return false;
}
}
bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
if (const TupleNode* rhs = other.as<TupleNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
if (const FunctionNode* rhs = other.as<FunctionNode>()) {
if (lhs->params.size() != rhs->params.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
// map type parameter to be the same
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
}
// check parameter type annotations
for (size_t i = 0; i < lhs->params.size(); ++i) {
if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
}
// check return types.
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
return false;
}
}
bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
if (const IfNode* rhs = other.as<IfNode>()) {
return ExprEqual(lhs->cond, rhs->cond) &&
ExprEqual(lhs->true_branch, rhs->true_branch) &&
ExprEqual(lhs->false_branch, rhs->false_branch);
} else {
return false;
}
}
bool VisitExpr_(const OpNode* op, const Expr& other) final {
return op == other.get();
}
bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
if (const ConstantNode* rhs = other.as<ConstantNode>()) {
return NDArrayEqual(lhs->data, rhs->data);
} else {
return false;
}
}
bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
} else {
return false;
}
}
private:
// whether to map open terms.
bool map_free_var_{false};
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};
bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
}
bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
});
} // namespace relay
} // namespace tvm
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/relay/environment.h> #include <tvm/relay/environment.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <sstream> #include <sstream>
#include "../pass/type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../lang/attr_functor.h"
namespace tvm { namespace tvm {
...@@ -245,6 +245,9 @@ class TextPrinter : ...@@ -245,6 +245,9 @@ class TextPrinter :
stream_ << ", "; stream_ << ", ";
} }
} }
if (fields.size() == 1) {
stream_ << ',';
}
stream_ << ')'; stream_ << ')';
this->PrintEndInst("\n"); this->PrintEndInst("\n");
return id; return id;
...@@ -648,7 +651,7 @@ class TextPrinter : ...@@ -648,7 +651,7 @@ class TextPrinter :
name = "%" + name; name = "%" + name;
} }
TextValue val(GetUniqueName(name)); TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)); CHECK(!memo_.count(var)) << "Duplicated variable " << var;
memo_[var] = val; memo_[var] = val;
return val; return val;
} }
......
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
* \file type_functor.h * \file type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types. * \brief A way to defined arbitrary function signature with dispatch on types.
*/ */
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -89,6 +90,113 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -89,6 +90,113 @@ class TypeFunctor<R(const Type& n, Args...)> {
} }
}; };
/*!
* \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief Check that two type are syntactically equal up to alpha equivalence.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t s = GetDataSize(*lhs.operator->());
return memcmp(lhs->data, rhs->data, s) == 0;
} else {
return false;
}
}
}
struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeVar, TypeVar> eq_map;
bool equal;
TypeAlphaEq() : eq_map(), equal(true) {}
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
if (dt1 != dt2) {
equal = false;
}
}
void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
if (s1.size() != s2.size()) {
equal = false;
return;
}
for (size_t i = 0; i < s1.size(); ++i) {
if (!tvm::ir::Equal(s1[i], s2[i])) {
equal = false;
return;
}
}
}
void VisitType_(const TensorTypeNode* tt1, const Type& t2) final {
if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) {
DataTypeEqual(tt1->dtype, tt2->dtype);
ShapeEqual(tt1->shape, tt2->shape);
} else {
equal = false;
}
}
void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final {
if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) {
equal = equal && bt1 == bt2;
return;
} else {
equal = false;
}
}
void VisitType_(const TypeVarNode* ti1, const Type& t2) final {
if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) {
auto tid1 = GetRef<TypeVar>(ti1);
auto tid2 = GetRef<TypeVar>(ti2);
// We handle open terms with this rule assuming variables are identical.
//
// Not sure if we should do this.
if (tid1 == tid2) {
return;
}
// Check that they are same kind
if (tid1->kind != tid2->kind) {
equal = false;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(tid1) != eq_map.end()) {
equal = equal && eq_map[tid1] == tid2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitType_(const FuncTypeNode* op, const Type& t2) final {
if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) {
if (op->arg_types.size() != ta2->arg_types.size()
|| op->type_params.size() != ta2->type_params.size()
|| op->type_constraints.size() != ta2->type_constraints.size()) {
equal = false;
return;
}
// must visit params first so they are appropriate entered
// into equality map
for (size_t i = 0; i < op->type_params.size(); i++) {
eq_map.Set(op->type_params[i], ta2->type_params[i]);
this->VisitType(op->type_params[i], ta2->type_params[i]);
if (!equal) {
return;
}
}
for (size_t i = 0; i < op->arg_types.size(); i++) {
this->VisitType(op->arg_types[i], ta2->arg_types[i]);
if (!equal) {
return;
}
}
this->VisitType(op->ret_type, ta2->ret_type);
if (!equal) {
return;
}
for (size_t i = 0; i < op->type_constraints.size(); i++) {
this->VisitType(op->type_constraints[i], ta2->type_constraints[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitType_(const TypeRelationNode* tr1, const Type& t2) final {
if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) {
if (tr1->func != tr2->func
|| tr1->num_inputs != tr2->num_inputs
|| tr1->attrs != tr2->attrs) {
equal = false;
return;
}
if (tr1->args.size() != tr2->args.size()) {
equal = false;
return;
}
for (size_t i = 0; i < tr1->args.size(); i++) {
this->VisitType(tr1->args[i], tr2->args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitType_(const TupleTypeNode* op, const Type& t2) final {
if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) {
if (op->fields.size() != pt->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < op->fields.size(); i++) {
if (!equal) {
return;
}
this->VisitType(op->fields[i], pt->fields[i]);
}
} else {
equal = false;
}
}
};
bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined()) {
return false;
}
if (!t1.defined()) {
return true;
}
TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
return aeq.equal;
}
struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
public:
tvm::Map<Var, Var> eq_map;
bool equal;
AlphaEq() : eq_map(), equal(true) {}
void VisitExpr_(const VarNode* e1, const Expr& e2) final {
if (const VarNode* id2 = e2.as<VarNode>()) {
auto local1 = GetRef<Var>(e1);
auto local2 = GetRef<Var>(id2);
// We handle open terms with this rule assuming variables are identical.
if (local1 == local2) {
equal = true;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(local1) != eq_map.end()) {
equal = equal && eq_map[local1] == local2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final {
if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) {
equal = equal && g1 == g2;
} else {
equal = false;
}
}
void VisitExpr_(const TupleNode* pl1, const Expr& e2) final {
Tuple prod1 = GetRef<Tuple>(pl1);
if (const TupleNode* pl2 = e2.as<TupleNode>()) {
Tuple prod2 = GetRef<Tuple>(pl2);
if (prod1->fields.size() != prod2->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < prod1->fields.size(); i++) {
this->VisitExpr(prod1->fields[i], prod2->fields[i]);
}
} else {
equal = false;
}
}
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) {
equal = false;
return;
}
if (func1->type_params.size() != func2->type_params.size()) {
equal = false;
return;
}
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) {
return;
}
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
if (!equal) {
return;
}
}
equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
if (!equal) {
return;
}
this->VisitExpr(func1->body, func2->body);
} else {
equal = false;
}
}
void VisitExpr_(const CallNode* op, const Expr& e2) final {
if (const CallNode* call = e2.as<CallNode>()) {
this->VisitExpr(op->op, call->op);
if (op->args.size() != call->args.size()) {
equal = false;
return;
}
if (op->type_args.size() != call->type_args.size()) {
equal = false;
return;
}
// checking attrs by pointer equality for now
equal = equal && (op->attrs == call->attrs);
if (!equal) {
return;
}
for (size_t i = 0U; i < op->args.size(); i++) {
this->VisitExpr(op->args[i], call->args[i]);
}
for (size_t i = 0U; i < op->type_args.size(); i++) {
equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode* let = e2.as<LetNode>()) {
MergeVarDecl(op->var, let->var);
this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body);
} else {
equal = false;
}
}
void VisitExpr_(const IfNode* op, const Expr& e2) final {
if (const IfNode* i = e2.as<IfNode>()) {
VisitExpr(op->cond, i->cond);
VisitExpr(op->true_branch, i->true_branch);
VisitExpr(op->false_branch, i->false_branch);
} else {
equal = false;
}
}
void VisitExpr_(const OpNode* op, const Expr& e2) final {
if (const OpNode* o = e2.as<OpNode>()) {
equal = equal && op->name == o->name;
} else {
equal = false;
}
}
void VisitExpr_(const ConstantNode* op, const Expr& e2) final {
if (const ConstantNode* c = e2.as<ConstantNode>()) {
if (AlphaEqual(op->tensor_type(), c->tensor_type())) {
equal = equal && SameNDArray(op->data, c->data);
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
private:
void MergeVarDecl(const Var& var1, const Var& var2) {
equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation);
if (!equal) {
return;
}
eq_map.Set(var1, var2);
}
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
AlphaEq eq;
eq.VisitExpr(e1, e2);
return eq.equal;
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expr e1 = args[0];
Expr e2 = args[1];
*ret = AlphaEqual(e1, e2);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t1 = args[0];
Type t2 = args[1];
*ret = AlphaEqual(t1, t2);
});
} // namespace relay
} // namespace tvm
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
* contains a data type such as `int`, `float`, `uint`. * contains a data type such as `int`, `float`, `uint`.
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) { ...@@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) {
} }
TVM_REGISTER_API("relay._ir_pass.check_kind") TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], EnvironmentNode::make({})); *ret = KindCheck(args[0], EnvironmentNode::make({}));
} else { } else {
*ret = KindCheck(args[0], args[1]); *ret = KindCheck(args[0], args[1]);
} }
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief Function for substituting a concrete type in place of a type ID * \brief Function for substituting a concrete type in place of a type ID
*/ */
#include "./type_subst.h" #include "./type_subst.h"
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_visitor.h
* \brief A wrapper around TypeFunctor for common use cases.
*/
#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_
#define TVM_RELAY_PASS_TYPE_VISITOR_H_
#include <vector>
#include "./type_functor.h"
namespace tvm {
namespace relay {
/*! \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal(): ...@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal():
# attrs are also compared only by pointer equality # attrs are also compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
...@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal(): ...@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal():
diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1) diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1) diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2) diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1) bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2) diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)
...@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal(): ...@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal():
assert tr != diff_order assert tr != diff_order
assert tr != diff_args assert tr != diff_args
assert tr != diff_attr assert tr != diff_attr
assert tr == same_attr
assert tr != bigger assert tr != bigger
assert bigger != diff_num_inputs assert bigger != diff_num_inputs
...@@ -216,22 +219,26 @@ def test_global_var_alpha_equal(): ...@@ -216,22 +219,26 @@ def test_global_var_alpha_equal():
def test_tuple_alpha_equal(): def test_tuple_alpha_equal():
v0 = relay.Var("v0")
v1 = relay.Var("v1") v1 = relay.Var("v1")
v2 = relay.Var("v2") v2 = relay.Var("v2")
# unit value is a valid tuple # unit value is a valid tuple
assert alpha_equal(relay.Tuple([]), relay.Tuple([])) assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
assert alpha_equal(tup, same) assert alpha_equal(tup, same)
# use the eq_map # use the eq_map
let_tup = relay.Let(v1, tup, v1) let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3), let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])]), relay.Tuple([relay.const(4)])]),
v2) v2)
assert alpha_equal(let_tup, let_mapped) assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]) more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
...@@ -340,7 +347,8 @@ def test_call_alpha_equal(): ...@@ -340,7 +347,8 @@ def test_call_alpha_equal():
# attrs are compared only by pointer equality # attrs are compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tt1 = relay.TensorType((1, 2, 3), "float32") tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8") tt2 = relay.TensorType((), "int8")
...@@ -375,6 +383,9 @@ def test_call_alpha_equal(): ...@@ -375,6 +383,9 @@ def test_call_alpha_equal():
different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
assert not alpha_equal(call, different_attrs) assert not alpha_equal(call, different_attrs)
same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
assert alpha_equal(call, same_attrs)
no_type_args = relay.Call(v1, basic_args, attr1) no_type_args = relay.Call(v1, basic_args, attr1)
assert not alpha_equal(call, no_type_args) assert not alpha_equal(call, no_type_args)
...@@ -445,6 +456,27 @@ def test_op_alpha_equal(): ...@@ -445,6 +456,27 @@ def test_op_alpha_equal():
assert not alpha_equal(op1, op3) assert not alpha_equal(op1, op3)
def test_graph_equal():
x = relay.var("x")
y0 = relay.add(x, x)
z0 = relay.add(y0, y0)
y1 = relay.add(x, x)
z1 = relay.add(y1, y1)
z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
# Relay view them as different programs
# Check the difference in the text format.
assert not alpha_equal(z0, z3)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
...@@ -462,3 +494,4 @@ if __name__ == "__main__": ...@@ -462,3 +494,4 @@ if __name__ == "__main__":
test_if_alpha_equal() test_if_alpha_equal()
test_op_alpha_equal() test_op_alpha_equal()
test_var_alpha_equal() test_var_alpha_equal()
test_graph_equal()
...@@ -17,6 +17,12 @@ def test_attrs_equal(): ...@@ -17,6 +17,12 @@ def test_attrs_equal():
assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
n = tvm.var("n")
assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
def test_attrs_hash(): def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash fhash = tvm.ir_pass.AttrsHash
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment