Unverified Commit 997a14ed by Tianqi Chen Committed by GitHub

[NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)

* [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.

This PR introduces a new way to handle structural equality
for both TIR and relay nodes in an extensive way.

- Each object can now register an optional SEqualReduce function, which
  describes how to reduce its structural equality to another instance
  into equality of the children.
- Optionally, the object can choose to allow remapping of vars(e.g. function parameters)
  by calling DefEqual
- We implemented a non-recursive structural equality checker that
  recursively traverses the objects and does the structural equality checking.

This PR also fixes a few potential problems in previous relay's AlphaEqual.

- In particular, the new structural equality relation will be communicative.
- It is can be dangerous to use same_as relation to quickly check equality,
  demonstrated by the following case. (%x, %y) are shared vars between two functions.

- function0: fn (%x, %y) { %x + %y }
- function1: fn (%y, %x) { %x + %y }

The new structural equal is intented to supersede AlphaEqual and AttrsEqual.

Follow-up PRs should be performed to redirect the existing usages, and removes
the corresponding implementation.

* Update the rule to distinguish between graph node and non-graph nodes.

* Refactor the test cases to use structural equal.

* address comments

* Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal)

* Remove unrelated comment

* Fix file comment

* Address review comment

* Relax condition to fit flaky case
parent 9c806621
...@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object { ...@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
v->Visit("max_value", &max_value); v->Visit("max_value", &max_value);
} }
bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
}
/*! \brief Number to represent +inf */ /*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max(); static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*! /*!
...@@ -170,6 +174,10 @@ class ModularSetNode : public Object { ...@@ -170,6 +174,10 @@ class ModularSetNode : public Object {
v->Visit("base", &base); v->Visit("base", &base);
} }
bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
return equal(coeff, other->coeff) && equal(base, other->base);
}
static constexpr const char* _type_key = "arith.ModularSet"; static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
}; };
......
...@@ -59,6 +59,7 @@ enum SignType { ...@@ -59,6 +59,7 @@ enum SignType {
class IntSetNode : public Object { class IntSetNode : public Object {
public: public:
static constexpr const char* _type_key = "IntSet"; static constexpr const char* _type_key = "IntSet";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
}; };
......
...@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode { ...@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
return
equal(name_hint, other->name_hint) &&
equal(inputs, other->inputs);
}
static constexpr const char* _type_key = "relay.Constructor"; static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
}; };
...@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode { ...@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
return
equal.DefEqual(header, other->header) &&
equal.DefEqual(type_vars, other->type_vars) &&
equal(constructors, other->constructors);
}
static constexpr const char* _type_key = "relay.TypeData"; static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
}; };
......
...@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object { ...@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
v->Visit("type_info", &type_info); v->Visit("type_info", &type_info);
v->Visit("description", &description); v->Visit("description", &description);
} }
static constexpr const char* _type_key = "AttrFieldInfo"; static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
}; };
...@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object { ...@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
*/ */
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "Attrs"; static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
}; };
...@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
/*! \brief internal attrs map */ /*! \brief internal attrs map */
Map<std::string, ObjectRef> dict; Map<std::string, ObjectRef> dict;
bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
}
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
...@@ -401,6 +408,33 @@ class AttrsEqualVisitor { ...@@ -401,6 +408,33 @@ class AttrsEqualVisitor {
const AttrsEqual& equal_; const AttrsEqual& equal_;
}; };
class AttrsSEqualVisitor {
public:
bool result_{true};
// constructor
AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Object* lhs_;
const Object* rhs_;
const SEqualReducer& equal_;
};
class AttrsHashVisitor { class AttrsHashVisitor {
public: public:
explicit AttrsHashVisitor(const AttrsHash& hasher) explicit AttrsHashVisitor(const AttrsHash& hasher)
...@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode { ...@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
} }
} }
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
DerivedType* pself = self();
::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
Array<AttrFieldInfo> ListFieldInfo() const final { Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor; ::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
......
...@@ -51,7 +51,12 @@ class EnvFuncNode : public Object { ...@@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
v->Visit("name", &name); v->Visit("name", &name);
} }
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
}
static constexpr const char* _type_key = "EnvFunc"; static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
}; };
......
...@@ -43,6 +43,7 @@ namespace tvm { ...@@ -43,6 +43,7 @@ namespace tvm {
class BaseExprNode : public Object { class BaseExprNode : public Object {
public: public:
static constexpr const char* _type_key = "Expr"; static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
}; };
...@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode { ...@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
// name matters for global var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalVar"; static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
}; };
...@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode { ...@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "IntImm"; static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
}; };
...@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode { ...@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "FloatImm"; static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
}; };
...@@ -353,7 +369,12 @@ class RangeNode : public Object { ...@@ -353,7 +369,12 @@ class RangeNode : public Object {
v->Visit("extent", &extent); v->Visit("extent", &extent);
} }
bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
return equal(min, other->min) && equal(extent, other->extent);
}
static constexpr const char* _type_key = "Range"; static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
}; };
......
...@@ -62,6 +62,8 @@ class IRModuleNode : public Object { ...@@ -62,6 +62,8 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_);
} }
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
* \param var The var of the global function. * \param var The var of the global function.
...@@ -235,6 +237,7 @@ class IRModuleNode : public Object { ...@@ -235,6 +237,7 @@ class IRModuleNode : public Object {
TVM_DLL std::unordered_set<std::string> Imports() const; TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "IRModule"; static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private: private:
......
...@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode { ...@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
v->Visit("support_level", &support_level); v->Visit("support_level", &support_level);
} }
bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
// pointer equality is fine as there is only one op with the same name.
return this == other;
}
/*! /*!
* \brief Check that if current op is a "primtive operator". * \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single * That is the arguments are all type variables, and there is a single
......
...@@ -44,6 +44,10 @@ class SourceNameNode : public Object { ...@@ -44,6 +44,10 @@ class SourceNameNode : public Object {
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
return equal(name, other->name);
}
static constexpr const char* _type_key = "SourceName"; static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
}; };
...@@ -87,6 +91,13 @@ class SpanNode : public Object { ...@@ -87,6 +91,13 @@ class SpanNode : public Object {
v->Visit("col_offset", &col_offset); v->Visit("col_offset", &col_offset);
} }
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
return
equal(source, other->source) &&
equal(lineno, other->lineno) &&
equal(col_offset, other->col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset); TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "Span"; static constexpr const char* _type_key = "Span";
......
...@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
return
equal(shape, other->shape) &&
equal(dtype, other->dtype);
}
/*! \brief Return product of elements in the shape. /*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/ */
......
...@@ -111,6 +111,7 @@ class PassContextNode : public Object { ...@@ -111,6 +111,7 @@ class PassContextNode : public Object {
} }
static constexpr const char* _type_key = "transform.PassContext"; static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
}; };
...@@ -207,6 +208,7 @@ class PassInfoNode : public Object { ...@@ -207,6 +208,7 @@ class PassInfoNode : public Object {
} }
static constexpr const char* _type_key = "transform.PassInfo"; static constexpr const char* _type_key = "transform.PassInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
}; };
......
...@@ -79,6 +79,7 @@ class TypeNode : public Object { ...@@ -79,6 +79,7 @@ class TypeNode : public Object {
mutable Span span; mutable Span span;
static constexpr const char* _type_key = "Type"; static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
}; };
...@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode { ...@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
} }
bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
static constexpr const char* _type_key = "PrimType"; static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
}; };
...@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode { ...@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
v->Visit("element_type", &element_type); v->Visit("element_type", &element_type);
} }
bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}
static constexpr const char* _type_key = "PointerType"; static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
}; };
...@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode { ...@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "TypeVar"; static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
}; };
...@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode { ...@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind); v->Visit("kind", &kind);
} }
bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalTypeVar"; static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
}; };
...@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode { ...@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}
static constexpr const char* _type_key = "TupleType"; static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
}; };
...@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode { ...@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
return
equal.DefEqual(type_params, other->type_params) &&
equal(arg_types, other->arg_types) &&
equal(ret_type, other->ret_type) &&
equal(type_constraints, other->type_constraints);
}
static constexpr const char* _type_key = "FuncType"; static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
}; };
...@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode { ...@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
}
static constexpr const char* _type_key = "IncompleteType"; static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
}; };
...@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode { ...@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
// Keep the relay prefix in the type as this type is specific // Keep the relay prefix in the type as this type is specific
// to the relay itself. // to the relay itself.
static constexpr const char* _type_key = "relay.RefType"; static constexpr const char* _type_key = "relay.RefType";
......
...@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode { ...@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args);
}
static constexpr const char* _type_key = "TypeCall"; static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
}; };
...@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args) &&
equal(num_inputs, other->num_inputs) &&
equal(attrs, other->attrs);
}
static constexpr const char* _type_key = "TypeRelation"; static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
}; };
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#ifndef TVM_NODE_CONTAINER_H_ #ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
...@@ -34,15 +36,19 @@ ...@@ -34,15 +36,19 @@
namespace tvm { namespace tvm {
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
using runtime::make_object;
using runtime::ObjectHash;
using runtime::ObjectEqual;
/*! \brief array node content in array */ /*! \brief array node content in array */
class ArrayNode : public Object { class ArrayNode : public Object {
public: public:
/*! \brief the data content */ /*! \brief the data content */
std::vector<ObjectRef> data; std::vector<ObjectRef> data;
void VisitAttrs(AttrVisitor* visitor) {
}
static constexpr const char* _type_key = "Array"; static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
}; };
...@@ -50,9 +56,6 @@ class ArrayNode : public Object { ...@@ -50,9 +56,6 @@ class ArrayNode : public Object {
/*! \brief map node content */ /*! \brief map node content */
class MapNode : public Object { class MapNode : public Object {
public: public:
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map< using ContainerType = std::unordered_map<
ObjectRef, ObjectRef,
...@@ -73,9 +76,6 @@ class StrMapNode : public Object { ...@@ -73,9 +76,6 @@ class StrMapNode : public Object {
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>; using ContainerType = std::unordered_map<std::string, ObjectRef>;
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief the data content */ /*! \brief the data content */
ContainerType data; ContainerType data;
......
...@@ -39,6 +39,8 @@ ...@@ -39,6 +39,8 @@
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h> #include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h> #include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -29,13 +29,14 @@ ...@@ -29,13 +29,14 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h> #include <tvm/runtime/data_type.h>
#include <tvm/node/structural_equal.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <type_traits>
namespace tvm { namespace tvm {
// forward declaration
using runtime::Object; using runtime::Object;
using runtime::ObjectPtr; using runtime::ObjectPtr;
using runtime::ObjectRef; using runtime::ObjectRef;
...@@ -87,6 +88,13 @@ class ReflectionVTable { ...@@ -87,6 +88,13 @@ class ReflectionVTable {
*/ */
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*! /*!
* \brief Equality comparison function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
/*!
* \brief creator function. * \brief creator function.
* \param global_key Key that identifies a global single object. * \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object. * If this is not empty then FGlobalKey must be defined for the object.
...@@ -112,6 +120,14 @@ class ReflectionVTable { ...@@ -112,6 +120,14 @@ class ReflectionVTable {
*/ */
inline std::string GetGlobalKey(Object* self) const; inline std::string GetGlobalKey(Object* self) const;
/*! /*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
* \param other The pointer to another object to be compared.
* \param equal The equality comparator.
* \return the result.
*/
bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
/*!
* \brief Create an initial object using default constructor * \brief Create an initial object using default constructor
* by type_key and global key. * by type_key and global key.
* *
...@@ -139,12 +155,14 @@ class ReflectionVTable { ...@@ -139,12 +155,14 @@ class ReflectionVTable {
TVM_DLL static ReflectionVTable* Global(); TVM_DLL static ReflectionVTable* Global();
class Registry; class Registry;
template<typename T> template<typename T, typename TraitName>
inline Registry Register(); inline Registry Register();
private: private:
/*! \brief Attribute visitor. */ /*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_; std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Structural equal function. */
std::vector<FSEqualReduce> fsequal_;
/*! \brief Creation function. */ /*! \brief Creation function. */
std::vector<FCreate> fcreate_; std::vector<FCreate> fcreate_;
/*! \brief Global key function. */ /*! \brief Global key function. */
...@@ -182,6 +200,44 @@ class ReflectionVTable::Registry { ...@@ -182,6 +200,44 @@ class ReflectionVTable::Registry {
uint32_t type_index_; uint32_t type_index_;
}; };
#define TVM_REFLECTION_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \
__make_reflectiion
/*!
* \brief Directly register reflection VTable.
* \param TypeName The name of the type.
* \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce.
*
* \code
*
* // Example SEQualReduce traits for runtime StringObj.
*
* struct StringObjTrait {
* static constexpr const std::nullptr_t VisitAttrs = nullptr;
*
* static bool SEqualReduce(const runtime::StringObj* lhs,
* const runtime::StringObj* rhs,
* SEqualReducer equal) {
* if (lhs == rhs) return true;
* if (lhs->size != rhs->size) return false;
* if (lhs->data != rhs->data) return true;
* return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
* }
* };
*
* TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
*
* \endcode
*
* \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
* And can be used to register the related reflection functions for runtime objects.
*/
#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
/*! /*!
* \brief Register a node type to object registry and reflection registry. * \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type. * \param TypeName The name of the type.
...@@ -189,15 +245,79 @@ class ReflectionVTable::Registry { ...@@ -189,15 +245,79 @@ class ReflectionVTable::Registry {
*/ */
#define TVM_REGISTER_NODE_TYPE(TypeName) \ #define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \ TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \ TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
__make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \
.set_creator([](const std::string&) -> ObjectPtr<Object> { \ .set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \ return ::tvm::runtime::make_object<TypeName>(); \
}) })
// Implementation details // Implementation details
namespace detail {
template<typename T,
bool = T::_type_has_method_visit_attrs>
struct ImplVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T> template<typename T>
struct ImplVisitAttrs<T, true> {
static void VisitAttrs(T* self, AttrVisitor* v) {
self->VisitAttrs(v);
}
};
template<typename T,
bool = T::_type_has_method_sequal_reduce>
struct ImplSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T>
struct ImplSEqualReduce<T, true> {
static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
return self->SEqualReduce(other, equal);
}
};
template<typename T>
struct ReflectionTrait :
public ImplVisitAttrs<T>,
public ImplSEqualReduce<T> {
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
struct SelectVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T, typename TraitName>
struct SelectVisitAttrs<T, TraitName, false> {
static void VisitAttrs(Object* self, AttrVisitor* v) {
TraitName::VisitAttrs(static_cast<T*>(self), v);
}
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
struct SelectSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T, typename TraitName>
struct SelectSEqualReduce<T, TraitName, false> {
static bool SEqualReduce(const Object* self,
const Object* other,
SEqualReducer equal) {
return TraitName::SEqualReduce(static_cast<const T*>(self),
static_cast<const T*>(other),
equal);
}
};
} // namespace detail
template<typename T, typename TraitName>
inline ReflectionVTable::Registry inline ReflectionVTable::Registry
ReflectionVTable::Register() { ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex(); uint32_t tindex = T::RuntimeTypeIndex();
...@@ -205,15 +325,15 @@ ReflectionVTable::Register() { ...@@ -205,15 +325,15 @@ ReflectionVTable::Register() {
fvisit_attrs_.resize(tindex + 1, nullptr); fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr); fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr); fglobal_key_.resize(tindex + 1, nullptr);
fsequal_.resize(tindex + 1, nullptr);
} }
// functor that implemnts the redirection. // functor that implemnts the redirection.
struct Functor { fvisit_attrs_[tindex] =
static void VisitAttrs(Object* self, AttrVisitor* v) { ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
static_cast<T*>(self)->VisitAttrs(v);
} fsequal_[tindex] =
}; ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex); return Registry(this, tindex);
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \brief Structural equality comparison.
*/
#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
#define TVM_NODE_STRUCTURAL_EQUAL_H_
#include <tvm/runtime/data_type.h>
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <string>
namespace tvm {
/*!
* \brief Equality definition of base value class.
*/
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
return lhs == rhs;
}
};
/*!
* \brief Content-aware structural equality comparator for objects.
*
* The structural equality is recursively defined in the DAG of IR nodes via SEqual.
* There are two kinds of nodes:
*
* - Graph node: a graph node in lhs can only be mapped as equal to
* one and only one graph node in rhs.
* - Normal node: equality is recursively defined without the restriction
* of graph nodes.
*
* Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
* For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
* to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
*
* A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
* with the same type if one of the following condition holds:
*
* - They appear in a same definition point(e.g. function argument).
* - They points to the same VarNode via the same_as relation.
* - They appear in a same usage point, and map_free_vars is set to be True.
*/
class StructuralEqual : public BaseValueEqual {
public:
// inheritate operator()
using BaseValueEqual::operator();
/*!
* \brief Compare objects via strutural equal.
* \param lhs The left operand.
* \param rhs The right operand.
* \return The comparison result.
*/
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
};
/*!
* \brief A Reducer class to reduce the structural equality result of two objects.
*
* The reducer will call the SEqualReduce function of each objects recursively.
* Importantly, the reducer may not directly use recursive calls to resolve the
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer : public BaseValueEqual {
public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
public:
/*!
* \brief Reduce condition to equality of lhs and rhs.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs,
const ObjectRef& rhs,
bool map_free_vars) = 0;
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
* This is an auxiliary method to check the Map<Var, Value> equality.
* \param lhs an lhs value.
*
* \return The corresponding rhs value if any, nullptr if not available.
*/
virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
/*!
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
};
using BaseValueEqual::operator();
/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
}
/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
*
* Call this function to compare definition points such as function params
* and var in a let-binding.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
return handler_->SEqualReduce(lhs, rhs, true);
}
/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
template<typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
* \param lhs The left operand.
* \param rhs The right operand.
* \return the result.
*/
bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
// var need to be remapped, so it belongs to graph node.
handler_->MarkGraphNode();
// We only map free vars if they corresponds to the same address
// or map free_var option is set to be true.
return lhs == rhs || map_free_vars_;
}
/*! \return Get the internal handler. */
Handler* operator->() const {
return handler_;
}
private:
/*! \brief Internal class pointer. */
Handler* handler_;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_;
};
} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
...@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode; ...@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode;
class PatternNode : public RelayNode { class PatternNode : public RelayNode {
public: public:
static constexpr const char* _type_key = "relay.Pattern"; static constexpr const char* _type_key = "relay.Pattern";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
}; };
...@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode { ...@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
return true;
}
static constexpr const char* _type_key = "relay.PatternWildcard"; static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
}; };
...@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode { ...@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
return equal.DefEqual(var, other->var);
}
static constexpr const char* _type_key = "relay.PatternVar"; static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
}; };
...@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode { ...@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
return
equal(constructor, other->constructor) &&
equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternConstructor"; static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
}; };
...@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode { ...@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
return equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternTuple"; static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
}; };
...@@ -208,7 +227,12 @@ class ClauseNode : public Object { ...@@ -208,7 +227,12 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs); v->Visit("rhs", &rhs);
} }
bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
return equal(lhs, other->lhs) && equal(rhs, other->rhs);
}
static constexpr const char* _type_key = "relay.Clause"; static constexpr const char* _type_key = "relay.Clause";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
}; };
...@@ -248,6 +272,14 @@ class MatchNode : public ExprNode { ...@@ -248,6 +272,14 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(data, other->data) &&
equal(clauses, other->clauses) &&
equal(complete, other->complete);
}
static constexpr const char* _type_key = "relay.Match"; static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
}; };
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <string> #include <string>
#include <functional> #include <functional>
...@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode { ...@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
static constexpr const char* _type_key = "relay.Constant"; static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
}; };
...@@ -101,6 +106,16 @@ class TupleNode : public ExprNode { ...@@ -101,6 +106,16 @@ class TupleNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
// specially handle empty tuple as a constant is not a graph node.
if (fields.size() == other->fields.size() && fields.size() == 0) {
return true;
} else {
equal->MarkGraphNode();
return equal(fields, other->fields);
}
}
static constexpr const char* _type_key = "relay.Tuple"; static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
}; };
...@@ -157,6 +172,12 @@ class VarNode : public ExprNode { ...@@ -157,6 +172,12 @@ class VarNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return
equal(type_annotation, other->type_annotation) &&
equal.FreeVarEqualImpl(this, other);
}
TVM_DLL static Var make(std::string name_hint, TVM_DLL static Var make(std::string name_hint,
Type type_annotation); Type type_annotation);
...@@ -238,6 +259,16 @@ class CallNode : public ExprNode { ...@@ -238,6 +259,16 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip type_args check for primitive ops.
equal->MarkGraphNode();
return
equal(op, other->op) &&
equal(args, other->args) &&
equal(attrs, other->attrs) &&
(IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
static constexpr const char* _type_key = "relay.Call"; static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
}; };
...@@ -289,6 +320,14 @@ class LetNode : public ExprNode { ...@@ -289,6 +320,14 @@ class LetNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "relay.Let"; static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
}; };
...@@ -336,6 +375,14 @@ class IfNode : public ExprNode { ...@@ -336,6 +375,14 @@ class IfNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(cond, other->cond) &&
equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch);
}
static constexpr const char* _type_key = "relay.If"; static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
}; };
...@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode { ...@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
return
equal(tuple, other->tuple) &&
equal(index, other->index);
}
static constexpr const char* _type_key = "relay.TupleGetItem"; static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
}; };
...@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode { ...@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(value, other->value);
}
static constexpr const char* _type_key = "relay.RefCreate"; static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
}; };
...@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode { ...@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(ref, other->ref);
}
static constexpr const char* _type_key = "relay.RefRead"; static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
}; };
...@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode { ...@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(ref, other->ref) &&
equal(value, other->value);
}
TVM_DLL static RefWrite make(Expr ref, Expr value); TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite"; static constexpr const char* _type_key = "relay.RefWrite";
...@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode { ...@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0; virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr"; static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
}; };
......
...@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode { ...@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
// Important to make def equal first.
equal->MarkGraphNode();
return
equal.DefEqual(params, other->params) &&
equal.DefEqual(type_params, other->type_params) &&
equal(ret_type, other->ret_type) &&
equal(attrs, other->attrs) &&
equal(body, other->body);
}
/*! /*!
* \brief Return the derived function annotation of this expression. * \brief Return the derived function annotation of this expression.
* *
......
...@@ -65,6 +65,8 @@ class NDArray : public ObjectRef { ...@@ -65,6 +65,8 @@ class NDArray : public ObjectRef {
inline int use_count() const; inline int use_count() const;
/*! \return Pointer to content of DLTensor */ /*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const; inline const DLTensor* operator->() const;
/*! \return Whether the tensor is contiguous */
inline bool IsContiguous() const;
/*! /*!
* \brief Copy data content from another array. * \brief Copy data content from another array.
* \param other The source array to be copied from. * \param other The source array to be copied from.
...@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) { ...@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size; return size;
} }
/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
* \return The check result.
*/
inline bool IsContiguous(const DLTensor& arr) {
if (arr.strides == nullptr) return true;
int64_t expected_stride = 1;
for (int32_t i = arr.ndim; i != 0; --i) {
int32_t k = i - 1;
if (arr.strides[k] != expected_stride) return false;
expected_stride *= arr.shape[k];
}
return true;
}
inline bool NDArray::IsContiguous() const {
return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor);
}
inline void NDArray::CopyFrom(const DLTensor* other) { inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(other, &(get_mutable()->dl_tensor)); CopyFromTo(other, &(get_mutable()->dl_tensor));
......
...@@ -211,11 +211,15 @@ class Object { ...@@ -211,11 +211,15 @@ class Object {
static constexpr bool _type_final = false; static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0; static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true; static constexpr bool _type_child_slots_can_overflow = true;
// member information
static constexpr bool _type_has_method_visit_attrs = true;
static constexpr bool _type_has_method_sequal_reduce = false;
// NOTE: the following field is not type index of Object // NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value. // but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot // The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic; static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor // Default constructor and copy constructor
Object() {} Object() {}
// Override the copy and assign constructors to do nothing. // Override the copy and assign constructors to do nothing.
......
...@@ -150,6 +150,20 @@ class BufferNode : public Object { ...@@ -150,6 +150,20 @@ class BufferNode : public Object {
v->Visit("buffer_type", &buffer_type); v->Visit("buffer_type", &buffer_type);
} }
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables
// in its semantics, skip name as name is not important.
return
equal.DefEqual(data, other->data) &&
equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) &&
equal.DefEqual(strides, other->strides) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(scope, other->scope) &&
equal(data_alignment, other->data_alignment) &&
equal(buffer_type, other->buffer_type);
}
/*! \return preferred index type for this buffer node */ /*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const { DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
...@@ -169,6 +183,7 @@ class BufferNode : public Object { ...@@ -169,6 +183,7 @@ class BufferNode : public Object {
BufferType buffer_type); BufferType buffer_type);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
}; };
......
...@@ -75,6 +75,12 @@ class VarNode : public PrimExprNode { ...@@ -75,6 +75,12 @@ class VarNode : public PrimExprNode {
v->Visit("type_annotation", &type_annotation); v->Visit("type_annotation", &type_annotation);
} }
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
if (!equal(dtype, other->dtype)) return false;
if (!equal(type_annotation, other->type_annotation)) return false;
return equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "tir.Var"; static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
}; };
...@@ -288,11 +294,20 @@ class IterVarNode : public Object { ...@@ -288,11 +294,20 @@ class IterVarNode : public Object {
v->Visit("thread_tag", &thread_tag); v->Visit("thread_tag", &thread_tag);
} }
bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
return
equal(dom, other->dom) &&
equal.DefEqual(var, other->var) &&
equal(iter_type, other->iter_type) &&
equal(thread_tag, other->thread_tag);
}
TVM_DLL static IterVar make(Range dom, Var var, TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type, IterVarType iter_type,
std::string thread_tag = ""); std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
}; };
...@@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode { ...@@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
TVM_DLL PrimExpr static make(std::string value); TVM_DLL PrimExpr static make(std::string value);
static constexpr const char* _type_key = "StringImm"; static constexpr const char* _type_key = "StringImm";
...@@ -359,6 +378,10 @@ class CastNode : public PrimExprNode { ...@@ -359,6 +378,10 @@ class CastNode : public PrimExprNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
TVM_DLL static PrimExpr make(DataType t, PrimExpr v); TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
static constexpr const char* _type_key = "Cast"; static constexpr const char* _type_key = "Cast";
...@@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode { ...@@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode {
v->Visit("b", &b); v->Visit("b", &b);
} }
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) { static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n";
...@@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode { ...@@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode {
v->Visit("b", &b); v->Visit("b", &b);
} }
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) { static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n";
...@@ -539,6 +576,13 @@ class AndNode : public PrimExprNode { ...@@ -539,6 +576,13 @@ class AndNode : public PrimExprNode {
v->Visit("b", &b); v->Visit("b", &b);
} }
bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "And"; static constexpr const char* _type_key = "And";
...@@ -559,6 +603,13 @@ class OrNode : public PrimExprNode { ...@@ -559,6 +603,13 @@ class OrNode : public PrimExprNode {
v->Visit("b", &b); v->Visit("b", &b);
} }
bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "Or"; static constexpr const char* _type_key = "Or";
...@@ -576,6 +627,10 @@ class NotNode : public PrimExprNode { ...@@ -576,6 +627,10 @@ class NotNode : public PrimExprNode {
v->Visit("a", &a); v->Visit("a", &a);
} }
bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(a, other->a);
}
TVM_DLL static PrimExpr make(PrimExpr a); TVM_DLL static PrimExpr make(PrimExpr a);
static constexpr const char* _type_key = "Not"; static constexpr const char* _type_key = "Not";
...@@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode { ...@@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode {
v->Visit("false_value", &false_value); v->Visit("false_value", &false_value);
} }
bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(condition, other->condition) &&
equal(true_value, other->true_value) &&
equal(false_value, other->false_value);
}
TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
static constexpr const char* _type_key = "Select"; static constexpr const char* _type_key = "Select";
...@@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode { ...@@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode {
v->Visit("predicate", &predicate); v->Visit("predicate", &predicate);
} }
bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(buffer_var, other->buffer_var) &&
equal(index, other->index) &&
equal(predicate, other->predicate);
}
TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
static constexpr const char* _type_key = "Load"; static constexpr const char* _type_key = "Load";
...@@ -673,6 +744,14 @@ class RampNode : public PrimExprNode { ...@@ -673,6 +744,14 @@ class RampNode : public PrimExprNode {
v->Visit("lanes", &lanes); v->Visit("lanes", &lanes);
} }
bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(base, other->base) &&
equal(stride, other->stride) &&
equal(lanes, other->lanes);
}
TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes); TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
static constexpr const char* _type_key = "Ramp"; static constexpr const char* _type_key = "Ramp";
...@@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode { ...@@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode {
v->Visit("lanes", &lanes); v->Visit("lanes", &lanes);
} }
bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(value, other->value) &&
equal(lanes, other->lanes);
}
TVM_DLL static PrimExpr make(PrimExpr value, int lanes); TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
static constexpr const char* _type_key = "Broadcast"; static constexpr const char* _type_key = "Broadcast";
...@@ -718,6 +804,14 @@ class LetNode : public PrimExprNode { ...@@ -718,6 +804,14 @@ class LetNode : public PrimExprNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body); TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
static constexpr const char* _type_key = "Let"; static constexpr const char* _type_key = "Let";
...@@ -788,6 +882,16 @@ class CallNode : public PrimExprNode { ...@@ -788,6 +882,16 @@ class CallNode : public PrimExprNode {
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
} }
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(name, other->name) &&
equal(args, other->args) &&
equal(call_type, other->call_type) &&
equal(func, other->func) &&
equal(value_index, other->value_index);
}
TVM_DLL static PrimExpr make(DataType dtype, TVM_DLL static PrimExpr make(DataType dtype,
std::string name, std::string name,
Array<PrimExpr> args, Array<PrimExpr> args,
...@@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode { ...@@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode {
v->Visit("indices", &indices); v->Visit("indices", &indices);
} }
bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(vectors, other->vectors) &&
equal(indices, other->indices);
}
TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices); TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors); TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index); TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
...@@ -918,7 +1029,16 @@ class CommReducerNode : public Object { ...@@ -918,7 +1029,16 @@ class CommReducerNode : public Object {
v->Visit("identity_element", &identity_element); v->Visit("identity_element", &identity_element);
} }
bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
return
equal.DefEqual(lhs, other->lhs) &&
equal.DefEqual(rhs, other->rhs) &&
equal(result, other->result) &&
equal(identity_element, other->identity_element);
}
static constexpr const char* _type_key = "CommReducer"; static constexpr const char* _type_key = "CommReducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
}; };
...@@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode { ...@@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode {
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
} }
bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
// check axis first so IterVars can define the necessary variables.
return
equal(dtype, other->dtype) &&
equal(axis, other->axis) &&
equal(combiner, other->combiner) &&
equal(source, other->source) &&
equal(condition, other->condition) &&
equal(value_index, other->value_index);
}
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
}; };
...@@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode { ...@@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode {
class AnyNode : public PrimExprNode { class AnyNode : public PrimExprNode {
public: public:
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
return true;
}
/*! \brief Convert to var. */ /*! \brief Convert to var. */
Var ToVar() const { Var ToVar() const {
return Var("any_dim", DataType::Int(32)); return Var("any_dim", DataType::Int(32));
......
...@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode { ...@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return
equal.DefEqual(params, other->params) &&
equal(buffer_map, other->buffer_map) &&
equal(ret_type, other->ret_type) &&
equal(body, other->body) &&
equal(attrs, other->attrs);
}
/*! /*!
* \brief Return the derived function annotation of this function. * \brief Return the derived function annotation of this function.
* *
...@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode { ...@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode {
TVM_DLL FuncType func_type_annotation() const; TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc"; static constexpr const char* _type_key = "tir.PrimFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
}; };
......
...@@ -38,6 +38,7 @@ namespace tir { ...@@ -38,6 +38,7 @@ namespace tir {
class StmtNode : public Object { class StmtNode : public Object {
public: public:
static constexpr const char* _type_key = "Stmt"; static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
}; };
...@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode { ...@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt"; static constexpr const char* _type_key = "LetStmt";
...@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode { ...@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
return
equal(node, other->node) &&
equal(attr_key, other->attr_key) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(ObjectRef node, TVM_DLL static Stmt make(ObjectRef node,
std::string type_key, std::string type_key,
PrimExpr value, PrimExpr value,
...@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode { ...@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(message, other->message) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt"; static constexpr const char* _type_key = "AssertStmt";
...@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode { ...@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer"; static constexpr const char* _type_key = "ProducerConsumer";
...@@ -194,6 +224,14 @@ class StoreNode : public StmtNode { ...@@ -194,6 +224,14 @@ class StoreNode : public StmtNode {
v->Visit("predicate", &predicate); v->Visit("predicate", &predicate);
} }
bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var) &&
equal(value, other->value) &&
equal(index, other->index) &&
equal(predicate, other->predicate);
}
TVM_DLL static Stmt make(Var buffer_var, TVM_DLL static Stmt make(Var buffer_var,
PrimExpr value, PrimExpr value,
PrimExpr index, PrimExpr index,
...@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode { ...@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode {
v->Visit("args", &args); v->Visit("args", &args);
} }
bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(value, other->value) &&
equal(args, other->args);
}
TVM_DLL static Stmt make(FunctionRef func, TVM_DLL static Stmt make(FunctionRef func,
int value_index, int value_index,
PrimExpr value, PrimExpr value,
...@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode { ...@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return
equal.DefEqual(buffer_var, other->buffer_var) &&
equal(dtype, other->dtype) &&
equal(extents, other->extents) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var buffer_var, TVM_DLL static Stmt make(Var buffer_var,
DataType dtype, DataType dtype,
Array<PrimExpr> extents, Array<PrimExpr> extents,
...@@ -300,6 +355,11 @@ class FreeNode : public StmtNode { ...@@ -300,6 +355,11 @@ class FreeNode : public StmtNode {
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
} }
bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var); TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free"; static constexpr const char* _type_key = "Free";
...@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode { ...@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode {
PrimExpr condition, PrimExpr condition,
Stmt body); Stmt body);
bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "Realize"; static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
}; };
...@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode { ...@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode {
v->Visit("seq", &seq); v->Visit("seq", &seq);
} }
bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
return equal(seq, other->seq);
}
static constexpr const char* _type_key = "SeqStmt"; static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
}; };
...@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode { ...@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode {
v->Visit("else_case", &else_case); v->Visit("else_case", &else_case);
} }
bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(then_case, other->then_case) &&
equal(else_case, other->else_case);
}
TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse"; static constexpr const char* _type_key = "IfThenElse";
...@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode { ...@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
TVM_DLL static Stmt make(PrimExpr v); TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate"; static constexpr const char* _type_key = "Evaluate";
...@@ -562,6 +647,16 @@ class ForNode : public StmtNode { ...@@ -562,6 +647,16 @@ class ForNode : public StmtNode {
v->Visit("body", &body); v->Visit("body", &body);
} }
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return
equal.DefEqual(loop_var, other->loop_var) &&
equal(min, other->min) &&
equal(extent, other->extent) &&
equal(for_type, other->for_type) &&
equal(device_api, other->device_api) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "For"; static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
}; };
...@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode { ...@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode {
v->Visit("bounds", &bounds); v->Visit("bounds", &bounds);
} }
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds);
}
TVM_DLL static Stmt make(FunctionRef func, TVM_DLL static Stmt make(FunctionRef func,
int value_index, int value_index,
DataType dtype, DataType dtype,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
"""Common data structures across all IR variants.""" """Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .base import structural_equal, assert_structural_equal
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType from .tensor_type import TensorType
......
...@@ -149,3 +149,76 @@ def save_json(node): ...@@ -149,3 +149,76 @@ def save_json(node):
Saved json string. Saved json string.
""" """
return tvm.runtime._ffi_node_api.SaveJSON(node) return tvm.runtime._ffi_node_api.SaveJSON(node)
def structural_equal(lhs, rhs, map_free_vars=False):
"""Check structural equality of lhs and rhs.
The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Graph node: a graph node in lhs can only be mapped as equal to
one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
of graph nodes.
Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
with the same type if one of the following condition holds:
- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.
The rules for var are used to remap variables occurs in function
arguments and let-bindings.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Return
------
result : bool
The comparison result.
"""
return tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars)
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Raises
------
ValueError : if assertion does not hold.
See Also
--------
structural_equal
"""
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)
...@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm") ...@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get()); auto* op = static_cast<const FloatImmNode*>(node.get());
...@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range") ...@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range")
*ret = Range(args[0], args[1]); *ret = Range(args[0], args[1]);
}); });
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get()); auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
}); });
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
GlobalVar::GlobalVar(std::string name_hint) { GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>(); ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
......
...@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions, ...@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
data_ = std::move(n); data_ = std::move(n);
} }
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}
bool IRModuleNode::ContainGlobalVar(const std::string& name) const { bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end(); return global_var_map_.find(name) != global_var_map_.end();
} }
...@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr( ...@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) { const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions); auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func; BaseFunc func;
if (auto* func_node = expr.as<relay::FunctionNode>()) { if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<relay::Function>(func_node); func = GetRef<BaseFunc>(func_node);
} else { } else {
func = relay::Function( func = relay::Function(
relay::FreeVars(expr), expr, Type(), relay::FreeVars(expr), expr, Type(),
......
...@@ -21,11 +21,98 @@ ...@@ -21,11 +21,98 @@
* \file src/node/container.cc * \file src/node/container.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <cstring>
namespace tvm { namespace tvm {
// SEQualReduce traits for runtime containers.
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::StringObj* lhs,
const runtime::StringObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
if (lhs->data != rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::ADTObj* lhs,
const runtime::ADTObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->tag != rhs->tag) return false;
if (lhs->size != rhs->size) return false;
for (uint32_t i = 0; i < lhs->size; ++i) {
if (!equal((*lhs)[i], (*rhs)[i])) return false;
}
return true;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(lhs->dl_tensor))
<< "Can only compare contiguous tensor";
CHECK(runtime::IsContiguous(rhs->dl_tensor))
<< "Can only compare contiguous tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return false;
}
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const ArrayNode* lhs,
const ArrayNode* rhs,
SEqualReducer equal) {
if (lhs->data.size() != rhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(ArrayNode);
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<ArrayNode>();
});
TVM_REGISTER_GLOBAL("node.Array") TVM_REGISTER_GLOBAL("node.Array")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data; std::vector<ObjectRef> data;
...@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize") ...@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
static_cast<const ArrayNode*>(ptr)->data.size()); static_cast<const ArrayNode*>(ptr)->data.size());
}); });
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const MapNode* lhs,
const MapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<MapNode>();
});
struct StrMapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const StrMapNode* lhs,
const StrMapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(StrMapNode);
TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<StrMapNode>();
});
TVM_REGISTER_GLOBAL("node.Map") TVM_REGISTER_GLOBAL("node.Map")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
......
...@@ -180,7 +180,7 @@ ObjectPtr<Object> ...@@ -180,7 +180,7 @@ ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key, ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const { const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key); uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE"; << " is not registered via TVM_REGISTER_NODE_TYPE";
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
bool ReflectionVTable::
SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
uint32_t tindex = self->type_index();
if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fsequal_[tindex](self, other, equal);
}
/*!
* \brief A non recursive stack based SEqual handler that can remaps vars.
*
* This handler pushs the Object equality cases into a stack, and
* traverses the stack to expand the necessary children that need to be checked.
*
* The order of SEqual being called is the same as the order as if we
* eagerly do recursive calls in SEqualReduce.
*/
class RemapVarSEqualHandler :
public SEqualReducer::Handler {
public:
explicit RemapVarSEqualHandler(bool assert_mode)
: assert_mode_(assert_mode) {}
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
// We cannot use check lhs.same_as(rhs) to check equality.
// if we choose to enable var remapping.
//
// Counter example below (%x, %y) are shared vars
// between the two functions(possibly before/after rewriting).
//
// - function0: fn (%x, %y) { %x + %y }
// - function1. fn (%y, %x) { %x + %y }
//
// Because we choose to enable var remapping,
// %x is mapped to %y, and %y is mapped to %x,
// the body of the function no longer means the same thing.
//
// Take away: We can either choose only compare Var by address,
// in which case we can use same_as for quick checking,
// or we have to run deep comparison and avoid to use same_as checks.
auto run = [=]() {
if (!lhs.defined() && !rhs.defined()) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// need to push to pending tasks in this case
pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
return true;
};
return CheckResult(run(), lhs, rhs);
}
void MarkGraphNode() final {
// need to push to pending tasks in this case
CHECK(!allow_push_to_stack_ && !task_stack_.empty());
task_stack_.back().graph_equal = true;
}
ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) return it->second;
return ObjectRef(nullptr);
}
// Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
task_stack_.clear();
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
CHECK_EQ(pending_tasks_.size(), 1U);
CHECK(allow_push_to_stack_);
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.clear();
return RunTasks();
}
protected:
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL)
<< "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
}
return result;
}
/*!
* \brief Run tasks until the stack reaches the stack begin
* \param stack_begin The expected beginning of the stack.
* \return The checks we encountered throughout the process.
*/
bool RunTasks() {
while (task_stack_.size() != 0) {
// Caution: entry becomes invalid when the stack changes
auto& entry = task_stack_.back();
if (entry.children_expanded) {
// When all the children has expanded and visited.
// This means all the condition checks for
// the current entry has been passed
// We can safely mark lhs and rhs as equal to each other.
auto it = equal_map_lhs_.find(entry.lhs);
if (it != equal_map_lhs_.end()) {
CHECK(it->second.same_as(entry.rhs));
}
// create the map if the quality is graph equal.
if (entry.graph_equal) {
equal_map_lhs_[entry.lhs] = entry.rhs;
equal_map_rhs_[entry.rhs] = entry.lhs;
}
task_stack_.pop_back();
} else {
// mark before expand
// Important: because entry becomes invalid when stack changes.
entry.children_expanded = true;
// Expand the objects
// The SEqual of the object can call into this->SEqualReduce
// which populates the pending tasks.
CHECK_EQ(pending_tasks_.size(), 0U);
allow_push_to_stack_ = false;
if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
allow_push_to_stack_ = true;
// Push pending tasks in reverse order, so earlier tasks get to
// expand first in the stack
while (pending_tasks_.size() != 0) {
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.pop_back();
}
}
}
return true;
}
// The default equal as registered in the structural equal vtable.
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
auto compute = [=]() {
CHECK(lhs.defined() &&
rhs.defined() &&
lhs->type_index() == rhs->type_index());
// skip entries that already have equality maps.
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// Run reduce check for free nodes.
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
};
return CheckResult(compute(), lhs, rhs);
}
private:
/*! \brief Pending reduce tasks. */
struct Task {
/*! \brief The lhs operand to be compared. */
ObjectRef lhs;
/*! \brief The rhs operand to be compared. */
ObjectRef rhs;
/*! \brief The map free var argument. */
bool map_free_vars;
/*! \brief Whether the children has been expanded via SEqualReduce */
bool children_expanded{false};
/*! \brief whether the task is about graph equality(need remap). */
bool graph_equal{false};
Task() = default;
Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
: lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
};
// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task.
std::vector<Task> task_stack_;
// Whether we allow push to stack.
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
bool assert_mode,
bool map_free_vars) {
return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
});
bool StructuralEqual::operator()(const ObjectRef& lhs,
const ObjectRef& rhs) const {
return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
}
} // namespace tvm
...@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var") ...@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var")
TVM_REGISTER_GLOBAL("tir.SizeVar") TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t); return SizeVar(s, t);
}); });
IterVar IterVarNode::make(Range dom, IterVar IterVarNode::make(Range dom,
Var var, Var var,
...@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) { ...@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) {
TVM_REGISTER_GLOBAL("tir.StringImm") TVM_REGISTER_GLOBAL("tir.StringImm")
.set_body_typed(StringImmNode::make); .set_body_typed(StringImmNode::make);
PrimExpr CastNode::make(DataType t, PrimExpr value) { PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined()); CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes()); CHECK_EQ(t.lanes(), value.dtype().lanes());
...@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) { ...@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) {
return PrimExpr(node); return PrimExpr(node);
} }
PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(b.defined()) << "ValueError: b is undefined";
...@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { ...@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
return PrimExpr(node); return PrimExpr(node);
} }
PrimExpr NotNode::make(PrimExpr a) { PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool()); CHECK(a.dtype().is_bool());
...@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) { ...@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) {
return PrimExpr(node); return PrimExpr(node);
} }
PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined";
......
...@@ -1114,7 +1114,7 @@ def test_read_variable_op(): ...@@ -1114,7 +1114,7 @@ def test_read_variable_op():
num_output=len(out_name)) num_output=len(out_name))
for i in range(len(tf_output)): for i in range(len(tf_output)):
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
sess.close() sess.close()
......
...@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass): ...@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass):
def test_let(): def test_let():
orig = relay.Let(e.x, e.y, e.z) orig = relay.Let(e.x, e.y, e.z)
orig = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_used_let(): def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c) orig = relay.Let(e.c, e.one, e.c + e.c)
orig = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c) expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
def test_inline(): def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
...@@ -75,7 +75,7 @@ def test_inline(): ...@@ -75,7 +75,7 @@ def test_inline():
def test_chain_unused_let(): def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
orig = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func): def use_f(func):
...@@ -111,13 +111,13 @@ def test_recursion_dead(): ...@@ -111,13 +111,13 @@ def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three) x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination()) dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three) assert tvm.ir.structural_equal(dced, e.three)
def test_op_let(): def test_op_let():
dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination()) transform.DeadCodeElimination())
assert alpha_equal(dced, add(e.three, e.two)) assert tvm.ir.structural_equal(dced, add(e.three, e.two))
def test_tuple_get_item(): def test_tuple_get_item():
...@@ -126,10 +126,10 @@ def test_tuple_get_item(): ...@@ -126,10 +126,10 @@ def test_tuple_get_item():
a = relay.Var('a') a = relay.Var('a')
g = relay.TupleGetItem(t, 0) g = relay.TupleGetItem(t, 0)
dced = run_opt_pass(g, transform.DeadCodeElimination()) dced = run_opt_pass(g, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = run_opt_pass(orig, transform.DeadCodeElimination()) dced = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
......
...@@ -72,7 +72,7 @@ def test_tuple(): ...@@ -72,7 +72,7 @@ def test_tuple():
f = Function([x], body, None, [t]) f = Function([x], body, None, [t])
expected = relay.Function([x], x, None, [t]) expected = relay.Function([x], x, None, [t])
expected = run_opt_pass(expected, transform.InferType()) expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(dcpe(f), expected) assert tvm.ir.structural_equal(dcpe(f), expected)
def test_const_inline(): def test_const_inline():
...@@ -80,7 +80,7 @@ def test_const_inline(): ...@@ -80,7 +80,7 @@ def test_const_inline():
d = Var("d", t) d = Var("d", t)
double = Function([d], d + d) double = Function([d], d + d)
orig = double(const(4.0)) orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0)) assert tvm.ir.structural_equal(dcpe(orig), const(8.0))
def test_ref(): def test_ref():
...@@ -93,7 +93,7 @@ def test_ref(): ...@@ -93,7 +93,7 @@ def test_ref():
body = Let(r, RefCreate(d), body) body = Let(r, RefCreate(d), body)
square = Function([d], body) square = Function([d], body)
expected = run_opt_pass(Function([d], d * d), transform.InferType()) expected = run_opt_pass(Function([d], d * d), transform.InferType())
assert alpha_equal(dcpe(square), expected) assert tvm.ir.structural_equal(dcpe(square), expected)
def test_empty_ad(): def test_empty_ad():
...@@ -105,7 +105,7 @@ def test_empty_ad(): ...@@ -105,7 +105,7 @@ def test_empty_ad():
g = dcpe(f, grad=True) g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = run_opt_pass(expected, transform.InferType()) expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected) assert tvm.ir.structural_equal(g, expected)
def test_ad(): def test_ad():
...@@ -180,7 +180,7 @@ def test_head_cons(): ...@@ -180,7 +180,7 @@ def test_head_cons():
body = hd(p.cons(x, p.nil())) body = hd(p.cons(x, p.nil()))
f = Function([x], body, None, [t]) f = Function([x], body, None, [t])
res = dcpe(f, mod) res = dcpe(f, mod)
assert alpha_equal(res, Function([x], x, t, [t])) assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_map(): def test_map():
...@@ -197,7 +197,7 @@ def test_map(): ...@@ -197,7 +197,7 @@ def test_map():
expected = mod["main"] expected = mod["main"]
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body) assert tvm.ir.structural_equal(res.body, expected.body)
def test_loop(): def test_loop():
...@@ -211,7 +211,7 @@ def test_loop(): ...@@ -211,7 +211,7 @@ def test_loop():
expected = mod["main"].body expected = mod["main"].body
call = Function([], loop(const(1))) call = Function([], loop(const(1)))
res = dcpe(call, mod=mod) res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected) assert tvm.ir.structural_equal(res.body, expected)
def test_swap_loop(): def test_swap_loop():
...@@ -226,7 +226,7 @@ def test_swap_loop(): ...@@ -226,7 +226,7 @@ def test_swap_loop():
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = Function([], prog) res = Function([], prog)
res = dcpe(res, mod=mod) res = dcpe(res, mod=mod)
assert alpha_equal(prog, res.body) assert tvm.ir.structural_equal(prog, res.body)
def test_abs_diff(): def test_abs_diff():
...@@ -248,7 +248,7 @@ def test_abs_diff(): ...@@ -248,7 +248,7 @@ def test_abs_diff():
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 4)) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id(): def test_match_nat_id():
...@@ -265,7 +265,7 @@ def test_match_nat_id(): ...@@ -265,7 +265,7 @@ def test_match_nat_id():
orig = nat_id(make_nat_expr(p, 3)) orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3)) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_nat_id(): def test_nat_id():
...@@ -280,7 +280,7 @@ def test_nat_id(): ...@@ -280,7 +280,7 @@ def test_nat_id():
orig = nat_id(make_nat_expr(p, 3)) orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3)) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id(): def test_global_match_nat_id():
...@@ -294,7 +294,7 @@ def test_global_match_nat_id(): ...@@ -294,7 +294,7 @@ def test_global_match_nat_id():
orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3)) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_double(): def test_double():
...@@ -304,7 +304,7 @@ def test_double(): ...@@ -304,7 +304,7 @@ def test_double():
orig = p.double(make_nat_expr(p, 3)) orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 6)) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
def test_concat(): def test_concat():
......
...@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d(): ...@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d():
# Since same dtype, there should not be any transformation # Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod) assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################ ################################################################
# Check transformations for platforms without fast Int8 support. # Check transformations for platforms without fast Int8 support.
...@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d(): ...@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d():
# Check no transformation for Intel VNNI. # Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'): with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod) assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened. # ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
...@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense(): ...@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense():
# Since same dtype, there should not be any transformation # Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod) assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################ ################################################################
# Check transformations for platforms without fast Int8 support. # Check transformations for platforms without fast Int8 support.
...@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense(): ...@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense():
# Check no transformation for Intel VNNI. # Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'): with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod) assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened. # ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
......
...@@ -76,7 +76,7 @@ def test_order(): ...@@ -76,7 +76,7 @@ def test_order():
expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output) expected_output = relay.Let(a, x, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType()) expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output) assert tvm.ir.structural_equal(anf, expected_output)
def test_if(): def test_if():
...@@ -93,7 +93,7 @@ def test_if(): ...@@ -93,7 +93,7 @@ def test_if():
expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output) expected_output = relay.Let(c, cond, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType()) expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output) assert tvm.ir.structural_equal(anf, expected_output)
# make sure we dont infinite loop. # make sure we dont infinite loop.
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.analysis import detect_feature
from tvm.relay.transform import to_cps, un_cps from tvm.relay.transform import to_cps, un_cps
from tvm.relay.analysis import Feature from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
......
...@@ -21,7 +21,6 @@ import tvm ...@@ -21,7 +21,6 @@ import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay import op, transform, analysis from tvm.relay import op, transform, analysis
from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None): def run_infer_type(expr, mod=None):
...@@ -360,7 +359,7 @@ def test_let_polymorphism(): ...@@ -360,7 +359,7 @@ def test_let_polymorphism():
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body) body = run_infer_type(body)
int32 = relay.TensorType((), "int32") int32 = relay.TensorType((), "int32")
assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,7 +25,7 @@ def test_const_saveload_json(): ...@@ -25,7 +25,7 @@ def test_const_saveload_json():
z = z + z z = z + z
json_str = tvm.ir.save_json(z) json_str = tvm.ir.save_json(z)
zz = tvm.ir.load_json(json_str) zz = tvm.ir.load_json(json_str)
assert tvm.ir.save_json(zz) == tvm.ir.save_json(z) tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
def test_make_smap(): def test_make_smap():
...@@ -38,6 +38,7 @@ def test_make_smap(): ...@@ -38,6 +38,7 @@ def test_make_smap():
arr = tvm.ir.load_json(json_str) arr = tvm.ir.load_json(json_str)
assert len(arr) == 1 assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"] assert arr[0]["z"].a == arr[0]["x"]
tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True)
def test_make_node(): def test_make_node():
...@@ -90,7 +91,6 @@ def test_env_func(): ...@@ -90,7 +91,6 @@ def test_env_func():
if __name__ == "__main__": if __name__ == "__main__":
test_env_func() test_env_func()
test_make_attrs()
test_make_node() test_make_node()
test_make_smap() test_make_smap()
test_const_saveload_json() test_const_saveload_json()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import pytest
from tvm import te
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
y = tvm.tir.const(10, "int32")
vx = te.var("x")
vy = te.var("y")
vz = te.var("z")
# test assert trigger.
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(x, y)
assert not tvm.ir.structural_equal(vx, vy)
assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
# corner case lhs:vx == rhs:vy, but cannot map it iteslf
assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
# corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
# corner case2: rolling remap.
assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
# Defintition remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
tvm.tir.Let(vy, 1, vy - 1))
# Default same address free var remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
tvm.tir.Let(vy, 1, vy // vz))
zx = vx + vx
zy = vy + vy
assert tvm.ir.structural_equal(zx * zx, zx * zx)
assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
def test_prim_func():
x = te.var('x')
y = te.var('y')
# counter example of same equality
func0 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(y + x))
assert not tvm.ir.structural_equal(func0, func1)
# new cases
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1))
func0 = tvm.tir.PrimFunc(
[x, y, b], stmt)
# easiest way to deep copy is via save/load
func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
tvm.ir.assert_structural_equal(func0, func1)
data0 = tvm.nd.array([1, 2, 3])
data1 = tvm.nd.array([1, 2, 3])
# attributes and ndarrays
func0 = func0.with_attr("data", data0)
func1 = func1.with_attr("data", data1)
# IRModules
mod0 = tvm.IRModule.from_expr(func0)
mod1 = tvm.IRModule.from_expr(func1)
tvm.ir.assert_structural_equal(mod0, mod1)
def test_attrs():
x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
tvm.ir.assert_structural_equal(y, x)
assert not tvm.ir.structural_equal(y, z)
if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment