Unverified Commit 497d01d3 by Tianqi Chen Committed by GitHub

[NODE][IR] Introduce StructuralHash for the Unified IR. (#5160)

* [NODE][IR] Introduce StructuralHash for the Unified IR.

This PR introduces a new way to handle structural hash for the unified IR.

- Each object can now register an optional SEqualHash function, which
  describes how to reduce its structural equality to sequence of hash values.
- Optionally, the object can choose to allow labeling of vars(e.g. function parameters)
  by calling DefHash
- We implemented a non-recursive structural hasher that maintains its own stack
  to traverse te IR.

This PR also improves the hash value property from the previous relay's hash utility.
In particular, the graph node mode hashs a DAG differently from a tree
by attaching an unique occurence index to each graph node.

In all of the test cases so far, structural_hash is consistent with structural_equal.
- if structrual(x, y) then structural_hash(x) == structural_hash(y)
- if structural_hash(x) == structural_hash(y) then highly likely structural_equal(x, y)
  - hash no collison is found in our testcases.

Ideally we should work on automatically generating these functions in the future.

* Fix cases for EnvFunc and Array dims

* fix testcase

* Update src/node/structural_hash.cc

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>

Co-authored-by: 雾雨魔理沙 <lolisa@marisa.moe>
parent 997a14ed
......@@ -71,6 +71,11 @@ class ConstructorNode : public RelayExprNode {
equal(inputs, other->inputs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce(inputs);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
......@@ -123,6 +128,12 @@ class TypeDataNode : public TypeNode {
equal(constructors, other->constructors);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(header);
hash_reduce.DefHash(type_vars);
hash_reduce(constructors);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
......
......@@ -121,6 +121,7 @@ class AttrFieldInfoNode : public Object {
static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
......@@ -281,6 +282,7 @@ class BaseAttrsNode : public Object {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
......@@ -309,6 +311,10 @@ class DictAttrsNode : public BaseAttrsNode {
return equal(dict, other->dict);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dict);
}
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
......@@ -452,6 +458,21 @@ class AttrsHashVisitor {
const AttrsHash& hasher_;
};
class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
: hash_reducer_(hash_reducer) {}
template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
hash_reducer_(*value);
return AttrNopEntry();
}
private:
const SHashReducer& hash_reducer_;
};
// helper entry that does initialization, set default.
template<typename T>
struct AttrInitEntry {
......@@ -858,6 +879,11 @@ class AttrsNode : public BaseAttrsNode {
return visitor.result_;
}
void SHashReduce(SHashReducer hash_reducer) const {
::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
self()->__VisitAttrs__(visitor);
}
Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
......
......@@ -52,11 +52,18 @@ class EnvFuncNode : public Object {
}
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
// name uniquely identifies the env function.
return name == other->name;
}
void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies the env function.
hash_reduce(name);
}
static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
......
......@@ -44,6 +44,7 @@ class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
......@@ -205,6 +206,11 @@ class GlobalVarNode : public RelayExprNode {
equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
......@@ -240,6 +246,11 @@ class IntImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
......@@ -279,6 +290,11 @@ class FloatImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
......@@ -373,8 +389,14 @@ class RangeNode : public Object {
return equal(min, other->min) && equal(extent, other->extent);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(min);
hash_reduce(extent);
}
static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
......
......@@ -64,6 +64,8 @@ class IRModuleNode : public Object {
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
......@@ -238,6 +240,7 @@ class IRModuleNode : public Object {
static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
......
......@@ -106,6 +106,11 @@ class OpNode : public RelayExprNode {
return this == other;
}
void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies an Op.
hash_reduce(name);
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
......
......@@ -79,6 +79,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(shape);
hash_reduce(dtype);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
......
......@@ -80,6 +80,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
......@@ -115,6 +116,10 @@ class PrimTypeNode : public TypeNode {
return equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
}
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
......@@ -161,6 +166,10 @@ class PointerTypeNode : public TypeNode {
return equal(element_type, other->element_type);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(element_type);
}
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
......@@ -233,6 +242,11 @@ class TypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
......@@ -280,6 +294,11 @@ class GlobalTypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
......@@ -320,6 +339,10 @@ class TupleTypeNode : public TypeNode {
return equal(fields, other->fields);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(fields);
}
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
......@@ -421,6 +444,13 @@ class FuncTypeNode : public TypeNode {
equal(type_constraints, other->type_constraints);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(type_params);
hash_reduce(arg_types);
hash_reduce(ret_type);
hash_reduce(type_constraints);
}
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
......@@ -471,6 +501,10 @@ class IncompleteTypeNode : public TypeNode {
return equal(kind, other->kind);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
}
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
......@@ -512,6 +546,10 @@ class RelayRefTypeNode : public TypeNode {
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
......
......@@ -56,6 +56,11 @@ class TypeCallNode : public TypeNode {
equal(args, other->args);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
}
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
......@@ -209,6 +214,13 @@ class TypeRelationNode : public TypeConstraintNode {
equal(attrs, other->attrs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
hash_reduce(num_inputs);
hash_reduce(attrs);
}
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
......
......@@ -41,6 +41,7 @@
#include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <string>
#include <vector>
......
......@@ -30,6 +30,7 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <vector>
#include <string>
......@@ -89,12 +90,13 @@ class ReflectionVTable {
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
* \brief Equality comparison function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
/*!
* \brief Structural hash reduction function.
*/
typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
......@@ -128,6 +130,13 @@ class ReflectionVTable {
*/
bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
/*!
* \brief Dispatch the SHashReduce function.
* \param self The pointer to the object.
* \param hash_reduce The hash reducer.
* \return the result.
*/
void SHashReduce(const Object* self, SHashReducer hash_reduce) const;
/*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
......@@ -162,7 +171,9 @@ class ReflectionVTable {
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Structural equal function. */
std::vector<FSEqualReduce> fsequal_;
std::vector<FSEqualReduce> fsequal_reduce_;
/*! \brief Structural hash function. */
std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
......@@ -280,10 +291,24 @@ struct ImplSEqualReduce<T, true> {
}
};
template<typename T,
bool = T::_type_has_method_shash_reduce>
struct ImplSHashReduce {
static constexpr const std::nullptr_t SHashReduce = nullptr;
};
template<typename T>
struct ImplSHashReduce<T, true> {
static void SHashReduce(const T* self, SHashReducer hash_reduce) {
self->SHashReduce(hash_reduce);
}
};
template<typename T>
struct ReflectionTrait :
public ImplVisitAttrs<T>,
public ImplSEqualReduce<T> {
public ImplSEqualReduce<T>,
public ImplSHashReduce<T> {
};
template<typename T, typename TraitName,
......@@ -315,6 +340,22 @@ struct SelectSEqualReduce<T, TraitName, false> {
equal);
}
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
struct SelectSHashReduce {
static constexpr const std::nullptr_t SHashReduce = nullptr;
};
template<typename T, typename TraitName>
struct SelectSHashReduce<T, TraitName, false> {
static void SHashReduce(const Object* self,
SHashReducer hash_reduce) {
return TraitName::SHashReduce(static_cast<const T*>(self),
hash_reduce);
}
};
} // namespace detail
template<typename T, typename TraitName>
......@@ -325,15 +366,19 @@ ReflectionVTable::Register() {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
fsequal_.resize(tindex + 1, nullptr);
fsequal_reduce_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
fvisit_attrs_[tindex] =
::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
fsequal_[tindex] =
fsequal_reduce_[tindex] =
::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
fshash_reduce_[tindex] =
::tvm::detail::SelectSHashReduce<T, TraitName>::SHashReduce;
return Registry(this, tindex);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \brief Structural hash class.
*/
#ifndef TVM_NODE_STRUCTURAL_HASH_H_
#define TVM_NODE_STRUCTURAL_HASH_H_
#include <tvm/runtime/data_type.h>
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <string>
#include <functional>
namespace tvm {
/*!
* \brief Hash definition of base value classes.
*/
class BaseValueHash {
public:
size_t operator()(const double& key) const {
return std::hash<double>()(key);
}
size_t operator()(const int64_t& key) const {
return std::hash<int64_t>()(key);
}
size_t operator()(const uint64_t& key) const {
return std::hash<uint64_t>()(key);
}
size_t operator()(const int& key) const {
return std::hash<int>()(key);
}
size_t operator()(const bool& key) const {
return std::hash<bool>()(key);
}
size_t operator()(const std::string& key) const {
return std::hash<std::string>()(key);
}
size_t operator()(const runtime::DataType& key) const {
return std::hash<int32_t>()(
static_cast<int32_t>(key.code()) |
(static_cast<int32_t>(key.bits()) << 8) |
(static_cast<int32_t>(key.lanes()) << 16));
}
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& key) const {
return std::hash<size_t>()(static_cast<size_t>(key));
}
};
/*!
* \brief Content-aware structural hasing.
*
* The structural hash value is recursively defined in the DAG of IRNodes.
* There are two kinds of nodes:
*
* - Normal node: the hash value is defined by its content and type only.
* - Graph node: each graph node will be assigned a unique index ordered by the
* first occurence during the visit. The hash value of a graph node is
* combined from the hash values of its contents and the index.
*/
class StructuralHash : public BaseValueHash {
public:
// inheritate operator()
using BaseValueHash::operator();
/*!
* \brief Compute structural hashing value for an object.
* \param key The left operand.
* \return The hash value.
*/
TVM_DLL size_t operator()(const ObjectRef& key) const;
};
/*!
* \brief A Reducer class to reduce the structural hash value.
*
* The reducer will call the SEqualHash function of each objects recursively.
*
* A SEqualHash function will make a sequence of calls to the reducer to
* indicate a sequence of child hash values that the reducer need to combine
* inorder to obtain the hash value of the hash value of the parent object.
*
* Importantly, the reducer may not directly use recursive calls
* to compute the hash values of child objects directly.
*
* Instead, it can store the necessary hash computing task into a stack
* and reduce the result later.
*/
class SHashReducer {
public:
/*! \brief Internal handler that defines custom behaviors. */
class Handler {
public:
/*!
* \brief Append hashed_value to the current sequence of hashes.
*
* \param hashed_value The hashed value
*/
virtual void SHashReduceHashedValue(size_t hashed_value) = 0;
/*!
* \brief Append hash value of key to the current sequence of hashes.
*
* \param key The object to compute hash from.
* \param map_free_vars Whether to map free variables by their occurence number.
*/
virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0;
/*!
* \brief Apppend a hash value of free variable to the current sequence of hashes.
*
* \param var The var of interest.
* \param map_free_vars Whether to map free variables by their occurence number.
*
* \note If map_free_vars is set to be true,
* internally the handler can maintain a counter to encode free variables
* by their order of occurence. This helps to resolve variable
* mapping of function parameters and let binding variables.
*
* If map_free_vars is set to be false, the address of the variable will be used.
*/
virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0;
/*!
* \brief Lookup a hash value for key
*
* \param key The hash key.
* \param hashed_value the result hash value
*
* \return Whether there is already a pre-computed hash value.
*/
virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0;
/*!
* \brief Mark current comparison as graph node in hashing.
* Graph node hash will depends on the graph structure.
*/
virtual void MarkGraphNode() = 0;
};
/*! \brief default constructor */
SHashReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param map_free_vars Whether to map free variables.
*/
explicit SHashReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
/*!
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
*/
template<typename T,
typename = typename std::enable_if<
!std::is_base_of<ObjectRef, T>::value>::type>
void operator()(const T& key) const {
// handle normal values.
handler_->SHashReduceHashedValue(BaseValueHash()(key));
}
/*!
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
*/
void operator()(const ObjectRef& key) const {
return handler_->SHashReduce(key, map_free_vars_);
}
/*!
* \brief Push hash of key to the current sequence of hash values.
* \param key The key to be hashed.
* \note This function indicate key could contain var defintions.
*/
void DefHash(const ObjectRef& key) const {
return handler_->SHashReduce(key, true);
}
/*!
* \brief Implementation for hash for a free var.
* \param var The variable.
* \return the result.
*/
void FreeVarHashImpl(const runtime::Object* var) const {
handler_->SHashReduceFreeVar(var, map_free_vars_);
}
/*! \return Get the internal handler. */
Handler* operator->() const {
return handler_;
}
private:
/*! \brief Internal class pointer. */
Handler* handler_;
/*!
* \brief Whether or not to map free variables by their occurence
* If the flag is false, then free variables will be mapped
* by their in-memory address.
*/
bool map_free_vars_;
};
} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_HASH_H_
......@@ -47,6 +47,7 @@ class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
......@@ -79,6 +80,9 @@ class PatternWildcardNode : public PatternNode {
return true;
}
void SHashReduce(SHashReducer hash_reduce) const {
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
......@@ -127,6 +131,10 @@ class PatternVarNode : public PatternNode {
return equal.DefEqual(var, other->var);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(var);
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
......@@ -164,6 +172,11 @@ class PatternConstructorNode : public PatternNode {
equal(patterns, other->patterns);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(constructor);
hash_reduce(patterns);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
......@@ -197,6 +210,10 @@ class PatternTupleNode : public PatternNode {
return equal(patterns, other->patterns);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(patterns);
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
......@@ -231,8 +248,14 @@ class ClauseNode : public Object {
return equal(lhs, other->lhs) && equal(rhs, other->rhs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(lhs);
hash_reduce(rhs);
}
static constexpr const char* _type_key = "relay.Clause";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
......@@ -280,6 +303,13 @@ class MatchNode : public ExprNode {
equal(complete, other->complete);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(data);
hash_reduce(clauses);
hash_reduce(complete);
}
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
......
......@@ -77,6 +77,10 @@ class ConstantNode : public ExprNode {
return equal(data, other->data);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(data);
}
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
......@@ -116,6 +120,13 @@ class TupleNode : public ExprNode {
}
}
void SHashReduce(SHashReducer hash_reduce) const {
if (fields.size() != 0) {
hash_reduce->MarkGraphNode();
hash_reduce(fields);
}
}
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
......@@ -178,6 +189,11 @@ class VarNode : public ExprNode {
equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(type_annotation);
hash_reduce.FreeVarHashImpl(this);
}
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
......@@ -269,6 +285,16 @@ class CallNode : public ExprNode {
(IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(op);
hash_reduce(args);
hash_reduce(attrs);
if (!IsPrimitiveOp(op)) {
hash_reduce(type_args);
}
}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
......@@ -328,6 +354,13 @@ class LetNode : public ExprNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(var);
hash_reduce(value);
hash_reduce(body);
}
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
......@@ -383,6 +416,13 @@ class IfNode : public ExprNode {
equal(false_branch, other->false_branch);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(cond);
hash_reduce(true_branch);
hash_reduce(false_branch);
}
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
......@@ -422,6 +462,11 @@ class TupleGetItemNode : public ExprNode {
equal(index, other->index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tuple);
hash_reduce(index);
}
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
......@@ -456,6 +501,11 @@ class RefCreateNode : public ExprNode {
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(value);
}
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
......@@ -489,6 +539,11 @@ class RefReadNode : public ExprNode {
return equal(ref, other->ref);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(ref);
}
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
......@@ -526,6 +581,12 @@ class RefWriteNode : public ExprNode {
equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(ref);
hash_reduce(value);
}
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
......@@ -568,6 +629,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
......
......@@ -79,6 +79,15 @@ class FunctionNode : public BaseFuncNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(params);
hash_reduce.DefHash(type_params);
hash_reduce(ret_type);
hash_reduce(attrs);
hash_reduce(body);
}
/*!
* \brief Return the derived function annotation of this expression.
*
......
......@@ -492,6 +492,26 @@ class String : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }
/*!
* \brief Hash the binary bytes
* \param data The data pointer
* \param size The size of the bytes.
* \return the hash value.
*/
static size_t HashBytes(const char* data, size_t size) {
// This function falls back to string copy with c++11 compiler and is
// recommended to be compiled with c++14
#if TVM_USE_CXX17_STRING_VIEW_HASH
return std::hash<std::string_view>()(
std::string_view(data, size));
#elif TVM_USE_CXX14_STRING_VIEW_HASH
return std::hash<std::experimental::string_view>()(
std::experimental::string_view(data, size));
#else
return std::hash<std::string>()(std::string(data, size));
#endif
}
TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
private:
......@@ -570,17 +590,7 @@ namespace std {
template <>
struct hash<::tvm::runtime::String> {
std::size_t operator()(const ::tvm::runtime::String& str) const {
// This function falls back to string copy with c++11 compiler and is
// recommended to be compiled with c++14
#if TVM_USE_CXX17_STRING_VIEW_HASH
return std::hash<std::string_view>{}(
std::string_view{str.data(), str.size()});
#elif TVM_USE_CXX14_STRING_VIEW_HASH
return std::hash<std::experimental::string_view>{}(
std::experimental::string_view{str.data(), str.size()});
#else
return std::hash<std::string>()(str.operator std::string());
#endif
return ::tvm::runtime::String::HashBytes(str.data(), str.size());
}
};
} // namespace std
......
......@@ -214,6 +214,7 @@ class Object {
// member information
static constexpr bool _type_has_method_visit_attrs = true;
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
......
......@@ -164,6 +164,17 @@ class BufferNode : public Object {
equal(buffer_type, other->buffer_type);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(data);
hash_reduce(dtype);
hash_reduce.DefHash(shape);
hash_reduce.DefHash(strides);
hash_reduce.DefHash(elem_offset);
hash_reduce(scope);
hash_reduce(data_alignment);
hash_reduce(buffer_type);
}
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
......@@ -184,6 +195,7 @@ class BufferNode : public Object {
static constexpr const char* _type_key = "Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};
......
......@@ -81,6 +81,12 @@ class VarNode : public PrimExprNode {
return equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(type_annotation);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
......@@ -302,12 +308,20 @@ class IterVarNode : public Object {
equal(thread_tag, other->thread_tag);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dom);
hash_reduce.DefHash(var);
hash_reduce(iter_type);
hash_reduce(thread_tag);
}
TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
};
......@@ -353,6 +367,10 @@ class StringImmNode : public PrimExprNode {
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
}
TVM_DLL PrimExpr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
......@@ -382,6 +400,11 @@ class CastNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}
TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
static constexpr const char* _type_key = "Cast";
......@@ -413,6 +436,12 @@ class BinaryOpNode : public PrimExprNode {
equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
hash_reduce(b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
......@@ -512,6 +541,12 @@ class CmpOpNode : public PrimExprNode {
equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
hash_reduce(b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
......@@ -583,6 +618,12 @@ class AndNode : public PrimExprNode {
equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
hash_reduce(b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "And";
......@@ -610,6 +651,12 @@ class OrNode : public PrimExprNode {
equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
hash_reduce(b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "Or";
......@@ -631,6 +678,11 @@ class NotNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(a, other->a);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
}
TVM_DLL static PrimExpr make(PrimExpr a);
static constexpr const char* _type_key = "Not";
......@@ -668,6 +720,13 @@ class SelectNode : public PrimExprNode {
equal(false_value, other->false_value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(condition);
hash_reduce(true_value);
hash_reduce(false_value);
}
TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
static constexpr const char* _type_key = "Select";
......@@ -713,6 +772,13 @@ class LoadNode : public PrimExprNode {
equal(predicate, other->predicate);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer_var);
hash_reduce(index);
hash_reduce(predicate);
}
TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
static constexpr const char* _type_key = "Load";
......@@ -752,6 +818,13 @@ class RampNode : public PrimExprNode {
equal(lanes, other->lanes);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(base);
hash_reduce(stride);
hash_reduce(lanes);
}
TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
......@@ -779,6 +852,12 @@ class BroadcastNode : public PrimExprNode {
equal(lanes, other->lanes);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
hash_reduce(lanes);
}
TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
......@@ -812,6 +891,13 @@ class LetNode : public PrimExprNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce.DefHash(var);
hash_reduce(value);
hash_reduce(body);
}
TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
static constexpr const char* _type_key = "Let";
......@@ -892,6 +978,15 @@ class CallNode : public PrimExprNode {
equal(value_index, other->value_index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(name);
hash_reduce(args);
hash_reduce(call_type);
hash_reduce(func);
hash_reduce(value_index);
}
TVM_DLL static PrimExpr make(DataType dtype,
std::string name,
Array<PrimExpr> args,
......@@ -967,6 +1062,12 @@ class ShuffleNode : public PrimExprNode {
equal(indices, other->indices);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(vectors);
hash_reduce(indices);
}
TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
......@@ -1037,8 +1138,16 @@ class CommReducerNode : public Object {
equal(identity_element, other->identity_element);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(lhs);
hash_reduce.DefHash(rhs);
hash_reduce(result);
hash_reduce(identity_element);
}
static constexpr const char* _type_key = "CommReducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
};
......@@ -1092,6 +1201,16 @@ class ReduceNode : public PrimExprNode {
equal(condition, other->condition) &&
equal(value_index, other->value_index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(axis);
hash_reduce(combiner);
hash_reduce(source);
hash_reduce(condition);
hash_reduce(value_index);
}
static constexpr const char* _type_key = "Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
......@@ -1105,6 +1224,9 @@ class AnyNode : public PrimExprNode {
return true;
}
void SHashReduce(SHashReducer hash_reduce) const {
}
/*! \brief Convert to var. */
Var ToVar() const {
return Var("any_dim", DataType::Int(32));
......
......@@ -112,6 +112,13 @@ class PrimFuncNode : public BaseFuncNode {
equal(attrs, other->attrs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(buffer_map);
hash_reduce(ret_type);
hash_reduce(body);
hash_reduce(attrs);
}
/*!
* \brief Return the derived function annotation of this function.
*
......@@ -122,7 +129,6 @@ class PrimFuncNode : public BaseFuncNode {
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
......
......@@ -39,6 +39,7 @@ class StmtNode : public Object {
public:
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
......@@ -73,6 +74,12 @@ class LetStmtNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(var);
hash_reduce(value);
hash_reduce(body);
}
TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
......@@ -115,6 +122,13 @@ class AttrStmtNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(node);
hash_reduce(attr_key);
hash_reduce(value);
hash_reduce(body);
}
TVM_DLL static Stmt make(ObjectRef node,
std::string type_key,
PrimExpr value,
......@@ -152,6 +166,12 @@ class AssertStmtNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(condition);
hash_reduce(message);
hash_reduce(body);
}
TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
......@@ -182,6 +202,12 @@ class ProducerConsumerNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(is_producer);
hash_reduce(body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
......@@ -232,6 +258,13 @@ class StoreNode : public StmtNode {
equal(predicate, other->predicate);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_var);
hash_reduce(value);
hash_reduce(index);
hash_reduce(predicate);
}
TVM_DLL static Stmt make(Var buffer_var,
PrimExpr value,
PrimExpr index,
......@@ -270,6 +303,13 @@ class ProvideNode : public StmtNode {
equal(args, other->args);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(value);
hash_reduce(args);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
PrimExpr value,
......@@ -316,6 +356,14 @@ class AllocateNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(buffer_var);
hash_reduce(dtype);
hash_reduce(extents);
hash_reduce(condition);
hash_reduce(body);
}
TVM_DLL static Stmt make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
......@@ -360,6 +408,10 @@ class FreeNode : public StmtNode {
equal(buffer_var, other->buffer_var);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
......@@ -411,6 +463,15 @@ class RealizeNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
}
static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};
......@@ -443,6 +504,10 @@ class SeqStmtNode : public StmtNode {
return equal(seq, other->seq);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(seq);
}
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
......@@ -553,6 +618,12 @@ class IfThenElseNode : public StmtNode {
equal(else_case, other->else_case);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(condition);
hash_reduce(then_case);
hash_reduce(else_case);
}
TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
......@@ -578,6 +649,10 @@ class EvaluateNode : public StmtNode {
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
}
TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate";
......@@ -657,6 +732,16 @@ class ForNode : public StmtNode {
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(loop_var);
hash_reduce(min);
hash_reduce(extent);
hash_reduce(for_type);
hash_reduce(device_api);
hash_reduce(body);
}
static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
......@@ -690,6 +775,13 @@ class PrefetchNode : public StmtNode {
equal(bounds, other->bounds);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(value_index);
hash_reduce(dtype);
hash_reduce(bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
......
......@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .base import structural_equal, assert_structural_equal
from .base import structural_equal, assert_structural_equal, structural_hash
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
......
......@@ -192,9 +192,14 @@ def structural_equal(lhs, rhs, map_free_vars=False):
------
result : bool
The comparison result.
See Also
--------
structural_hash
assert_strucural_equal
"""
return tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars))
def assert_structural_equal(lhs, rhs, map_free_vars=False):
......@@ -222,3 +227,45 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)
def structural_hash(node, map_free_vars=False):
"""Compute structural hash of node
The structural hash value is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Normal node: the hash value is defined by its content and type only.
- Graph node: each graph node will be assigned a unique index ordered by the
first occurence during the visit. The hash value of a graph node is
combined from the hash values of its contents and the index.
structural_hash is made to be concistent with structural_equal.
If two nodes are structurally equal to each other,
then their structural hash (with the same map_free_vars option)
should be equal to each other as well.
If the structural hash of two nodes equals to each other,
then it is highly likely(except for rare hash value collison cases)
that the two nodes are structurally equal to each other.
Parameters
----------
node : Object
The input to be hashed.
map_free_vars : bool
If map_free_vars is set to true, we will hash free variables
by the order of their occurences. Otherwise, we will hash by
their in-memory pointer address.
Return
------
result : int
The hash result
See Also
--------
structrual_equal
"""
return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars)
......@@ -65,7 +65,6 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
data_ = std::move(n);
}
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
for (const auto& kv : this->functions) {
......@@ -80,6 +79,37 @@ bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal)
return true;
}
void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
using KV = std::pair<std::string, ObjectRef>;
// hash the functions.
std::vector<KV> temp;
auto reduce_temp = [&]() {
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.first < rhs.first;
});
hash_reduce(static_cast<uint64_t>(temp.size()));
// hash the content
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce(temp[i].first);
hash_reduce(temp[i].second);
}
};
for (const auto& kv : this->functions) {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
temp.clear();
for (const auto& kv : this->type_definitions) {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
}
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
......
......@@ -32,6 +32,12 @@ namespace tvm {
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::StringObj* key,
SHashReducer hash_reduce) {
hash_reduce->SHashReduceHashedValue(
runtime::String::HashBytes(key->data, key->size));
}
static bool SEqualReduce(const runtime::StringObj* lhs,
const runtime::StringObj* rhs,
SEqualReducer equal) {
......@@ -47,6 +53,15 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::ADTObj* key,
SHashReducer hash_reduce) {
hash_reduce(key->tag);
hash_reduce(static_cast<uint64_t>(key->size));
for (uint32_t i = 0; i < key->size; ++i) {
hash_reduce((*key)[i]);
}
}
static bool SEqualReduce(const runtime::ADTObj* lhs,
const runtime::ADTObj* rhs,
SEqualReducer equal) {
......@@ -67,6 +82,22 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::NDArray::Container* key,
SHashReducer hash_reduce) {
CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(key->dl_tensor))
<< "Can only hash contiguous tensor";
hash_reduce(runtime::DataType(key->dl_tensor.dtype));
hash_reduce(key->dl_tensor.ndim);
for (int i = 0; i < key->dl_tensor.ndim; ++i) {
hash_reduce(key->dl_tensor.shape[i]);
}
hash_reduce->SHashReduceHashedValue(
runtime::String::HashBytes(
static_cast<const char*>(key->dl_tensor.data),
runtime::GetDataSize(key->dl_tensor)));
}
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
......@@ -80,6 +111,11 @@ struct NDArrayContainerTrait {
<< "Can only compare contiguous tensor";
CHECK(runtime::IsContiguous(rhs->dl_tensor))
<< "Can only compare contiguous tensor";
if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
}
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
......@@ -95,6 +131,14 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrai
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const ArrayNode* key,
SHashReducer hash_reduce) {
hash_reduce(static_cast<uint64_t>(key->data.size()));
for (size_t i = 0; i < key->data.size(); ++i) {
hash_reduce(key->data[i]);
}
}
static bool SEqualReduce(const ArrayNode* lhs,
const ArrayNode* rhs,
SEqualReducer equal) {
......@@ -153,6 +197,40 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const MapNode* key,
SHashReducer hash_reduce) {
// SHash's var handling depends on the determinism of traversal.
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
using KV = std::pair<size_t, ObjectRef>;
std::vector<KV> temp;
for (const auto& kv : key->data) {
size_t hashed_value;
if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) {
temp.emplace_back(hashed_value, kv.second);
}
}
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.first < rhs.first;
});
// add size to the hash
hash_reduce(static_cast<uint64_t>(key->data.size()));
// hash the content
for (size_t i = 0; i < temp.size();) {
size_t k = i + 1;
for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {}
// ties are rare, but we need to skip them to make the hash determinsitic
if (k == i + 1) {
hash_reduce->SHashReduceHashedValue(temp[i].first);
hash_reduce(temp[i].second);
}
i = k;
}
}
static bool SEqualReduce(const MapNode* lhs,
const MapNode* rhs,
SEqualReducer equal) {
......@@ -182,6 +260,28 @@ TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
struct StrMapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const StrMapNode* key,
SHashReducer hash_reduce) {
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
using KV = std::pair<std::string, ObjectRef>;
std::vector<KV> temp(key->data.begin(), key->data.end());
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.first < rhs.first;
});
// NOTE: we won't have ties
// add size to the hash after sorting.
hash_reduce(static_cast<uint64_t>(key->data.size()));
// hash the content
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce(temp[i].first);
hash_reduce(temp[i].second);
}
}
static bool SEqualReduce(const StrMapNode* lhs,
const StrMapNode* rhs,
SEqualReducer equal) {
......
......@@ -33,11 +33,11 @@ namespace tvm {
bool ReflectionVTable::
SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
uint32_t tindex = self->type_index();
if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fsequal_[tindex](self, other, equal);
return fsequal_reduce_[tindex](self, other, equal);
}
/*!
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/node/structural_hash.cc
*/
#include <tvm/node/structural_hash.h>
#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
#include <algorithm>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
void ReflectionVTable::
SHashReduce(const Object* self, SHashReducer reducer) const {
uint32_t tindex = self->type_index();
if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
fshash_reduce_[tindex](self, reducer);
}
// Hash handler that handles free vars
// by assigning an unique counter in the order of their ocurrence.
//
// This algorithm depends on the determinism of the traversal of SHash function.
// In particular, when we traverse unordered_map, we should first sort
// the entries by keys(or hash of keys) before traversing.
class VarCountingSHashHandler :
public SHashReducer::Handler {
public:
/*! \brief Pending reduce tasks. */
struct Task {
/*!
* \brief The object operand to be hashed.
* If the object is nullptr, then the reduced hash is already set
* the correct value.
*/
ObjectRef object;
/*! \biref The partially reduce hash value.*/
size_t reduced_hash;
/*! \brief The expected location in the result stack. */
size_t result_stack_index = std::numeric_limits<size_t>::max();
/*! \brief Whether the children has been expanded via SEqualReduce */
bool children_expanded{false};
/*! \brief Whether the node is graph node. */
bool graph_node_hash{false};
/*! \brief whether to map the free variables. */
bool map_free_vars;
Task() = default;
explicit Task(ObjectRef object, size_t reduced_hash, bool map_free_vars)
: object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {}
};
VarCountingSHashHandler() {}
void MarkGraphNode() final {
// need to push to pending tasks in this case
CHECK(!allow_push_to_stack_ && !task_stack_.empty());
task_stack_.back().graph_node_hash = true;
}
bool LookupHashedValue(const ObjectRef& key, size_t* hash_value) final {
auto it = hash_memo_.find(key);
if (it != hash_memo_.end()) {
hash_value[0] = it->second;
return true;
}
return false;
}
void SHashReduceHashedValue(size_t hashed_value) final {
pending_tasks_.emplace_back(
Task(ObjectRef(nullptr), hashed_value, false));
}
void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final {
CHECK(!hash_memo_.count(GetRef<ObjectRef>(var)));
if (map_free_vars) {
// use counter value.
size_t value = std::hash<size_t>()(free_var_counter_++);
pending_tasks_.emplace_back(
Task(ObjectRef(nullptr), value, false));
} else {
// use pointer hash
size_t value = std::hash<const runtime::Object*>()(var);
pending_tasks_.emplace_back(
Task(ObjectRef(nullptr), value, false));
}
}
void SHashReduce(const ObjectRef& object, bool map_free_vars) final {
// Directly push the result
// Note: it is still important to push the result to pendng tasks
// so that the reduction order of hash values stays the same.
if (!object.defined()) {
pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false));
return;
}
auto it = hash_memo_.find(object);
if (it != hash_memo_.end()) {
pending_tasks_.emplace_back(
Task(ObjectRef(nullptr), it->second, false));
} else {
// Push a pending task with initial value.
pending_tasks_.emplace_back(
Task(object, object->GetTypeKeyHash(), map_free_vars));
}
}
size_t Hash(const ObjectRef& object, bool map_free_vars) {
CHECK_EQ(task_stack_.size(), 0U);
CHECK_EQ(pending_tasks_.size(), 0U);
CHECK_EQ(result_stack_.size(), 0U);
this->SHashReduce(object, map_free_vars);
CHECK_EQ(pending_tasks_.size(), 1U);
CHECK(allow_push_to_stack_);
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.clear();
this->RunTasks();
CHECK_EQ(result_stack_.size(), 1U);
size_t ret = result_stack_.back();
result_stack_.pop_back();
return ret;
}
protected:
/*!
* \brief Pop the top entry of the task stack and push the hash into the result stack.
*/
void PopTaskStack() {
const auto& entry = task_stack_.back();
result_stack_.push_back(entry.reduced_hash);
task_stack_.pop_back();
}
/*!
* \brief Compute the reduced hash value for the task.
* \param task The indicated task.
*/
size_t ReduceHash(const Task& task) {
size_t stack_begin = task.result_stack_index;
CHECK_LE(stack_begin, result_stack_.size());
// combine in the reverse order of the stack.
size_t reduced_hash = task.reduced_hash;
for (size_t i = result_stack_.size(); i != stack_begin; --i) {
reduced_hash = HashCombine(reduced_hash, result_stack_[i - 1]);
}
result_stack_.resize(stack_begin);
return reduced_hash;
}
// run the tasks.
void RunTasks() {
while (task_stack_.size() != 0) {
// Caution: entry becomes invalid when the stack changes
auto& entry = task_stack_.back();
if (entry.children_expanded) {
// reduce hash
entry.reduced_hash = ReduceHash(entry);
// When all the children has expanded and visited.
// entry.reduced_hash contains the reduced hash result.
auto it = hash_memo_.find(entry.object);
if (it != hash_memo_.end()) {
// use the pre-computed hash for the object.
entry.reduced_hash = it->second;
} else {
// Append the graph node counter to the hash
// so that we can distinguish DAG from trees.
if (entry.graph_node_hash) {
entry.reduced_hash = HashCombine(
entry.reduced_hash,
std::hash<size_t>()(graph_node_counter_++));
}
hash_memo_[entry.object] = entry.reduced_hash;
}
// send value to parent.
this->PopTaskStack();
} else if (!entry.object.defined()) {
// Directly send value to parent
this->PopTaskStack();
} else {
// check if there are already hash for object.
auto it = hash_memo_.find(entry.object);
if (it != hash_memo_.end()) {
entry.reduced_hash = it->second;
this->PopTaskStack();
} else {
// NOTE: important to modify entry before visit.
// as entry becomes invalid after we change the stack.
entry.children_expanded = true;
entry.result_stack_index = result_stack_.size();
CHECK_EQ(pending_tasks_.size(), 0U);
allow_push_to_stack_ = false;
// dispatch hash, reduce to the current slot.
this->DispatchSHash(entry.object, entry.map_free_vars);
allow_push_to_stack_ = true;
// Move pending tasks to the stack until the marked point.
while (pending_tasks_.size() != 0) {
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.pop_back();
}
}
}
}
}
// The default equal as registered in the structural equal vtable.
void DispatchSHash(const ObjectRef& object, bool map_free_vars) {
CHECK(object.defined());
vtable_->SHashReduce(object.get(), SHashReducer(this, map_free_vars));
}
/*!
* \brief Combine two hash values into a single one.
* \param key The left operand.
* \param value The right operand.
* \return the combined result.
*/
size_t HashCombine(size_t key, size_t value) {
return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2));
}
private:
// free var counter.
size_t free_var_counter_{0};
// graph node counter.
size_t graph_node_counter_{0};
// record current stack top
bool allow_push_to_stack_{true};
// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task
std::vector<Task> task_stack_;
// Internal stack to store the result poped from the task stack.
std::vector<size_t> result_stack_;
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, size_t, ObjectHash, ObjectEqual> hash_memo_;
};
TVM_REGISTER_GLOBAL("node.StructuralHash")
.set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t {
size_t hashed_value =
VarCountingSHashHandler().Hash(object, map_free_vars);
return static_cast<int64_t>(hashed_value);
});
size_t StructuralHash::operator()(const ObjectRef& object) const {
return VarCountingSHashHandler().Hash(object, false);
}
} // namespace tvm
......@@ -15,10 +15,32 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
import pytest
from tvm import te
def consistent_equal(x, y, map_free_vars=False):
struct_equal0 = tvm.ir.structural_equal(x, y, map_free_vars)
struct_equal1 = tvm.ir.structural_equal(y, x, map_free_vars)
xhash = tvm.ir.structural_hash(x, map_free_vars)
yhash = tvm.ir.structural_hash(y, map_free_vars)
if struct_equal0 != struct_equal1:
raise ValueError(
"Non-communicative {} vs {}, sequal0={}, sequal1={}".format(
x, y, struct_equal0, struct_equal1))
# NOTE: hash colision can happen but should be rare.
# we can confirm that hash colison doesn't happen for our testcases
if struct_equal0 != (xhash == yhash):
raise ValueError(
"Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
x, y, struct_equal0, xhash, yhash))
return struct_equal0
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
......@@ -26,34 +48,35 @@ def test_exprs():
vx = te.var("x")
vy = te.var("y")
vz = te.var("z")
zx = vx + vx
zy = vy + vy
assert consistent_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
# test assert trigger.
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(x, y)
assert not tvm.ir.structural_equal(vx, vy)
assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
assert not consistent_equal(vx, vy)
assert consistent_equal(vx, vy, map_free_vars=True)
# corner case lhs:vx == rhs:vy, but cannot map it iteslf
assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
assert not consistent_equal(vx + vx, vy + vx, map_free_vars=True)
# corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
assert consistent_equal(vx + vy, vy + vx, map_free_vars=True)
# corner case2: rolling remap.
assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
assert consistent_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not consistent_equal(vx + 1, vy + 1, map_free_vars=False)
# Defintition remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
tvm.tir.Let(vy, 1, vy - 1))
assert consistent_equal(tvm.tir.Let(vx, 1, vx - 1),
tvm.tir.Let(vy, 1, vy - 1))
# Default same address free var remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
tvm.tir.Let(vy, 1, vy // vz))
assert consistent_equal(tvm.tir.Let(vx, 1, vx // vz),
tvm.tir.Let(vy, 1, vy // vz))
zx = vx + vx
zy = vy + vy
assert tvm.ir.structural_equal(zx * zx, zx * zx)
assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
assert consistent_equal(zx * zx, zx * zx)
assert consistent_equal(zx * zx, zy * zy, map_free_vars=True)
assert not consistent_equal(zx * zx, zy * zy, map_free_vars=False)
def test_prim_func():
......@@ -64,7 +87,7 @@ def test_prim_func():
[x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(y + x))
assert not tvm.ir.structural_equal(func0, func1)
assert not consistent_equal(func0, func1)
# new cases
b = tvm.tir.decl_buffer((x,), "float32")
......@@ -86,17 +109,42 @@ def test_prim_func():
mod1 = tvm.IRModule.from_expr(func1)
tvm.ir.assert_structural_equal(mod0, mod1)
def test_array():
x = np.arange(10)
nx = tvm.nd.array(x)
ny = tvm.nd.array(x)
nz = tvm.nd.array(x.reshape(2, 5))
assert consistent_equal(nx, ny)
assert not consistent_equal(nx, nz)
def test_env_func():
@tvm.register_func("test.sequal.env_func")
def test(x):
return x + 1
x = tvm.ir.EnvFunc.get("test.sequal.env_func")
y = tvm.ir.EnvFunc.get("test.sequal.env_func")
assert consistent_equal(y, x)
def test_attrs():
x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
tvm.ir.assert_structural_equal(y, x)
assert not tvm.ir.structural_equal(y, z)
assert not consistent_equal(y, z)
x = tvm.runtime.convert({"x": [1, 2, 3], "y": 2})
y = tvm.runtime.convert({"y": 2, "x": [1, 2, 3]})
z = tvm.runtime.convert({"y": 2, "x": [1, 2, 3, 4]})
assert consistent_equal(y, x)
assert not consistent_equal(y, z)
if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
test_array()
test_env_func()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment