Unverified Commit 81db03ab by Tianqi Chen Committed by GitHub

[RELAY] Refactor AlphaEqual to support deep comparison of Attrs. (#1958)

parent 28f1a9fd
...@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node { ...@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node {
/*! \brief AttrFieldInfo */ /*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const NodeRef& value) const;
private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};
/*! /*!
* \brief Base class of all attribute class * \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly, * \note Do not subclass AttrBaseNode directly,
...@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node { ...@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node {
/*! /*!
* \brief Whether this attribute's content equals to another node. * \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node. * \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result. * \return The comparison result.
*/ */
TVM_DLL virtual bool ContentEqual(const Node* other) const = 0; TVM_DLL virtual bool ContentEqual(
const Node* other, AttrsEqual equal) const = 0;
/*! /*!
* \brief Content aware hash. * \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result. * \return the hash result.
*/ */
TVM_DLL virtual size_t ContentHash() const = 0; TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const char* _type_key = "Attrs"; static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
...@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final; bool ContentEqual(const Node* other, AttrsEqual equal) const final;
size_t ContentHash() const final; size_t ContentHash(AttrsHash hasher) const final;
// type info // type info
static constexpr const char* _type_key = "DictAttrs"; static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
}; };
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
bool operator()(const NodeRef& lhs, const NodeRef& rhs) const {
return AttrsEqual::Equal(lhs, rhs);
}
// comparator of NodeRef types.
static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs);
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
size_t operator()(const NodeRef& value) const {
return AttrsHash::Hash(value);
}
// hash function of the attribute and attribute fields.
static TVM_DLL size_t Hash(const NodeRef& lhs);
};
// Namespace containing detail implementations // Namespace containing detail implementations
namespace detail { namespace detail {
...@@ -342,8 +350,8 @@ class AttrsEqualVisitor { ...@@ -342,8 +350,8 @@ class AttrsEqualVisitor {
public: public:
bool result_{true}; bool result_{true};
// constructor // constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs) AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs) { : lhs_(lhs), rhs_(rhs), equal_(equal) {
} }
template<typename T> template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) { AttrNopEntry operator()(const char* key, T* lhs_value) {
...@@ -353,7 +361,7 @@ class AttrsEqualVisitor { ...@@ -353,7 +361,7 @@ class AttrsEqualVisitor {
reinterpret_cast<const char*>(rhs_) + reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) - (reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_))); reinterpret_cast<const char*>(lhs_)));
if (!AttrsEqual()(*lhs_value, *rhs_value)) { if (!equal_(*lhs_value, *rhs_value)) {
result_ = false; result_ = false;
} }
return AttrNopEntry(); return AttrNopEntry();
...@@ -362,17 +370,24 @@ class AttrsEqualVisitor { ...@@ -362,17 +370,24 @@ class AttrsEqualVisitor {
private: private:
const Node* lhs_; const Node* lhs_;
const Node* rhs_; const Node* rhs_;
const AttrsEqual& equal_;
}; };
class AttrsHashVisitor { class AttrsHashVisitor {
public: public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}
size_t result_{0}; size_t result_{0};
template<typename T> template<typename T>
AttrNopEntry operator()(const char* key, T* value) { AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, AttrsHash()(*value)); result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry(); return AttrNopEntry();
} }
private:
const AttrsHash& hasher_;
}; };
// helper entry that does initialization, set default. // helper entry that does initialization, set default.
...@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode { ...@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_; return visitor.fields_;
} }
bool ContentEqual(const Node* other) const final { bool ContentEqual(const Node* other, AttrsEqual equal) const final {
DerivedType* pself = self(); DerivedType* pself = self();
if (pself == other) return true; if (pself == other) return true;
if (other == nullptr) return false; if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false; if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other); detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
} }
size_t ContentHash() const final { size_t ContentHash(AttrsHash hasher) const final {
detail::AttrsHashVisitor visitor; detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key()); visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.result_; return visitor.result_;
......
...@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal") ...@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal")
TVM_REGISTER_API("ir_pass.AttrsEqual") TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal); .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_API("ir_pass.AttrsHash") TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash); .set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
return AttrsHash()(node);
});
TVM_REGISTER_API("ir_pass.ExprUseVar") TVM_REGISTER_API("ir_pass.ExprUseVar")
......
...@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
} }
} }
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; // deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
private: private:
// initialize the vtable. // initialize the vtable.
...@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(UIntImm); ATTR_FUNCTOR_DISPATCH(UIntImm);
ATTR_FUNCTOR_DISPATCH(FloatImm); ATTR_FUNCTOR_DISPATCH(FloatImm);
ATTR_FUNCTOR_DISPATCH(StringImm); ATTR_FUNCTOR_DISPATCH(StringImm);
ATTR_FUNCTOR_DISPATCH(Variable);
ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
ATTR_FUNCTOR_DISPATCH(GT);
ATTR_FUNCTOR_DISPATCH(LE);
ATTR_FUNCTOR_DISPATCH(LT);
ATTR_FUNCTOR_DISPATCH(EQ);
ATTR_FUNCTOR_DISPATCH(NE);
ATTR_FUNCTOR_DISPATCH(And);
ATTR_FUNCTOR_DISPATCH(Or);
ATTR_FUNCTOR_DISPATCH(Not);
ATTR_FUNCTOR_DISPATCH(Cast);
ATTR_FUNCTOR_DISPATCH(Call);
ATTR_FUNCTOR_DISPATCH(Select);
return vtable; return vtable;
} }
}; };
class AttrsEqualHandler :
protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const NodeRef& lhs, const NodeRef& rhs);
protected:
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
};
class AttrsHashHandler :
protected AttrFunctor<size_t(const NodeRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const NodeRef& node) {
return this->VisitAttr(node);
}
protected:
size_t VisitAttrDefault_(const Node* lhs) final;
size_t VisitAttr_(const ir::IntImm* lhs) final;
size_t VisitAttr_(const ir::UIntImm* lhs) final;
size_t VisitAttr_(const ir::FloatImm* lhs) final;
size_t VisitAttr_(const ir::StringImm* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
size_t VisitAttr_(const ir::GT* op) final;
size_t VisitAttr_(const ir::LE* op) final;
size_t VisitAttr_(const ir::LT* op) final;
size_t VisitAttr_(const ir::EQ* op) final;
size_t VisitAttr_(const ir::NE* op) final;
size_t VisitAttr_(const ir::And* op) final;
size_t VisitAttr_(const ir::Or* op) final;
size_t VisitAttr_(const ir::Not* op) final;
size_t VisitAttr_(const ir::Cast* op) final;
size_t VisitAttr_(const ir::Call* op) final;
size_t VisitAttr_(const ir::Select* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm } // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_ #endif // TVM_LANG_ATTR_FUNCTOR_H_
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/relay/environment.h> #include <tvm/relay/environment.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <sstream> #include <sstream>
#include "../pass/type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../lang/attr_functor.h"
namespace tvm { namespace tvm {
...@@ -245,6 +245,9 @@ class TextPrinter : ...@@ -245,6 +245,9 @@ class TextPrinter :
stream_ << ", "; stream_ << ", ";
} }
} }
if (fields.size() == 1) {
stream_ << ',';
}
stream_ << ')'; stream_ << ')';
this->PrintEndInst("\n"); this->PrintEndInst("\n");
return id; return id;
...@@ -648,7 +651,7 @@ class TextPrinter : ...@@ -648,7 +651,7 @@ class TextPrinter :
name = "%" + name; name = "%" + name;
} }
TextValue val(GetUniqueName(name)); TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)); CHECK(!memo_.count(var)) << "Duplicated variable " << var;
memo_[var] = val; memo_[var] = val;
return val; return val;
} }
......
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
* \file type_functor.h * \file type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types. * \brief A way to defined arbitrary function signature with dispatch on types.
*/ */
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #define TVM_RELAY_IR_TYPE_FUNCTOR_H_
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -89,6 +90,113 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -89,6 +90,113 @@ class TypeFunctor<R(const Type& n, Args...)> {
} }
}; };
/*!
* \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
* contains a data type such as `int`, `float`, `uint`. * contains a data type such as `int`, `float`, `uint`.
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) { ...@@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) {
} }
TVM_REGISTER_API("relay._ir_pass.check_kind") TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], EnvironmentNode::make({})); *ret = KindCheck(args[0], EnvironmentNode::make({}));
} else { } else {
*ret = KindCheck(args[0], args[1]); *ret = KindCheck(args[0], args[1]);
} }
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief Function for substituting a concrete type in place of a type ID * \brief Function for substituting a concrete type in place of a type ID
*/ */
#include "./type_subst.h" #include "./type_subst.h"
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
/*!
* Copyright (c) 2018 by Contributors
* \file type_visitor.h
* \brief A wrapper around TypeFunctor for common use cases.
*/
#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_
#define TVM_RELAY_PASS_TYPE_VISITOR_H_
#include <vector>
#include "./type_functor.h"
namespace tvm {
namespace relay {
/*! \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeVarNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeVarNode* op) override {
return GetRef<TypeVar>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeVar> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->func,
new_args,
type_rel->num_inputs,
type_rel->attrs);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<Type>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include "./type_visitor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal(): ...@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal():
# attrs are also compared only by pointer equality # attrs are also compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
...@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal(): ...@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal():
diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1) diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1) diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2) diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1) bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2) diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)
...@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal(): ...@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal():
assert tr != diff_order assert tr != diff_order
assert tr != diff_args assert tr != diff_args
assert tr != diff_attr assert tr != diff_attr
assert tr == same_attr
assert tr != bigger assert tr != bigger
assert bigger != diff_num_inputs assert bigger != diff_num_inputs
...@@ -216,22 +219,26 @@ def test_global_var_alpha_equal(): ...@@ -216,22 +219,26 @@ def test_global_var_alpha_equal():
def test_tuple_alpha_equal(): def test_tuple_alpha_equal():
v0 = relay.Var("v0")
v1 = relay.Var("v1") v1 = relay.Var("v1")
v2 = relay.Var("v2") v2 = relay.Var("v2")
# unit value is a valid tuple # unit value is a valid tuple
assert alpha_equal(relay.Tuple([]), relay.Tuple([])) assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
assert alpha_equal(tup, same) assert alpha_equal(tup, same)
# use the eq_map # use the eq_map
let_tup = relay.Let(v1, tup, v1) let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3), let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])]), relay.Tuple([relay.const(4)])]),
v2) v2)
assert alpha_equal(let_tup, let_mapped) assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]) more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
...@@ -340,7 +347,8 @@ def test_call_alpha_equal(): ...@@ -340,7 +347,8 @@ def test_call_alpha_equal():
# attrs are compared only by pointer equality # attrs are compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
tt1 = relay.TensorType((1, 2, 3), "float32") tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((), "int8") tt2 = relay.TensorType((), "int8")
...@@ -375,6 +383,9 @@ def test_call_alpha_equal(): ...@@ -375,6 +383,9 @@ def test_call_alpha_equal():
different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
assert not alpha_equal(call, different_attrs) assert not alpha_equal(call, different_attrs)
same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
assert alpha_equal(call, same_attrs)
no_type_args = relay.Call(v1, basic_args, attr1) no_type_args = relay.Call(v1, basic_args, attr1)
assert not alpha_equal(call, no_type_args) assert not alpha_equal(call, no_type_args)
...@@ -445,6 +456,27 @@ def test_op_alpha_equal(): ...@@ -445,6 +456,27 @@ def test_op_alpha_equal():
assert not alpha_equal(op1, op3) assert not alpha_equal(op1, op3)
def test_graph_equal():
x = relay.var("x")
y0 = relay.add(x, x)
z0 = relay.add(y0, y0)
y1 = relay.add(x, x)
z1 = relay.add(y1, y1)
z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
# Relay view them as different programs
# Check the difference in the text format.
assert not alpha_equal(z0, z3)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
...@@ -462,3 +494,4 @@ if __name__ == "__main__": ...@@ -462,3 +494,4 @@ if __name__ == "__main__":
test_if_alpha_equal() test_if_alpha_equal()
test_op_alpha_equal() test_op_alpha_equal()
test_var_alpha_equal() test_var_alpha_equal()
test_graph_equal()
...@@ -17,6 +17,12 @@ def test_attrs_equal(): ...@@ -17,6 +17,12 @@ def test_attrs_equal():
assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
n = tvm.var("n")
assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
def test_attrs_hash(): def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash fhash = tvm.ir_pass.AttrsHash
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment