Unverified Commit 997a14ed by Tianqi Chen Committed by GitHub

[NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)

* [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.

This PR introduces a new way to handle structural equality
for both TIR and relay nodes in an extensive way.

- Each object can now register an optional SEqualReduce function, which
  describes how to reduce its structural equality to another instance
  into equality of the children.
- Optionally, the object can choose to allow remapping of vars(e.g. function parameters)
  by calling DefEqual
- We implemented a non-recursive structural equality checker that
  recursively traverses the objects and does the structural equality checking.

This PR also fixes a few potential problems in previous relay's AlphaEqual.

- In particular, the new structural equality relation will be communicative.
- It is can be dangerous to use same_as relation to quickly check equality,
  demonstrated by the following case. (%x, %y) are shared vars between two functions.

- function0: fn (%x, %y) { %x + %y }
- function1: fn (%y, %x) { %x + %y }

The new structural equal is intented to supersede AlphaEqual and AttrsEqual.

Follow-up PRs should be performed to redirect the existing usages, and removes
the corresponding implementation.

* Update the rule to distinguish between graph node and non-graph nodes.

* Refactor the test cases to use structural equal.

* address comments

* Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal)

* Remove unrelated comment

* Fix file comment

* Address review comment

* Relax condition to fit flaky case
parent 9c806621
......@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
v->Visit("max_value", &max_value);
}
bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
}
/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
......@@ -170,6 +174,10 @@ class ModularSetNode : public Object {
v->Visit("base", &base);
}
bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
return equal(coeff, other->coeff) && equal(base, other->base);
}
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
......
......@@ -59,6 +59,7 @@ enum SignType {
class IntSetNode : public Object {
public:
static constexpr const char* _type_key = "IntSet";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};
......
......@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
return
equal(name_hint, other->name_hint) &&
equal(inputs, other->inputs);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
......@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
return
equal.DefEqual(header, other->header) &&
equal.DefEqual(type_vars, other->type_vars) &&
equal(constructors, other->constructors);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
......
......@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
......@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
......@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
}
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
......@@ -401,6 +408,33 @@ class AttrsEqualVisitor {
const AttrsEqual& equal_;
};
class AttrsSEqualVisitor {
public:
bool result_{true};
// constructor
AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Object* lhs_;
const Object* rhs_;
const SEqualReducer& equal_;
};
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
......@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
}
}
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
DerivedType* pself = self();
::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
......
......@@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
v->Visit("name", &name);
}
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
}
static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
......
......@@ -43,6 +43,7 @@ namespace tvm {
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
......@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
// name matters for global var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
......@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
......@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
......@@ -353,7 +369,12 @@ class RangeNode : public Object {
v->Visit("extent", &extent);
}
bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
return equal(min, other->min) && equal(extent, other->extent);
}
static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
......
......@@ -62,6 +62,8 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
}
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
......@@ -235,6 +237,7 @@ class IRModuleNode : public Object {
TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
......
......@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
v->Visit("support_level", &support_level);
}
bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
// pointer equality is fine as there is only one op with the same name.
return this == other;
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
......
......@@ -44,6 +44,10 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
return equal(name, other->name);
}
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
......@@ -87,6 +91,13 @@ class SpanNode : public Object {
v->Visit("col_offset", &col_offset);
}
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
return
equal(source, other->source) &&
equal(lineno, other->lineno) &&
equal(col_offset, other->col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "Span";
......
......@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
return
equal(shape, other->shape) &&
equal(dtype, other->dtype);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
......
......@@ -111,6 +111,7 @@ class PassContextNode : public Object {
}
static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
......@@ -207,6 +208,7 @@ class PassInfoNode : public Object {
}
static constexpr const char* _type_key = "transform.PassInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
......
......@@ -79,6 +79,7 @@ class TypeNode : public Object {
mutable Span span;
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
......@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}
bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
......@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
v->Visit("element_type", &element_type);
}
bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
......@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
......@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}
bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
......@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
......@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
return
equal.DefEqual(type_params, other->type_params) &&
equal(arg_types, other->arg_types) &&
equal(ret_type, other->ret_type) &&
equal(type_constraints, other->type_constraints);
}
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
......@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
}
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
......@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
......
......@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args);
}
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
......@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args) &&
equal(num_inputs, other->num_inputs) &&
equal(attrs, other->attrs);
}
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
......
......@@ -23,7 +23,9 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <type_traits>
#include <vector>
......@@ -34,15 +36,19 @@
namespace tvm {
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
using runtime::make_object;
using runtime::ObjectHash;
using runtime::ObjectEqual;
/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;
void VisitAttrs(AttrVisitor* visitor) {
}
static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};
......@@ -50,9 +56,6 @@ class ArrayNode : public Object {
/*! \brief map node content */
class MapNode : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
ObjectRef,
......@@ -73,9 +76,6 @@ class StrMapNode : public Object {
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief the data content */
ContainerType data;
......
......@@ -39,6 +39,8 @@
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <string>
#include <vector>
......
......@@ -29,13 +29,14 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/node/structural_equal.h>
#include <vector>
#include <string>
#include <type_traits>
namespace tvm {
// forward declaration
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
......@@ -87,6 +88,13 @@ class ReflectionVTable {
*/
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
* \brief Equality comparison function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
......@@ -112,6 +120,14 @@ class ReflectionVTable {
*/
inline std::string GetGlobalKey(Object* self) const;
/*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
* \param other The pointer to another object to be compared.
* \param equal The equality comparator.
* \return the result.
*/
bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
/*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
......@@ -139,12 +155,14 @@ class ReflectionVTable {
TVM_DLL static ReflectionVTable* Global();
class Registry;
template<typename T>
template<typename T, typename TraitName>
inline Registry Register();
private:
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Structural equal function. */
std::vector<FSEqualReduce> fsequal_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
......@@ -182,6 +200,44 @@ class ReflectionVTable::Registry {
uint32_t type_index_;
};
#define TVM_REFLECTION_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \
__make_reflectiion
/*!
* \brief Directly register reflection VTable.
* \param TypeName The name of the type.
* \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce.
*
* \code
*
* // Example SEQualReduce traits for runtime StringObj.
*
* struct StringObjTrait {
* static constexpr const std::nullptr_t VisitAttrs = nullptr;
*
* static bool SEqualReduce(const runtime::StringObj* lhs,
* const runtime::StringObj* rhs,
* SEqualReducer equal) {
* if (lhs == rhs) return true;
* if (lhs->size != rhs->size) return false;
* if (lhs->data != rhs->data) return true;
* return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
* }
* };
*
* TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
*
* \endcode
*
* \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
* And can be used to register the related reflection functions for runtime objects.
*/
#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
......@@ -189,15 +245,79 @@ class ReflectionVTable::Registry {
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
__make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \
.set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \
})
TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
.set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \
})
// Implementation details
namespace detail {
template<typename T,
bool = T::_type_has_method_visit_attrs>
struct ImplVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T>
struct ImplVisitAttrs<T, true> {
static void VisitAttrs(T* self, AttrVisitor* v) {
self->VisitAttrs(v);
}
};
template<typename T,
bool = T::_type_has_method_sequal_reduce>
struct ImplSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T>
struct ImplSEqualReduce<T, true> {
static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
return self->SEqualReduce(other, equal);
}
};
template<typename T>
struct ReflectionTrait :
public ImplVisitAttrs<T>,
public ImplSEqualReduce<T> {
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
struct SelectVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T, typename TraitName>
struct SelectVisitAttrs<T, TraitName, false> {
static void VisitAttrs(Object* self, AttrVisitor* v) {
TraitName::VisitAttrs(static_cast<T*>(self), v);
}
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
struct SelectSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T, typename TraitName>
struct SelectSEqualReduce<T, TraitName, false> {
static bool SEqualReduce(const Object* self,
const Object* other,
SEqualReducer equal) {
return TraitName::SEqualReduce(static_cast<const T*>(self),
static_cast<const T*>(other),
equal);
}
};
} // namespace detail
template<typename T, typename TraitName>
inline ReflectionVTable::Registry
ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
......@@ -205,15 +325,15 @@ ReflectionVTable::Register() {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
fsequal_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
struct Functor {
static void VisitAttrs(Object* self, AttrVisitor* v) {
static_cast<T*>(self)->VisitAttrs(v);
}
};
fvisit_attrs_[tindex] =
::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
fsequal_[tindex] =
::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \brief Structural equality comparison.
*/
#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
#define TVM_NODE_STRUCTURAL_EQUAL_H_
#include <tvm/runtime/data_type.h>
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <string>
namespace tvm {
/*!
* \brief Equality definition of base value class.
*/
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
return lhs == rhs;
}
};
/*!
* \brief Content-aware structural equality comparator for objects.
*
* The structural equality is recursively defined in the DAG of IR nodes via SEqual.
* There are two kinds of nodes:
*
* - Graph node: a graph node in lhs can only be mapped as equal to
* one and only one graph node in rhs.
* - Normal node: equality is recursively defined without the restriction
* of graph nodes.
*
* Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
* For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
* to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
*
* A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
* with the same type if one of the following condition holds:
*
* - They appear in a same definition point(e.g. function argument).
* - They points to the same VarNode via the same_as relation.
* - They appear in a same usage point, and map_free_vars is set to be True.
*/
class StructuralEqual : public BaseValueEqual {
public:
// inheritate operator()
using BaseValueEqual::operator();
/*!
* \brief Compare objects via strutural equal.
* \param lhs The left operand.
* \param rhs The right operand.
* \return The comparison result.
*/
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
};
/*!
* \brief A Reducer class to reduce the structural equality result of two objects.
*
* The reducer will call the SEqualReduce function of each objects recursively.
* Importantly, the reducer may not directly use recursive calls to resolve the
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer : public BaseValueEqual {
public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
public:
/*!
* \brief Reduce condition to equality of lhs and rhs.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs,
const ObjectRef& rhs,
bool map_free_vars) = 0;
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
* This is an auxiliary method to check the Map<Var, Value> equality.
* \param lhs an lhs value.
*
* \return The corresponding rhs value if any, nullptr if not available.
*/
virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
/*!
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
};
using BaseValueEqual::operator();
/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
}
/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
*
* Call this function to compare definition points such as function params
* and var in a let-binding.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
return handler_->SEqualReduce(lhs, rhs, true);
}
/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
template<typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
* \param lhs The left operand.
* \param rhs The right operand.
* \return the result.
*/
bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
// var need to be remapped, so it belongs to graph node.
handler_->MarkGraphNode();
// We only map free vars if they corresponds to the same address
// or map free_var option is set to be true.
return lhs == rhs || map_free_vars_;
}
/*! \return Get the internal handler. */
Handler* operator->() const {
return handler_;
}
private:
/*! \brief Internal class pointer. */
Handler* handler_;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_;
};
} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
......@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode;
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
......@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
return true;
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
......@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
return equal.DefEqual(var, other->var);
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
......@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
return
equal(constructor, other->constructor) &&
equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
......@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
return equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
......@@ -208,7 +227,12 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs);
}
bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
return equal(lhs, other->lhs) && equal(rhs, other->rhs);
}
static constexpr const char* _type_key = "relay.Clause";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
......@@ -248,6 +272,14 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(data, other->data) &&
equal(clauses, other->clauses) &&
equal(complete, other->complete);
}
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
......
......@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
......@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
......@@ -101,6 +106,16 @@ class TupleNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
// specially handle empty tuple as a constant is not a graph node.
if (fields.size() == other->fields.size() && fields.size() == 0) {
return true;
} else {
equal->MarkGraphNode();
return equal(fields, other->fields);
}
}
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
......@@ -157,6 +172,12 @@ class VarNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return
equal(type_annotation, other->type_annotation) &&
equal.FreeVarEqualImpl(this, other);
}
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
......@@ -238,6 +259,16 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip type_args check for primitive ops.
equal->MarkGraphNode();
return
equal(op, other->op) &&
equal(args, other->args) &&
equal(attrs, other->attrs) &&
(IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
......@@ -289,6 +320,14 @@ class LetNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
......@@ -336,6 +375,14 @@ class IfNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(cond, other->cond) &&
equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch);
}
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
......@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
return
equal(tuple, other->tuple) &&
equal(index, other->index);
}
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
......@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(value, other->value);
}
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
......@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(ref, other->ref);
}
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
......@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(ref, other->ref) &&
equal(value, other->value);
}
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
......@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
......
......@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
// Important to make def equal first.
equal->MarkGraphNode();
return
equal.DefEqual(params, other->params) &&
equal.DefEqual(type_params, other->type_params) &&
equal(ret_type, other->ret_type) &&
equal(attrs, other->attrs) &&
equal(body, other->body);
}
/*!
* \brief Return the derived function annotation of this expression.
*
......
......@@ -65,6 +65,8 @@ class NDArray : public ObjectRef {
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
/*! \return Whether the tensor is contiguous */
inline bool IsContiguous() const;
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
......@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size;
}
/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
* \return The check result.
*/
inline bool IsContiguous(const DLTensor& arr) {
if (arr.strides == nullptr) return true;
int64_t expected_stride = 1;
for (int32_t i = arr.ndim; i != 0; --i) {
int32_t k = i - 1;
if (arr.strides[k] != expected_stride) return false;
expected_stride *= arr.shape[k];
}
return true;
}
inline bool NDArray::IsContiguous() const {
return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor);
}
inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(get_mutable()->dl_tensor));
......
......@@ -211,11 +211,15 @@ class Object {
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// member information
static constexpr bool _type_has_method_visit_attrs = true;
static constexpr bool _type_has_method_sequal_reduce = false;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
......
......@@ -150,6 +150,20 @@ class BufferNode : public Object {
v->Visit("buffer_type", &buffer_type);
}
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables
// in its semantics, skip name as name is not important.
return
equal.DefEqual(data, other->data) &&
equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) &&
equal.DefEqual(strides, other->strides) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(scope, other->scope) &&
equal(data_alignment, other->data_alignment) &&
equal(buffer_type, other->buffer_type);
}
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
......@@ -169,6 +183,7 @@ class BufferNode : public Object {
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};
......
......@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return
equal.DefEqual(params, other->params) &&
equal(buffer_map, other->buffer_map) &&
equal(ret_type, other->ret_type) &&
equal(body, other->body) &&
equal(attrs, other->attrs);
}
/*!
* \brief Return the derived function annotation of this function.
*
......@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode {
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
......
......@@ -38,6 +38,7 @@ namespace tir {
class StmtNode : public Object {
public:
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
......@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
......@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
return
equal(node, other->node) &&
equal(attr_key, other->attr_key) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(ObjectRef node,
std::string type_key,
PrimExpr value,
......@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(message, other->message) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
......@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
......@@ -194,6 +224,14 @@ class StoreNode : public StmtNode {
v->Visit("predicate", &predicate);
}
bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var) &&
equal(value, other->value) &&
equal(index, other->index) &&
equal(predicate, other->predicate);
}
TVM_DLL static Stmt make(Var buffer_var,
PrimExpr value,
PrimExpr index,
......@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode {
v->Visit("args", &args);
}
bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(value, other->value) &&
equal(args, other->args);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
PrimExpr value,
......@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return
equal.DefEqual(buffer_var, other->buffer_var) &&
equal(dtype, other->dtype) &&
equal(extents, other->extents) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
......@@ -300,6 +355,11 @@ class FreeNode : public StmtNode {
v->Visit("buffer_var", &buffer_var);
}
bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
......@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode {
PrimExpr condition,
Stmt body);
bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};
......@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode {
v->Visit("seq", &seq);
}
bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
return equal(seq, other->seq);
}
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
......@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode {
v->Visit("else_case", &else_case);
}
bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(then_case, other->then_case) &&
equal(else_case, other->else_case);
}
TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
......@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode {
v->Visit("value", &value);
}
bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate";
......@@ -562,6 +647,16 @@ class ForNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return
equal.DefEqual(loop_var, other->loop_var) &&
equal(min, other->min) &&
equal(extent, other->extent) &&
equal(for_type, other->for_type) &&
equal(device_api, other->device_api) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
......@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode {
v->Visit("bounds", &bounds);
}
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
......
......@@ -17,6 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .base import structural_equal, assert_structural_equal
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
......
......@@ -149,3 +149,76 @@ def save_json(node):
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
def structural_equal(lhs, rhs, map_free_vars=False):
"""Check structural equality of lhs and rhs.
The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Graph node: a graph node in lhs can only be mapped as equal to
one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
of graph nodes.
Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
with the same type if one of the following condition holds:
- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.
The rules for var are used to remap variables occurs in function
arguments and let-bindings.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Return
------
result : bool
The comparison result.
"""
return tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars)
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Raises
------
ValueError : if assertion does not hold.
See Also
--------
structural_equal
"""
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)
......@@ -45,8 +45,8 @@ class AttrFunctor;
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitAttr_(static_cast<const OP*>(n.get()), \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitAttr_(static_cast<const OP*>(n.get()), \
std::forward<Args>(args)...); \
}); \
......
......@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
......@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range")
*ret = Range(args[0], args[1]);
});
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
......
......@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
data_ = std::move(n);
}
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
......@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node);
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
......
......@@ -21,11 +21,98 @@
* \file src/node/container.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <cstring>
namespace tvm {
// SEQualReduce traits for runtime containers.
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::StringObj* lhs,
const runtime::StringObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
if (lhs->data != rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::ADTObj* lhs,
const runtime::ADTObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->tag != rhs->tag) return false;
if (lhs->size != rhs->size) return false;
for (uint32_t i = 0; i < lhs->size; ++i) {
if (!equal((*lhs)[i], (*rhs)[i])) return false;
}
return true;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(lhs->dl_tensor))
<< "Can only compare contiguous tensor";
CHECK(runtime::IsContiguous(rhs->dl_tensor))
<< "Can only compare contiguous tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return false;
}
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const ArrayNode* lhs,
const ArrayNode* rhs,
SEqualReducer equal) {
if (lhs->data.size() != rhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(ArrayNode);
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<ArrayNode>();
});
TVM_REGISTER_GLOBAL("node.Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
......@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
static_cast<const ArrayNode*>(ptr)->data.size());
});
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const MapNode* lhs,
const MapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<MapNode>();
});
struct StrMapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const StrMapNode* lhs,
const StrMapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(StrMapNode);
TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<StrMapNode>();
});
TVM_REGISTER_GLOBAL("node.Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
......
......@@ -180,7 +180,7 @@ ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
bool ReflectionVTable::
SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
uint32_t tindex = self->type_index();
if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fsequal_[tindex](self, other, equal);
}
/*!
* \brief A non recursive stack based SEqual handler that can remaps vars.
*
* This handler pushs the Object equality cases into a stack, and
* traverses the stack to expand the necessary children that need to be checked.
*
* The order of SEqual being called is the same as the order as if we
* eagerly do recursive calls in SEqualReduce.
*/
class RemapVarSEqualHandler :
public SEqualReducer::Handler {
public:
explicit RemapVarSEqualHandler(bool assert_mode)
: assert_mode_(assert_mode) {}
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
// We cannot use check lhs.same_as(rhs) to check equality.
// if we choose to enable var remapping.
//
// Counter example below (%x, %y) are shared vars
// between the two functions(possibly before/after rewriting).
//
// - function0: fn (%x, %y) { %x + %y }
// - function1. fn (%y, %x) { %x + %y }
//
// Because we choose to enable var remapping,
// %x is mapped to %y, and %y is mapped to %x,
// the body of the function no longer means the same thing.
//
// Take away: We can either choose only compare Var by address,
// in which case we can use same_as for quick checking,
// or we have to run deep comparison and avoid to use same_as checks.
auto run = [=]() {
if (!lhs.defined() && !rhs.defined()) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// need to push to pending tasks in this case
pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
return true;
};
return CheckResult(run(), lhs, rhs);
}
void MarkGraphNode() final {
// need to push to pending tasks in this case
CHECK(!allow_push_to_stack_ && !task_stack_.empty());
task_stack_.back().graph_equal = true;
}
ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) return it->second;
return ObjectRef(nullptr);
}
// Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
task_stack_.clear();
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
CHECK_EQ(pending_tasks_.size(), 1U);
CHECK(allow_push_to_stack_);
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.clear();
return RunTasks();
}
protected:
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL)
<< "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
}
return result;
}
/*!
* \brief Run tasks until the stack reaches the stack begin
* \param stack_begin The expected beginning of the stack.
* \return The checks we encountered throughout the process.
*/
bool RunTasks() {
while (task_stack_.size() != 0) {
// Caution: entry becomes invalid when the stack changes
auto& entry = task_stack_.back();
if (entry.children_expanded) {
// When all the children has expanded and visited.
// This means all the condition checks for
// the current entry has been passed
// We can safely mark lhs and rhs as equal to each other.
auto it = equal_map_lhs_.find(entry.lhs);
if (it != equal_map_lhs_.end()) {
CHECK(it->second.same_as(entry.rhs));
}
// create the map if the quality is graph equal.
if (entry.graph_equal) {
equal_map_lhs_[entry.lhs] = entry.rhs;
equal_map_rhs_[entry.rhs] = entry.lhs;
}
task_stack_.pop_back();
} else {
// mark before expand
// Important: because entry becomes invalid when stack changes.
entry.children_expanded = true;
// Expand the objects
// The SEqual of the object can call into this->SEqualReduce
// which populates the pending tasks.
CHECK_EQ(pending_tasks_.size(), 0U);
allow_push_to_stack_ = false;
if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
allow_push_to_stack_ = true;
// Push pending tasks in reverse order, so earlier tasks get to
// expand first in the stack
while (pending_tasks_.size() != 0) {
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.pop_back();
}
}
}
return true;
}
// The default equal as registered in the structural equal vtable.
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
auto compute = [=]() {
CHECK(lhs.defined() &&
rhs.defined() &&
lhs->type_index() == rhs->type_index());
// skip entries that already have equality maps.
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// Run reduce check for free nodes.
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
};
return CheckResult(compute(), lhs, rhs);
}
private:
/*! \brief Pending reduce tasks. */
struct Task {
/*! \brief The lhs operand to be compared. */
ObjectRef lhs;
/*! \brief The rhs operand to be compared. */
ObjectRef rhs;
/*! \brief The map free var argument. */
bool map_free_vars;
/*! \brief Whether the children has been expanded via SEqualReduce */
bool children_expanded{false};
/*! \brief whether the task is about graph equality(need remap). */
bool graph_equal{false};
Task() = default;
Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
: lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
};
// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task.
std::vector<Task> task_stack_;
// Whether we allow push to stack.
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
bool assert_mode,
bool map_free_vars) {
return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
});
bool StructuralEqual::operator()(const ObjectRef& lhs,
const ObjectRef& rhs) const {
return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
}
} // namespace tvm
......@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var")
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});
});
IterVar IterVarNode::make(Range dom,
Var var,
......@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) {
TVM_REGISTER_GLOBAL("tir.StringImm")
.set_body_typed(StringImmNode::make);
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
......@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) {
return PrimExpr(node);
}
PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
......@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
return PrimExpr(node);
}
PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
......@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) {
return PrimExpr(node);
}
PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
......@@ -270,11 +276,11 @@ bool CallNode::is_vectorizable() const {
}
PrimExpr CallNode::make(DataType dtype,
std::string name,
Array<PrimExpr> args,
CallType call_type,
FunctionRef func,
int value_index) {
std::string name,
Array<PrimExpr> args,
CallType call_type,
FunctionRef func,
int value_index) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
......
......@@ -1114,7 +1114,7 @@ def test_read_variable_op():
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
sess.close()
......
......@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass):
def test_let():
orig = relay.Let(e.x, e.y, e.z)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
......@@ -75,7 +75,7 @@ def test_inline():
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func):
......@@ -111,13 +111,13 @@ def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
assert tvm.ir.structural_equal(dced, e.three)
def test_op_let():
dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination())
assert alpha_equal(dced, add(e.three, e.two))
assert tvm.ir.structural_equal(dced, add(e.three, e.two))
def test_tuple_get_item():
......@@ -126,10 +126,10 @@ def test_tuple_get_item():
a = relay.Var('a')
g = relay.TupleGetItem(t, 0)
dced = run_opt_pass(g, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
@pytest.mark.timeout(timeout=10, method="thread")
......
......@@ -72,7 +72,7 @@ def test_tuple():
f = Function([x], body, None, [t])
expected = relay.Function([x], x, None, [t])
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(dcpe(f), expected)
assert tvm.ir.structural_equal(dcpe(f), expected)
def test_const_inline():
......@@ -80,7 +80,7 @@ def test_const_inline():
d = Var("d", t)
double = Function([d], d + d)
orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0))
assert tvm.ir.structural_equal(dcpe(orig), const(8.0))
def test_ref():
......@@ -93,7 +93,7 @@ def test_ref():
body = Let(r, RefCreate(d), body)
square = Function([d], body)
expected = run_opt_pass(Function([d], d * d), transform.InferType())
assert alpha_equal(dcpe(square), expected)
assert tvm.ir.structural_equal(dcpe(square), expected)
def test_empty_ad():
......@@ -105,7 +105,7 @@ def test_empty_ad():
g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected)
assert tvm.ir.structural_equal(g, expected)
def test_ad():
......@@ -180,7 +180,7 @@ def test_head_cons():
body = hd(p.cons(x, p.nil()))
f = Function([x], body, None, [t])
res = dcpe(f, mod)
assert alpha_equal(res, Function([x], x, t, [t]))
assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_map():
......@@ -197,7 +197,7 @@ def test_map():
expected = mod["main"]
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body)
assert tvm.ir.structural_equal(res.body, expected.body)
def test_loop():
......@@ -211,7 +211,7 @@ def test_loop():
expected = mod["main"].body
call = Function([], loop(const(1)))
res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected)
assert tvm.ir.structural_equal(res.body, expected)
def test_swap_loop():
......@@ -226,7 +226,7 @@ def test_swap_loop():
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = Function([], prog)
res = dcpe(res, mod=mod)
assert alpha_equal(prog, res.body)
assert tvm.ir.structural_equal(prog, res.body)
def test_abs_diff():
......@@ -248,7 +248,7 @@ def test_abs_diff():
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 4))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id():
......@@ -265,7 +265,7 @@ def test_match_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_nat_id():
......@@ -280,7 +280,7 @@ def test_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id():
......@@ -294,7 +294,7 @@ def test_global_match_nat_id():
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_double():
......@@ -304,7 +304,7 @@ def test_double():
orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 6))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
def test_concat():
......
......@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
......@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
......@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
......@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
......
......@@ -76,7 +76,7 @@ def test_order():
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output)
assert tvm.ir.structural_equal(anf, expected_output)
def test_if():
......@@ -93,7 +93,7 @@ def test_if():
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output)
assert tvm.ir.structural_equal(anf, expected_output)
# make sure we dont infinite loop.
......
......@@ -17,7 +17,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.analysis import detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude
......
......@@ -21,7 +21,6 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None):
......@@ -360,7 +359,7 @@ def test_let_polymorphism():
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body)
int32 = relay.TensorType((), "int32")
assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
if __name__ == "__main__":
......
......@@ -25,7 +25,7 @@ def test_const_saveload_json():
z = z + z
json_str = tvm.ir.save_json(z)
zz = tvm.ir.load_json(json_str)
assert tvm.ir.save_json(zz) == tvm.ir.save_json(z)
tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
def test_make_smap():
......@@ -38,6 +38,7 @@ def test_make_smap():
arr = tvm.ir.load_json(json_str)
assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"]
tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True)
def test_make_node():
......@@ -90,7 +91,6 @@ def test_env_func():
if __name__ == "__main__":
test_env_func()
test_make_attrs()
test_make_node()
test_make_smap()
test_const_saveload_json()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import pytest
from tvm import te
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
y = tvm.tir.const(10, "int32")
vx = te.var("x")
vy = te.var("y")
vz = te.var("z")
# test assert trigger.
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(x, y)
assert not tvm.ir.structural_equal(vx, vy)
assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
# corner case lhs:vx == rhs:vy, but cannot map it iteslf
assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
# corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
# corner case2: rolling remap.
assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
# Defintition remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
tvm.tir.Let(vy, 1, vy - 1))
# Default same address free var remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
tvm.tir.Let(vy, 1, vy // vz))
zx = vx + vx
zy = vy + vy
assert tvm.ir.structural_equal(zx * zx, zx * zx)
assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
def test_prim_func():
x = te.var('x')
y = te.var('y')
# counter example of same equality
func0 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(y + x))
assert not tvm.ir.structural_equal(func0, func1)
# new cases
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1))
func0 = tvm.tir.PrimFunc(
[x, y, b], stmt)
# easiest way to deep copy is via save/load
func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
tvm.ir.assert_structural_equal(func0, func1)
data0 = tvm.nd.array([1, 2, 3])
data1 = tvm.nd.array([1, 2, 3])
# attributes and ndarrays
func0 = func0.with_attr("data", data0)
func1 = func1.with_attr("data", data1)
# IRModules
mod0 = tvm.IRModule.from_expr(func0)
mod1 = tvm.IRModule.from_expr(func1)
tvm.ir.assert_structural_equal(mod0, mod1)
def test_attrs():
x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
tvm.ir.assert_structural_equal(y, x)
assert not tvm.ir.structural_equal(y, z)
if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment