Unverified Commit 81db03ab by Tianqi Chen Committed by GitHub

[RELAY] Refactor AlphaEqual to support deep comparison of Attrs. (#1958)

parent 28f1a9fd
......@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node {
/*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const NodeRef& value) const;
private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
......@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node {
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(const Node* other) const = 0;
TVM_DLL virtual bool ContentEqual(
const Node* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash() const = 0;
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
......@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final;
size_t ContentHash() const final;
bool ContentEqual(const Node* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
};
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
bool operator()(const NodeRef& lhs, const NodeRef& rhs) const {
return AttrsEqual::Equal(lhs, rhs);
}
// comparator of NodeRef types.
static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs);
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
size_t operator()(const NodeRef& value) const {
return AttrsHash::Hash(value);
}
// hash function of the attribute and attribute fields.
static TVM_DLL size_t Hash(const NodeRef& lhs);
};
// Namespace containing detail implementations
namespace detail {
......@@ -342,8 +350,8 @@ class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs)
: lhs_(lhs), rhs_(rhs) {
AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
......@@ -353,7 +361,7 @@ class AttrsEqualVisitor {
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!AttrsEqual()(*lhs_value, *rhs_value)) {
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
......@@ -362,17 +370,24 @@ class AttrsEqualVisitor {
private:
const Node* lhs_;
const Node* rhs_;
const AttrsEqual& equal_;
};
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}
size_t result_{0};
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, AttrsHash()(*value));
result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry();
}
private:
const AttrsHash& hasher_;
};
// helper entry that does initialization, set default.
......@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}
bool ContentEqual(const Node* other) const final {
bool ContentEqual(const Node* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other);
detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
size_t ContentHash() const final {
detail::AttrsHashVisitor visitor;
size_t ContentHash(AttrsHash hasher) const final {
detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor);
return visitor.result_;
......
......@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal")
TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal);
.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash);
.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
return AttrsHash()(node);
});
TVM_REGISTER_API("ir_pass.ExprUseVar")
......
......@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
}
}
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
// deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
private:
// initialize the vtable.
......@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(UIntImm);
ATTR_FUNCTOR_DISPATCH(FloatImm);
ATTR_FUNCTOR_DISPATCH(StringImm);
ATTR_FUNCTOR_DISPATCH(Variable);
ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
ATTR_FUNCTOR_DISPATCH(GT);
ATTR_FUNCTOR_DISPATCH(LE);
ATTR_FUNCTOR_DISPATCH(LT);
ATTR_FUNCTOR_DISPATCH(EQ);
ATTR_FUNCTOR_DISPATCH(NE);
ATTR_FUNCTOR_DISPATCH(And);
ATTR_FUNCTOR_DISPATCH(Or);
ATTR_FUNCTOR_DISPATCH(Not);
ATTR_FUNCTOR_DISPATCH(Cast);
ATTR_FUNCTOR_DISPATCH(Call);
ATTR_FUNCTOR_DISPATCH(Select);
return vtable;
}
};
class AttrsEqualHandler :
protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs);
protected:
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const NodeRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const NodeRef& node) {
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Node* lhs) final;
size_t VisitAttr_(const ir::IntImm* lhs) final;
size_t VisitAttr_(const ir::UIntImm* lhs) final;
size_t VisitAttr_(const ir::FloatImm* lhs) final;
size_t VisitAttr_(const ir::StringImm* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
size_t VisitAttr_(const ir::GT* op) final;
size_t VisitAttr_(const ir::LE* op) final;
size_t VisitAttr_(const ir::LT* op) final;
size_t VisitAttr_(const ir::EQ* op) final;
size_t VisitAttr_(const ir::NE* op) final;
size_t VisitAttr_(const ir::And* op) final;
size_t VisitAttr_(const ir::Or* op) final;
size_t VisitAttr_(const ir::Not* op) final;
size_t VisitAttr_(const ir::Cast* op) final;
size_t VisitAttr_(const ir::Call* op) final;
size_t VisitAttr_(const ir::Select* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_
......@@ -51,156 +51,272 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
using namespace ir;
class AttrsEqualChecker :
public AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
bool Check(const NodeRef& lhs, const NodeRef& rhs) {
if (!equal_) return false;
// Equal handler.
bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!this->VisitAttr(lhs, rhs)) {
equal_ = false;
}
return equal_;
}
return this->VisitAttr(lhs, rhs);
}
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
AttrsEqual equal;
equal.handler_ = this;
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
other.get(), equal);
}
return lhs == other.get();
}
}
bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
}
bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
}
bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value;
}
return false;
}
}
bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value;
}
return false;
}
}
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Check(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
}
}
return true;
}
}
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!Check(NodeRef(kv.second), NodeRef(it->second))) return false;
if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
}
}
return true;
}
private:
bool equal_{true};
};
class AttrContentHasher :
public AttrFunctor<void(const NodeRef&)> {
public:
size_t result_{0};
}
void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
if (const auto* rhs = other.as<NodeName>()) { \
if (!Equal(lhs->a, rhs->a)) return false; \
if (!Equal(lhs->b, rhs->b)) return false; \
return true; \
} else { \
return false; \
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Not>()) {
return Equal(lhs->a, rhs->a);
} else {
Update(NodeHash()(GetRef<NodeRef>(value)));
}
return false;
}
}
void VisitAttr_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value));
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Cast>()) {
if (lhs->type != rhs->type) return false;
return Equal(lhs->value, rhs->value);
} else {
return false;
}
}
void VisitAttr_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value));
bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Call>()) {
return
lhs->name == rhs->name &&
lhs->type == rhs->type &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
return false;
}
}
void VisitAttr_(const FloatImm* op) final {
Update(std::hash<double>()(op->value));
bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Select>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
Equal(lhs->false_value, rhs->false_value);
} else {
return false;
}
}
void VisitAttr_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value));
// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
if (value->derived_from<BaseAttrsNode>()) {
AttrsHash hasher;
hasher.handler_ = this;
return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
} else {
return NodeHash()(GetRef<NodeRef>(value));
}
}
size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
return std::hash<int64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
return std::hash<uint64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
return std::hash<std::string>()(op->value);
}
void VisitAttr_(const ArrayNode* op) final {
Update(op->data.size());
size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
size_t result = op->data.size();
for (size_t i = 0; i < op->data.size(); ++i) {
this->VisitAttr(NodeRef(op->data[i]));
}
result = Combine(result, this->Hash(NodeRef(op->data[i])));
}
return result;
}
void VisitAttr_(const StrMapNode* lhs) final {
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first;
});
size_t result = 0;
for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first));
this->VisitAttr(NodeRef(kv.second));
}
result = Combine(result, std::hash<std::string>()(kv.first));
result = Combine(result, this->Hash(NodeRef(kv.second)));
}
return result;
}
void Update(size_t value) {
result_ = dmlc::HashCombine(result_, value);
}
};
bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);
size_t AttrsHashHandler::VisitAttr_(const Not* op) {
static size_t key = std::hash<std::string>()(Not::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->value));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Call* op) {
static size_t key = std::hash<std::string>()(Call::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->args));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Select* op) {
static size_t key = std::hash<std::string>()(Select::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
res = Combine(res, Hash(op->false_value));
return res;
}
// Default case
bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
if (lhs.same_as(rhs)) return true;
AttrsEqualChecker checker;
return checker.Check(lhs, rhs);
if (handler_ == nullptr) {
return AttrsEqualHandler().Equal(lhs, rhs);
} else {
return handler_->Equal(lhs, rhs);
}
}
size_t AttrsHash::Hash(const NodeRef& node) {
size_t AttrsHash::operator()(const NodeRef& node) const {
if (!node.defined()) return 0;
AttrContentHasher hasher;
hasher.VisitAttr(node);
return hasher.result_;
if (handler_ == nullptr) {
return AttrsHashHandler().Hash(node);
} else {
return handler_->Hash(node);
}
}
size_t DictAttrsNode::ContentHash() const {
return AttrsHash()(this->dict);
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
return hasher(this->dict);
}
bool DictAttrsNode::ContentEqual(const Node* other) const {
bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
return AttrsEqual()(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
// Alpha equal handler for relay.
class AlphaEqualHandler:
public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>,
public ExprFunctor<bool(const Expr&, const Expr&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) {}
/*!
* Check equality of two nodes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs->derived_from<TypeNode>()) {
if (!rhs->derived_from<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->derived_from<ExprNode>()) {
if (!rhs->derived_from<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
return AttrEqual(lhs, rhs);
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqualHandler::Equal(lhs, rhs);
}
/*!
* Check equality of two types.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool TypeEqual(const Type& lhs, const Type& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
}
/*!
* Check equality of two expressions.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool ExprEqual(const Expr& lhs, const Expr& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
}
if (this->VisitExpr(lhs, rhs)) {
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
}
protected:
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
return lhs == rhs;
}
/*!
* \brief Check Equality of leaf node of the graph.
* if map_free_var_ is set to true, try to map via equal node.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return the compare result.
*/
bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
} else {
if (map_free_var_) {
if (lhs->type_index() != rhs->type_index()) return false;
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
}
}
using AttrsEqualHandler::VisitAttr_;
bool VisitAttr_(const Variable* lhs, const NodeRef& other) final {
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
// Type equality
bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
return (lhs->dtype == rhs->dtype &&
AttrEqual(lhs->shape, rhs->shape));
} else {
return false;
}
}
bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
if (lhs->kind != rhs->kind) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else {
return false;
}
}
bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
return false;
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
}
for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
}
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
if (!TypeEqual(lhs->type_constraints[i],
rhs->type_constraints[i])) {
return false;
}
}
return true;
} else {
return false;
}
}
bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
if (!lhs->func.same_as(rhs->func)) return false;
if (lhs->num_inputs != rhs->num_inputs) return false;
if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(*lhs.operator->());
return std::memcmp(lhs->data, rhs->data, data_size) == 0;
} else {
return false;
}
}
}
// merge declaration of two variables together.
bool MergeVarDecl(const Var& lhs, const Var& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!TypeEqual(lhs->type_annotation,
rhs->type_annotation)) return false;
CHECK(!equal_map_.count(lhs))
<< "Duplicated declaration of variable " << lhs;
equal_map_[lhs] = rhs;
return true;
}
bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
if (const VarNode* rhs = other.as<VarNode>()) {
if (lhs->name_hint != rhs->name_hint) return false;
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else {
return false;
}
}
bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
// use name equality for global var for now.
if (lhs->name_hint != rhs->name_hint) return false;
return true;
} else {
return false;
}
}
bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
if (const TupleNode* rhs = other.as<TupleNode>()) {
if (lhs->fields.size() != rhs->fields.size()) return false;
for (size_t i = 0; i < lhs->fields.size(); ++i) {
if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
}
return true;
} else {
return false;
}
}
bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
if (const FunctionNode* rhs = other.as<FunctionNode>()) {
if (lhs->params.size() != rhs->params.size()) return false;
if (lhs->type_params.size() != rhs->type_params.size()) return false;
// map type parameter to be the same
for (size_t i = 0; i < lhs->type_params.size(); ++i) {
if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
}
// check parameter type annotations
for (size_t i = 0; i < lhs->params.size(); ++i) {
if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
}
// check return types.
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
return false;
}
}
bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
}
}
bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
if (const IfNode* rhs = other.as<IfNode>()) {
return ExprEqual(lhs->cond, rhs->cond) &&
ExprEqual(lhs->true_branch, rhs->true_branch) &&
ExprEqual(lhs->false_branch, rhs->false_branch);
} else {
return false;
}
}
bool VisitExpr_(const OpNode* op, const Expr& other) final {
return op == other.get();
}
bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
if (const ConstantNode* rhs = other.as<ConstantNode>()) {
return NDArrayEqual(lhs->data, rhs->data);
} else {
return false;
}
}
bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
} else {
return false;
}
}
private:
// whether to map open terms.
bool map_free_var_{false};
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};
bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
}
bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
});
} // namespace relay
} // namespace tvm
......@@ -6,7 +6,7 @@
#include <tvm/relay/environment.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include "../pass/type_functor.h"
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
......@@ -245,6 +245,9 @@ class TextPrinter :
stream_ << ", ";
}
}
if (fields.size() == 1) {
stream_ << ',';
}
stream_ << ')';
this->PrintEndInst("\n");
return id;
......@@ -648,7 +651,7 @@ class TextPrinter :
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var));
CHECK(!memo_.count(var)) << "Duplicated variable " << var;
memo_[var] = val;
return val;
}
......
......@@ -3,12 +3,13 @@
* \file type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
......@@ -89,6 +90,113 @@ class TypeFunctor<R(const Type& n, Args...)> {
}
};
/*!
* \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief Check that two type are syntactically equal up to alpha equivalence.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
if (lhs.defined() != rhs.defined()) {
return false;
} else if (lhs.same_as(rhs)) {
return true;
} else {
auto ldt = lhs->dtype;
auto rdt = rhs->dtype;
CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t s = GetDataSize(*lhs.operator->());
return memcmp(lhs->data, rhs->data, s) == 0;
} else {
return false;
}
}
}
struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeVar, TypeVar> eq_map;
bool equal;
TypeAlphaEq() : eq_map(), equal(true) {}
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
if (dt1 != dt2) {
equal = false;
}
}
void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
if (s1.size() != s2.size()) {
equal = false;
return;
}
for (size_t i = 0; i < s1.size(); ++i) {
if (!tvm::ir::Equal(s1[i], s2[i])) {
equal = false;
return;
}
}
}
void VisitType_(const TensorTypeNode* tt1, const Type& t2) final {
if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) {
DataTypeEqual(tt1->dtype, tt2->dtype);
ShapeEqual(tt1->shape, tt2->shape);
} else {
equal = false;
}
}
void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final {
if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) {
equal = equal && bt1 == bt2;
return;
} else {
equal = false;
}
}
void VisitType_(const TypeVarNode* ti1, const Type& t2) final {
if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) {
auto tid1 = GetRef<TypeVar>(ti1);
auto tid2 = GetRef<TypeVar>(ti2);
// We handle open terms with this rule assuming variables are identical.
//
// Not sure if we should do this.
if (tid1 == tid2) {
return;
}
// Check that they are same kind
if (tid1->kind != tid2->kind) {
equal = false;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(tid1) != eq_map.end()) {
equal = equal && eq_map[tid1] == tid2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitType_(const FuncTypeNode* op, const Type& t2) final {
if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) {
if (op->arg_types.size() != ta2->arg_types.size()
|| op->type_params.size() != ta2->type_params.size()
|| op->type_constraints.size() != ta2->type_constraints.size()) {
equal = false;
return;
}
// must visit params first so they are appropriate entered
// into equality map
for (size_t i = 0; i < op->type_params.size(); i++) {
eq_map.Set(op->type_params[i], ta2->type_params[i]);
this->VisitType(op->type_params[i], ta2->type_params[i]);
if (!equal) {
return;
}
}
for (size_t i = 0; i < op->arg_types.size(); i++) {
this->VisitType(op->arg_types[i], ta2->arg_types[i]);
if (!equal) {
return;
}
}
this->VisitType(op->ret_type, ta2->ret_type);
if (!equal) {
return;
}
for (size_t i = 0; i < op->type_constraints.size(); i++) {
this->VisitType(op->type_constraints[i], ta2->type_constraints[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitType_(const TypeRelationNode* tr1, const Type& t2) final {
if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) {
if (tr1->func != tr2->func
|| tr1->num_inputs != tr2->num_inputs
|| tr1->attrs != tr2->attrs) {
equal = false;
return;
}
if (tr1->args.size() != tr2->args.size()) {
equal = false;
return;
}
for (size_t i = 0; i < tr1->args.size(); i++) {
this->VisitType(tr1->args[i], tr2->args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitType_(const TupleTypeNode* op, const Type& t2) final {
if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) {
if (op->fields.size() != pt->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < op->fields.size(); i++) {
if (!equal) {
return;
}
this->VisitType(op->fields[i], pt->fields[i]);
}
} else {
equal = false;
}
}
};
bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined()) {
return false;
}
if (!t1.defined()) {
return true;
}
TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
return aeq.equal;
}
struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
public:
tvm::Map<Var, Var> eq_map;
bool equal;
AlphaEq() : eq_map(), equal(true) {}
void VisitExpr_(const VarNode* e1, const Expr& e2) final {
if (const VarNode* id2 = e2.as<VarNode>()) {
auto local1 = GetRef<Var>(e1);
auto local2 = GetRef<Var>(id2);
// We handle open terms with this rule assuming variables are identical.
if (local1 == local2) {
equal = true;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(local1) != eq_map.end()) {
equal = equal && eq_map[local1] == local2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final {
if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) {
equal = equal && g1 == g2;
} else {
equal = false;
}
}
void VisitExpr_(const TupleNode* pl1, const Expr& e2) final {
Tuple prod1 = GetRef<Tuple>(pl1);
if (const TupleNode* pl2 = e2.as<TupleNode>()) {
Tuple prod2 = GetRef<Tuple>(pl2);
if (prod1->fields.size() != prod2->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < prod1->fields.size(); i++) {
this->VisitExpr(prod1->fields[i], prod2->fields[i]);
}
} else {
equal = false;
}
}
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) {
equal = false;
return;
}
if (func1->type_params.size() != func2->type_params.size()) {
equal = false;
return;
}
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) {
return;
}
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
if (!equal) {
return;
}
}
equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
if (!equal) {
return;
}
this->VisitExpr(func1->body, func2->body);
} else {
equal = false;
}
}
void VisitExpr_(const CallNode* op, const Expr& e2) final {
if (const CallNode* call = e2.as<CallNode>()) {
this->VisitExpr(op->op, call->op);
if (op->args.size() != call->args.size()) {
equal = false;
return;
}
if (op->type_args.size() != call->type_args.size()) {
equal = false;
return;
}
// checking attrs by pointer equality for now
equal = equal && (op->attrs == call->attrs);
if (!equal) {
return;
}
for (size_t i = 0U; i < op->args.size(); i++) {
this->VisitExpr(op->args[i], call->args[i]);
}
for (size_t i = 0U; i < op->type_args.size(); i++) {
equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
if (!equal) {
return;
}
}
} else {
equal = false;
}
}
void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode* let = e2.as<LetNode>()) {
MergeVarDecl(op->var, let->var);
this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body);
} else {
equal = false;
}
}
void VisitExpr_(const IfNode* op, const Expr& e2) final {
if (const IfNode* i = e2.as<IfNode>()) {
VisitExpr(op->cond, i->cond);
VisitExpr(op->true_branch, i->true_branch);
VisitExpr(op->false_branch, i->false_branch);
} else {
equal = false;
}
}
void VisitExpr_(const OpNode* op, const Expr& e2) final {
if (const OpNode* o = e2.as<OpNode>()) {
equal = equal && op->name == o->name;
} else {
equal = false;
}
}
void VisitExpr_(const ConstantNode* op, const Expr& e2) final {
if (const ConstantNode* c = e2.as<ConstantNode>()) {
if (AlphaEqual(op->tensor_type(), c->tensor_type())) {
equal = equal && SameNDArray(op->data, c->data);
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
private:
void MergeVarDecl(const Var& var1, const Var& var2) {
equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation);
if (!equal) {
return;
}
eq_map.Set(var1, var2);
}
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
AlphaEq eq;
eq.VisitExpr(e1, e2);
return eq.equal;
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expr e1 = args[0];
Expr e2 = args[1];
*ret = AlphaEqual(e1, e2);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Type t1 = args[0];
Type t2 = args[1];
*ret = AlphaEqual(t1, t2);
});
} // namespace relay
} // namespace tvm
......@@ -14,7 +14,7 @@
* contains a data type such as `int`, `float`, `uint`.
*/
#include <tvm/relay/pass.h>
#include "./type_visitor.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......@@ -105,7 +105,7 @@ bool KindCheck(const Type& t, const Environment& env) {
}
TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], EnvironmentNode::make({}));
} else {
......
......@@ -4,7 +4,7 @@
* \brief Function for substituting a concrete type in place of a type ID
*/
#include "./type_subst.h"
#include "./type_visitor.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_visitor.h
* \brief A wrapper around TypeFunctor for common use cases.
*/
#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_
#define TVM_RELAY_PASS_TYPE_VISITOR_H_
#include <vector>
#include "./type_functor.h"
namespace tvm {
namespace relay {
/*! \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_
......@@ -7,7 +7,7 @@
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
......
......@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal():
# attrs are also compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
......@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal():
diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)
......@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal():
assert tr != diff_order
assert tr != diff_args
assert tr != diff_attr
assert tr == same_attr
assert tr != bigger
assert bigger != diff_num_inputs
......@@ -216,22 +219,26 @@ def test_global_var_alpha_equal():
def test_tuple_alpha_equal():
v0 = relay.Var("v0")
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# unit value is a valid tuple
assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
assert alpha_equal(tup, same)
# use the eq_map
let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3),
let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])]),
v2)
assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
......@@ -340,7 +347,8 @@ def test_call_alpha_equal():
# attrs are compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8")
......@@ -375,6 +383,9 @@ def test_call_alpha_equal():
different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
assert not alpha_equal(call, different_attrs)
same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
assert alpha_equal(call, same_attrs)
no_type_args = relay.Call(v1, basic_args, attr1)
assert not alpha_equal(call, no_type_args)
......@@ -445,6 +456,27 @@ def test_op_alpha_equal():
assert not alpha_equal(op1, op3)
def test_graph_equal():
x = relay.var("x")
y0 = relay.add(x, x)
z0 = relay.add(y0, y0)
y1 = relay.add(x, x)
z1 = relay.add(y1, y1)
z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
# Relay view them as different programs
# Check the difference in the text format.
assert not alpha_equal(z0, z3)
if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
......@@ -462,3 +494,4 @@ if __name__ == "__main__":
test_if_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()
test_graph_equal()
......@@ -17,6 +17,12 @@ def test_attrs_equal():
assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
n = tvm.var("n")
assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment