Unverified Commit 6536b356 by Zhi Committed by GitHub

remove AttrsEqual and AttrsHash related code (#5169)

parent a2edd01b
...@@ -46,6 +46,8 @@ ...@@ -46,6 +46,8 @@
#include <dmlc/common.h> #include <dmlc/common.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <unordered_map> #include <unordered_map>
...@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef { ...@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
}; };
class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const ObjectRef& value) const;
private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};
/*! /*!
* \brief Base class of all attribute class * \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly, * \note Do not subclass AttrBaseNode directly,
...@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object { ...@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present. * \note This function throws when the required field is not present.
*/ */
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true;
...@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info // type info
static constexpr const char* _type_key = "DictAttrs"; static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
...@@ -386,34 +283,6 @@ class AttrNormalVisitor { ...@@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor* visitor_; AttrVisitor* visitor_;
}; };
// Wrapper for normal visitor.
class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};
class AttrsSEqualVisitor { class AttrsSEqualVisitor {
public: public:
bool result_{true}; bool result_{true};
...@@ -441,23 +310,6 @@ class AttrsSEqualVisitor { ...@@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const SEqualReducer& equal_; const SEqualReducer& equal_;
}; };
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}
size_t result_{0};
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry();
}
private:
const AttrsHash& hasher_;
};
class AttrsSHashVisitor { class AttrsSHashVisitor {
public: public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
...@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry { ...@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return *this; return *this;
} }
TSelf& set_default(const T& value) { TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) { if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false; trigger_ = false;
} }
return *this; return *this;
...@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode { ...@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_; return visitor.fields_;
} }
bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
private: private:
DerivedType* self() const { DerivedType* self() const {
return const_cast<DerivedType*>( return const_cast<DerivedType*>(
......
...@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> { ...@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
} }
}; };
class AttrsEqualHandler :
protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);
protected:
bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const ObjectRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const ObjectRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Object* lhs) final;
size_t VisitAttr_(const tir::IntImmNode* lhs) final;
size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
size_t VisitAttr_(const tir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const tir::AddNode* op) final;
size_t VisitAttr_(const tir::SubNode* op) final;
size_t VisitAttr_(const tir::MulNode* op) final;
size_t VisitAttr_(const tir::DivNode* op) final;
size_t VisitAttr_(const tir::ModNode* op) final;
size_t VisitAttr_(const tir::FloorDivNode* op) final;
size_t VisitAttr_(const tir::FloorModNode* op) final;
size_t VisitAttr_(const tir::MinNode* op) final;
size_t VisitAttr_(const tir::MaxNode* op) final;
size_t VisitAttr_(const tir::GENode* op) final;
size_t VisitAttr_(const tir::GTNode* op) final;
size_t VisitAttr_(const tir::LENode* op) final;
size_t VisitAttr_(const tir::LTNode* op) final;
size_t VisitAttr_(const tir::EQNode* op) final;
size_t VisitAttr_(const tir::NENode* op) final;
size_t VisitAttr_(const tir::AndNode* op) final;
size_t VisitAttr_(const tir::OrNode* op) final;
size_t VisitAttr_(const tir::NotNode* op) final;
size_t VisitAttr_(const tir::CastNode* op) final;
size_t VisitAttr_(const tir::CallNode* op) final;
size_t VisitAttr_(const tir::SelectNode* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm } // namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_ #endif // TVM_IR_ATTR_FUNCTOR_H_
...@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") ...@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
return attrs->dict; return attrs->dict;
}); });
using namespace tir;
// Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}
bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) {
if (lhs->IsInstance<BaseAttrsNode>()) {
AttrsEqual equal;
equal.handler_ = this;
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
other.get(), equal);
}
return lhs == other.get();
}
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!Equal(kv.second, it->second)) return false;
}
return true;
} else {
return false;
}
}
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
if (const auto* rhs = other.as<NodeName>()) { \
if (!Equal(lhs->a, rhs->a)) return false; \
if (!Equal(lhs->b, rhs->b)) return false; \
return true; \
} else { \
return false; \
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<NotNode>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CastNode>()) {
if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CallNode>()) {
return
lhs->name == rhs->name &&
lhs->dtype == rhs->dtype &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<SelectNode>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
Equal(lhs->false_value, rhs->false_value);
} else {
return false;
}
}
// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
if (value->IsInstance<BaseAttrsNode>()) {
AttrsHash hasher;
hasher.handler_ = this;
return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
} else {
return ObjectHash()(GetRef<ObjectRef>(value));
}
}
size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
return std::hash<int64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
return std::hash<std::string>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
size_t result = op->data.size();
for (size_t i = 0; i < op->data.size(); ++i) {
result = Combine(result, this->Hash(op->data[i]));
}
return result;
}
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
using Entry = std::pair<std::string, ObjectRef>;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first;
});
size_t result = 0;
for (const Entry& kv : data) {
result = Combine(result, std::hash<std::string>()(kv.first));
result = Combine(result, this->Hash(kv.second));
}
return result;
}
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
static size_t key = std::hash<std::string>()(NotNode::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
static size_t key = std::hash<std::string>()(CastNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->value));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
static size_t key = std::hash<std::string>()(CallNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->args));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
static size_t key = std::hash<std::string>()(SelectNode::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
res = Combine(res, Hash(op->false_value));
return res;
}
// Default case
bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
if (lhs.same_as(rhs)) return true;
if (handler_ == nullptr) {
return AttrsEqualHandler().Equal(lhs, rhs);
} else {
return handler_->Equal(lhs, rhs);
}
}
size_t AttrsHash::operator()(const ObjectRef& node) const {
if (!node.defined()) return 0;
if (handler_ == nullptr) {
return AttrsHashHandler().Hash(node);
} else {
return handler_->Hash(node);
}
}
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
return hasher(this->dict);
}
bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body_typed([](Attrs attrs) { .set_body_typed([](Attrs attrs) {
return attrs->ListFieldInfo(); return attrs->ListFieldInfo();
}); });
TVM_REGISTER_GLOBAL("ir.AttrsEqual")
.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
return AttrsEqual()(lhs, rhs);
});
} // namespace tvm } // namespace tvm
...@@ -103,6 +103,7 @@ class RemapVarSEqualHandler : ...@@ -103,6 +103,7 @@ class RemapVarSEqualHandler :
// Function that implements actual equality check. // Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
if (!lhs.defined() && !rhs.defined()) return true;
task_stack_.clear(); task_stack_.clear();
pending_tasks_.clear(); pending_tasks_.clear();
equal_map_lhs_.clear(); equal_map_lhs_.clear();
......
...@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { ...@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
} }
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq; StructuralEqual eq;
const Layout kOIHW("OIHW"); const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>(); const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>(); const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
...@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { ...@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
} }
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq; StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>(); auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>(); auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>(); auto toutput_a = a->type_as<TensorTypeNode>();
......
...@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { ...@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
protected: protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq; StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>(); const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>(); const auto* attrs_b = b->attrs.as<DenseAttrs>();
CHECK(attrs_a); CHECK(attrs_a);
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
* \brief Abstract class to combine parallel ops and their successive element-wise ops. * \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/ */
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
...@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { ...@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) {
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
const CallNode* call = branches[0][depth]; const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal; tvm::StructuralEqual attrs_equal;
// check if all branches in current depth can be combined // check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) { for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it; const Branch& branch = *it;
......
...@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode ...@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode
return false; return false;
} }
AttrsEqual eq; StructuralEqual eq;
for (size_t i = 0; i < a->args.size(); i++) { for (size_t i = 0; i < a->args.size(); i++) {
auto ta = a->args[i]->type_as<TensorTypeNode>(); auto ta = a->args[i]->type_as<TensorTypeNode>();
auto tb = b->args[i]->type_as<TensorTypeNode>(); auto tb = b->args[i]->type_as<TensorTypeNode>();
...@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { ...@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
} }
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq; StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>(); auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>(); auto tb = b->args[index]->type_as<TensorTypeNode>();
......
...@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator { ...@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator {
const CallNode* new_call = new_expr.as<CallNode>(); const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call); CHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>(); const OpNode* op = new_call->op.as<OpNode>();
AttrsEqual attrs_equal; StructuralEqual attrs_equal;
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) { if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr; return new_expr;
......
...@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) { Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AttrsEqual equal; StructuralEqual equal;
if (in_messages[0].defined() && if (in_messages[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_messages[0]; return in_messages[0];
...@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call, ...@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call,
} }
Message lhs_message = transformer->GetMessage(call->args[0]); Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]); Message rhs_message = transformer->GetMessage(call->args[1]);
AttrsEqual equal; StructuralEqual equal;
if (lhs_message.defined() && rhs_message.defined()) { if (lhs_message.defined() && rhs_message.defined()) {
CHECK(equal(lhs_message->axes, rhs_message->axes)); CHECK(equal(lhs_message->axes, rhs_message->axes));
......
...@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// The output. // The output.
IndexedForwardGraph graph_; IndexedForwardGraph graph_;
// attribute equal comparator // attribute equal comparator
AttrsEqual attr_equal_; StructuralEqual attr_equal_;
// Update the message stored at the node. // Update the message stored at the node.
void Update(const Expr& node, void Update(const Expr& node,
IndexedForwardGraph::Node* parent, IndexedForwardGraph::Node* parent,
......
...@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, ...@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const Array<Integer>& lhs_axes, const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) { Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false; if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal; StructuralEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size(); size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0; size_t j = 0;
......
...@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") ...@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
return RewriteForTensorCore(stmt, schedule, extern_buffer); return RewriteForTensorCore(stmt, schedule, extern_buffer);
}); });
TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
.set_body_typed(
[](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
.set_body_typed([](const ObjectRef &node) -> int64_t {
return AttrsHash()(node);
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
......
...@@ -106,7 +106,6 @@ def test_function(): ...@@ -106,7 +106,6 @@ def test_function():
check_json_roundtrip(fn) check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs(): def test_function_attrs():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names]) params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names])
......
...@@ -51,14 +51,13 @@ def test_dict_attrs(): ...@@ -51,14 +51,13 @@ def test_dict_attrs():
def test_attrs_equal(): def test_attrs_equal():
attr_equal = tvm.ir._ffi_api.AttrsEqual
dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
assert attr_equal(dattr0, dattr1) assert tvm.ir.structural_equal(dattr0, dattr1)
assert not attr_equal(dattr0, dattr2) assert not tvm.ir.structural_equal(dattr0, dattr2)
assert not attr_equal({"x": 1}, tvm.runtime.convert(1)) assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
assert not attr_equal([1, 2], tvm.runtime.convert(1)) assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
......
...@@ -21,28 +21,28 @@ def test_attrs_equal(): ...@@ -21,28 +21,28 @@ def test_attrs_equal():
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1)) z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
assert tvm.tir.ir_pass.AttrsEqual(x, y) assert tvm.ir.structural_equal(x, y)
assert not tvm.tir.ir_pass.AttrsEqual(x, z) assert not tvm.ir.structural_equal(x, z)
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert not tvm.tir.ir_pass.AttrsEqual(dattr, x) assert not tvm.ir.structural_equal(dattr, x)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert tvm.tir.ir_pass.AttrsEqual(dattr, dattr2) assert tvm.ir.structural_equal(dattr, dattr2)
assert tvm.tir.ir_pass.AttrsEqual({"x": x}, {"x": y}) assert tvm.ir.structural_equal({"x": x}, {"x": y})
# array related checks # array related checks
assert tvm.tir.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]})
assert not tvm.tir.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]})
n = te.var("n") n = te.var("n")
assert tvm.tir.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1}) assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1})
def test_attrs_hash(): def test_attrs_hash():
fhash = tvm.tir.ir_pass.AttrsHash fhash = tvm.ir.structural_hash
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y}) assert fhash({"x": x}) == fhash({"x": y})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment