Unverified Commit b7a8af8d by Tianqi Chen Committed by GitHub

[LANG][ATTRS] Enable deep equality comparison and hash of Attrs (#1903)

parent b64f3f1c
......@@ -27,8 +27,10 @@
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#include <dmlc/common.h>
#include <unordered_map>
#include <vector>
#include <functional>
#include <type_traits>
#include <string>
#include "ir.h"
......@@ -129,8 +131,8 @@ class BaseAttrsNode : public Node {
*/
inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Get the field information about the
* \note This function throws when the required a field is not present.
* \brief Get the field information
* \return The fields in the Attrs.
*/
TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
/*!
......@@ -138,9 +140,20 @@ class BaseAttrsNode : public Node {
* \param kwargs The key value pairs for initialization.
* [key0, value0, key1, value1, ..., key_n, value_n]
* \param allow_unknown Whether allow additional unknown fields.
* \note This function throws when the required a field is not present.
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(const Node* other) const = 0;
/*!
* \brief Content aware hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash() const = 0;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
......@@ -188,11 +201,93 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final;
size_t ContentHash() const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
};
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
return lhs == rhs;
}
bool operator()(const NodeRef& lhs, const NodeRef& rhs) const {
return AttrsEqual::Equal(lhs, rhs);
}
// comparator of NodeRef types.
static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs);
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
size_t operator()(const NodeRef& value) const {
return AttrsHash::Hash(value);
}
// hash function of the attribute and attribute fields.
static TVM_DLL size_t Hash(const NodeRef& lhs);
};
// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
......@@ -234,6 +329,44 @@ class AttrNormalVisitor {
AttrVisitor* visitor_;
};
// Wrapper for normal visitor.
class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs)
: lhs_(lhs), rhs_(rhs) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!AttrsEqual()(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Node* lhs_;
const Node* rhs_;
};
class AttrsHashVisitor {
public:
size_t result_{0};
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, AttrsHash()(*value));
return AttrNopEntry();
}
};
// helper entry that does initialization, set default.
template<typename T>
struct AttrInitEntry {
......@@ -596,6 +729,23 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}
bool ContentEqual(const Node* other) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
detail::AttrsEqualVisitor visitor(pself, other);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
size_t ContentHash() const final {
detail::AttrsHashVisitor visitor;
visitor.result_ = std::hash<std::string>()(this->type_key());
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
private:
DerivedType* self() const {
return const_cast<DerivedType*>(
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
......@@ -65,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});
TVM_REGISTER_API("ir_pass.AttrsEqual")
.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal);
TVM_REGISTER_API("ir_pass.AttrsHash")
.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash);
TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
......
/*!
* Copyright (c) 2018 by Contributors
* \file attr_functor.h
* \brief A way to define arbitrary function signature
* with dispatch on common attributes.
*
* Common attributes include:
* - int, float, str constants
* - array of attributes
* - map of attributes
*/
#ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_
namespace tvm {
template <typename FType>
class AttrFunctor;
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
// A functor for common attribute information.
template <typename R, typename... Args>
class AttrFunctor<R(const NodeRef& n, Args...)> {
private:
using TSelf = AttrFunctor<R(const NodeRef& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R Visit(const NodeRef& n, Args... args) {
static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...);
} else {
return VisitDefault_(n, std::forward<Args>(args)...);
}
}
virtual R Visit_(const ArrayNode* op, Args... args) = 0;
virtual R Visit_(const StrMapNode* op, Args... args) = 0;
virtual R Visit_(const ir::IntImm* op, Args... args) = 0;
virtual R Visit_(const ir::UIntImm* op, Args... args) = 0;
virtual R Visit_(const ir::FloatImm* op, Args... args) = 0;
virtual R Visit_(const ir::StringImm* op, Args... args) = 0;
virtual R VisitDefault_(const NodeRef& n, Args... args) = 0;
private:
// initialize the vtable.
static FType InitVTable() {
using namespace ir;
FType vtable;
// Set dispatch
ATTR_FUNCTOR_DISPATCH(StrMapNode);
ATTR_FUNCTOR_DISPATCH(ArrayNode);
ATTR_FUNCTOR_DISPATCH(IntImm);
ATTR_FUNCTOR_DISPATCH(UIntImm);
ATTR_FUNCTOR_DISPATCH(FloatImm);
ATTR_FUNCTOR_DISPATCH(StringImm);
return vtable;
}
};
} // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_
......@@ -3,6 +3,7 @@
* \file attrs.cc
*/
#include <tvm/attrs.h>
#include "attr_functor.h"
namespace tvm {
......@@ -44,4 +45,158 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
using namespace ir;
class AttrsEqualChecker :
public AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
public:
bool Check(const NodeRef& lhs, const NodeRef& rhs) {
if (!equal_) return false;
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!this->Visit(lhs, rhs)) {
equal_ = false;
}
return equal_;
}
bool VisitDefault_(const NodeRef& lhs, const NodeRef& other) final {
if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs.get())->ContentEqual(other.get());
}
return lhs.same_as(other);
}
bool Visit_(const IntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const UIntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const FloatImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const StringImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const ArrayNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Check(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
}
}
return true;
}
bool Visit_(const StrMapNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!Check(NodeRef(kv.second), NodeRef(it->second))) return false;
}
}
return true;
}
private:
bool equal_{true};
};
class AttrContentHasher :
public AttrFunctor<void(const NodeRef&)> {
public:
size_t result_{0};
void VisitDefault_(const NodeRef& value) final {
if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value.get())->ContentHash());
} else {
Update(NodeHash()(value));
}
}
void Visit_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value));
}
void Visit_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value));
}
void Visit_(const FloatImm* op) final {
Update(std::hash<double>()(op->value));
}
void Visit_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value));
}
void Visit_(const ArrayNode* op) final {
Update(op->data.size());
for (size_t i = 0; i < op->data.size(); ++i) {
this->Visit(NodeRef(op->data[i]));
}
}
void Visit_(const StrMapNode* lhs) final {
using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first;
});
for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first));
this->Visit(NodeRef(kv.second));
}
}
void Update(size_t value) {
result_ = dmlc::HashCombine(result_, value);
}
};
bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
AttrsEqualChecker checker;
return checker.Check(lhs, rhs);
}
size_t AttrsHash::Hash(const NodeRef& node) {
if (!node.defined()) return 0;
AttrContentHasher hasher;
hasher.Visit(node);
return hasher.result_;
}
size_t DictAttrsNode::ContentHash() const {
return AttrsHash()(this->dict);
}
bool DictAttrsNode::ContentEqual(const Node* other) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
return AttrsEqual()(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
} // namespace tvm
......@@ -124,38 +124,6 @@ def test_binary_broadcast():
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_multibox_prior():
sizes = (0.3, 1.5, 0.7)
ratios = (1.3, 2.4)
steps = (2.0, 1.5)
offsets = (0.2, 0.3)
clip = True
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 3, 56, 56
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x.var, sizes, ratios,
steps, offsets, clip))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32")
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 24, 32, 32
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w, 4), "float32")
def test_where():
ib = relay.ir_builder.IRBuilder()
cond = ib.param("cond", relay.TensorType((3, 4), "float32"))
......
......@@ -25,5 +25,41 @@ def test_resize_infer_type():
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((n, c, 100, 200), "int8")
def test_multibox_prior():
sizes = (0.3, 1.5, 0.7)
ratios = (1.3, 2.4)
steps = (2.0, 1.5)
offsets = (0.2, 0.3)
clip = True
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 3, 56, 56
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x, sizes, ratios,
steps, offsets, clip))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32")
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 24, 32, 32
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
ib.ret(relay.vision.multibox_prior(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(1, h * w, 4), "float32")
if __name__ == "__main__":
test_resize_infer_type()
test_multibox_prior()
import tvm
def test_attrs_equal():
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1))
assert tvm.ir_pass.AttrsEqual(x, y)
assert not tvm.ir_pass.AttrsEqual(x, z)
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert not tvm.ir_pass.AttrsEqual(dattr, x)
dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert tvm.ir_pass.AttrsEqual(dattr, dattr2)
assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y})
# array related checks
assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
def test_attrs_hash():
fhash = tvm.ir_pass.AttrsHash
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y})
assert fhash({"x": x}) != fhash({"x": [y, 1]})
assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]})
assert fhash({"x": [x, 2]}) == fhash({"x": [y, 2]})
if __name__ == "__main__":
test_attrs_equal()
test_attrs_hash()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment