Unverified Commit 6536b356 by Zhi Committed by GitHub

remove AttrsEqual and AttrsHash related code (#5169)

parent a2edd01b
......@@ -46,6 +46,8 @@
#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
......@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};
class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const ObjectRef& value) const;
private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
......@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
......@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
......@@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor* visitor_;
};
// Wrapper for normal visitor.
class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};
class AttrsSEqualVisitor {
public:
bool result_{true};
......@@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const SEqualReducer& equal_;
};
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}
size_t result_{0};
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry();
}
private:
const AttrsHash& hasher_;
};
class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
......@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
......@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}
bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
private:
DerivedType* self() const {
return const_cast<DerivedType*>(
......
......@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
}
};
class AttrsEqualHandler :
protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);
protected:
bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const ObjectRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const ObjectRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Object* lhs) final;
size_t VisitAttr_(const tir::IntImmNode* lhs) final;
size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
size_t VisitAttr_(const tir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const tir::AddNode* op) final;
size_t VisitAttr_(const tir::SubNode* op) final;
size_t VisitAttr_(const tir::MulNode* op) final;
size_t VisitAttr_(const tir::DivNode* op) final;
size_t VisitAttr_(const tir::ModNode* op) final;
size_t VisitAttr_(const tir::FloorDivNode* op) final;
size_t VisitAttr_(const tir::FloorModNode* op) final;
size_t VisitAttr_(const tir::MinNode* op) final;
size_t VisitAttr_(const tir::MaxNode* op) final;
size_t VisitAttr_(const tir::GENode* op) final;
size_t VisitAttr_(const tir::GTNode* op) final;
size_t VisitAttr_(const tir::LENode* op) final;
size_t VisitAttr_(const tir::LTNode* op) final;
size_t VisitAttr_(const tir::EQNode* op) final;
size_t VisitAttr_(const tir::NENode* op) final;
size_t VisitAttr_(const tir::AndNode* op) final;
size_t VisitAttr_(const tir::OrNode* op) final;
size_t VisitAttr_(const tir::NotNode* op) final;
size_t VisitAttr_(const tir::CastNode* op) final;
size_t VisitAttr_(const tir::CallNode* op) final;
size_t VisitAttr_(const tir::SelectNode* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_
......@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
return attrs->dict;
});
using namespace tir;
// Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}
bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) {
if (lhs->IsInstance<BaseAttrsNode>()) {
AttrsEqual equal;
equal.handler_ = this;
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
other.get(), equal);
}
return lhs == other.get();
}
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!Equal(kv.second, it->second)) return false;
}
return true;
} else {
return false;
}
}
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
if (const auto* rhs = other.as<NodeName>()) { \
if (!Equal(lhs->a, rhs->a)) return false; \
if (!Equal(lhs->b, rhs->b)) return false; \
return true; \
} else { \
return false; \
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<NotNode>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CastNode>()) {
if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CallNode>()) {
return
lhs->name == rhs->name &&
lhs->dtype == rhs->dtype &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<SelectNode>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
Equal(lhs->false_value, rhs->false_value);
} else {
return false;
}
}
// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
if (value->IsInstance<BaseAttrsNode>()) {
AttrsHash hasher;
hasher.handler_ = this;
return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
} else {
return ObjectHash()(GetRef<ObjectRef>(value));
}
}
size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
return std::hash<int64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
return std::hash<std::string>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
size_t result = op->data.size();
for (size_t i = 0; i < op->data.size(); ++i) {
result = Combine(result, this->Hash(op->data[i]));
}
return result;
}
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
using Entry = std::pair<std::string, ObjectRef>;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first;
});
size_t result = 0;
for (const Entry& kv : data) {
result = Combine(result, std::hash<std::string>()(kv.first));
result = Combine(result, this->Hash(kv.second));
}
return result;
}
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
static size_t key = std::hash<std::string>()(NotNode::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
static size_t key = std::hash<std::string>()(CastNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->value));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
static size_t key = std::hash<std::string>()(CallNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->args));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
static size_t key = std::hash<std::string>()(SelectNode::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
res = Combine(res, Hash(op->false_value));
return res;
}
// Default case
bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
if (lhs.same_as(rhs)) return true;
if (handler_ == nullptr) {
return AttrsEqualHandler().Equal(lhs, rhs);
} else {
return handler_->Equal(lhs, rhs);
}
}
size_t AttrsHash::operator()(const ObjectRef& node) const {
if (!node.defined()) return 0;
if (handler_ == nullptr) {
return AttrsHashHandler().Hash(node);
} else {
return handler_->Hash(node);
}
}
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
return hasher(this->dict);
}
bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body_typed([](Attrs attrs) {
return attrs->ListFieldInfo();
});
TVM_REGISTER_GLOBAL("ir.AttrsEqual")
.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
return AttrsEqual()(lhs, rhs);
});
} // namespace tvm
......@@ -103,6 +103,7 @@ class RemapVarSEqualHandler :
// Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
if (!lhs.defined() && !rhs.defined()) return true;
task_stack_.clear();
pending_tasks_.clear();
equal_map_lhs_.clear();
......
......@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
StructuralEqual eq;
const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
......@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq;
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
......
......@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
CHECK(attrs_a);
......
......@@ -23,6 +23,7 @@
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
......@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) {
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
tvm::StructuralEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
......
......@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode
return false;
}
AttrsEqual eq;
StructuralEqual eq;
for (size_t i = 0; i < a->args.size(); i++) {
auto ta = a->args[i]->type_as<TensorTypeNode>();
auto tb = b->args[i]->type_as<TensorTypeNode>();
......@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
}
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
AttrsEqual eq;
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
......
......@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
AttrsEqual attrs_equal;
StructuralEqual attrs_equal;
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
......
......@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
AttrsEqual equal;
StructuralEqual equal;
if (in_messages[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_messages[0];
......@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call,
}
Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
AttrsEqual equal;
StructuralEqual equal;
if (lhs_message.defined() && rhs_message.defined()) {
CHECK(equal(lhs_message->axes, rhs_message->axes));
......
......@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// The output.
IndexedForwardGraph graph_;
// attribute equal comparator
AttrsEqual attr_equal_;
StructuralEqual attr_equal_;
// Update the message stored at the node.
void Update(const Expr& node,
IndexedForwardGraph::Node* parent,
......
......@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
AttrsEqual equal;
StructuralEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
......
......@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
.set_body_typed(
[](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
.set_body_typed([](const ObjectRef &node) -> int64_t {
return AttrsHash()(node);
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
......
......@@ -106,7 +106,6 @@ def test_function():
check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names])
......
......@@ -51,14 +51,13 @@ def test_dict_attrs():
def test_attrs_equal():
attr_equal = tvm.ir._ffi_api.AttrsEqual
dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
assert attr_equal(dattr0, dattr1)
assert not attr_equal(dattr0, dattr2)
assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
assert not attr_equal([1, 2], tvm.runtime.convert(1))
assert tvm.ir.structural_equal(dattr0, dattr1)
assert not tvm.ir.structural_equal(dattr0, dattr2)
assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
......
......@@ -21,28 +21,28 @@ def test_attrs_equal():
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
assert tvm.tir.ir_pass.AttrsEqual(x, y)
assert not tvm.tir.ir_pass.AttrsEqual(x, z)
assert tvm.ir.structural_equal(x, y)
assert not tvm.ir.structural_equal(x, z)
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert not tvm.tir.ir_pass.AttrsEqual(dattr, x)
assert not tvm.ir.structural_equal(dattr, x)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert tvm.tir.ir_pass.AttrsEqual(dattr, dattr2)
assert tvm.ir.structural_equal(dattr, dattr2)
assert tvm.tir.ir_pass.AttrsEqual({"x": x}, {"x": y})
assert tvm.ir.structural_equal({"x": x}, {"x": y})
# array related checks
assert tvm.tir.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
assert not tvm.tir.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]})
assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]})
n = te.var("n")
assert tvm.tir.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1})
def test_attrs_hash():
fhash = tvm.tir.ir_pass.AttrsHash
fhash = tvm.ir.structural_hash
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment