Unverified Commit 997a14ed by Tianqi Chen Committed by GitHub

[NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)

* [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.

This PR introduces a new way to handle structural equality
for both TIR and relay nodes in an extensive way.

- Each object can now register an optional SEqualReduce function, which
  describes how to reduce its structural equality to another instance
  into equality of the children.
- Optionally, the object can choose to allow remapping of vars(e.g. function parameters)
  by calling DefEqual
- We implemented a non-recursive structural equality checker that
  recursively traverses the objects and does the structural equality checking.

This PR also fixes a few potential problems in previous relay's AlphaEqual.

- In particular, the new structural equality relation will be communicative.
- It is can be dangerous to use same_as relation to quickly check equality,
  demonstrated by the following case. (%x, %y) are shared vars between two functions.

- function0: fn (%x, %y) { %x + %y }
- function1: fn (%y, %x) { %x + %y }

The new structural equal is intented to supersede AlphaEqual and AttrsEqual.

Follow-up PRs should be performed to redirect the existing usages, and removes
the corresponding implementation.

* Update the rule to distinguish between graph node and non-graph nodes.

* Refactor the test cases to use structural equal.

* address comments

* Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal)

* Remove unrelated comment

* Fix file comment

* Address review comment

* Relax condition to fit flaky case
parent 9c806621
......@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
v->Visit("max_value", &max_value);
}
bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
}
/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
......@@ -170,6 +174,10 @@ class ModularSetNode : public Object {
v->Visit("base", &base);
}
bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
return equal(coeff, other->coeff) && equal(base, other->base);
}
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
......
......@@ -59,6 +59,7 @@ enum SignType {
class IntSetNode : public Object {
public:
static constexpr const char* _type_key = "IntSet";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};
......
......@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
return
equal(name_hint, other->name_hint) &&
equal(inputs, other->inputs);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
......@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
return
equal.DefEqual(header, other->header) &&
equal.DefEqual(type_vars, other->type_vars) &&
equal(constructors, other->constructors);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
......
......@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
......@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
......@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
}
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
......@@ -401,6 +408,33 @@ class AttrsEqualVisitor {
const AttrsEqual& equal_;
};
class AttrsSEqualVisitor {
public:
bool result_{true};
// constructor
AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}
private:
const Object* lhs_;
const Object* rhs_;
const SEqualReducer& equal_;
};
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
......@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
}
}
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
DerivedType* pself = self();
::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
......
......@@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
v->Visit("name", &name);
}
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
}
static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
......
......@@ -43,6 +43,7 @@ namespace tvm {
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
......@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
// name matters for global var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
......@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
......@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
......@@ -353,7 +369,12 @@ class RangeNode : public Object {
v->Visit("extent", &extent);
}
bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
return equal(min, other->min) && equal(extent, other->extent);
}
static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
......
......@@ -62,6 +62,8 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
}
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
......@@ -235,6 +237,7 @@ class IRModuleNode : public Object {
TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
......
......@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
v->Visit("support_level", &support_level);
}
bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
// pointer equality is fine as there is only one op with the same name.
return this == other;
}
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
......
......@@ -44,6 +44,10 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
return equal(name, other->name);
}
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
......@@ -87,6 +91,13 @@ class SpanNode : public Object {
v->Visit("col_offset", &col_offset);
}
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
return
equal(source, other->source) &&
equal(lineno, other->lineno) &&
equal(col_offset, other->col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "Span";
......
......@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
return
equal(shape, other->shape) &&
equal(dtype, other->dtype);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
......
......@@ -111,6 +111,7 @@ class PassContextNode : public Object {
}
static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
......@@ -207,6 +208,7 @@ class PassInfoNode : public Object {
}
static constexpr const char* _type_key = "transform.PassInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
......
......@@ -79,6 +79,7 @@ class TypeNode : public Object {
mutable Span span;
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
......@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}
bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
......@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
v->Visit("element_type", &element_type);
}
bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
......@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
......@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}
bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
......@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
......@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
return
equal.DefEqual(type_params, other->type_params) &&
equal(arg_types, other->arg_types) &&
equal(ret_type, other->ret_type) &&
equal(type_constraints, other->type_constraints);
}
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
......@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
}
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
......@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
......
......@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args);
}
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
......@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args) &&
equal(num_inputs, other->num_inputs) &&
equal(attrs, other->attrs);
}
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
......
......@@ -23,7 +23,9 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <type_traits>
#include <vector>
......@@ -34,15 +36,19 @@
namespace tvm {
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
using runtime::make_object;
using runtime::ObjectHash;
using runtime::ObjectEqual;
/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;
void VisitAttrs(AttrVisitor* visitor) {
}
static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};
......@@ -50,9 +56,6 @@ class ArrayNode : public Object {
/*! \brief map node content */
class MapNode : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
ObjectRef,
......@@ -73,9 +76,6 @@ class StrMapNode : public Object {
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;
void VisitAttrs(AttrVisitor* visitor) {
}
/*! \brief the data content */
ContainerType data;
......
......@@ -39,6 +39,8 @@
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <string>
#include <vector>
......
......@@ -29,13 +29,14 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/node/structural_equal.h>
#include <vector>
#include <string>
#include <type_traits>
namespace tvm {
// forward declaration
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
......@@ -87,6 +88,13 @@ class ReflectionVTable {
*/
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
* \brief Equality comparison function.
* \note We use function pointer, instead of std::function
* to reduce the dispatch overhead as field visit
* does not need as much customization.
*/
typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
......@@ -112,6 +120,14 @@ class ReflectionVTable {
*/
inline std::string GetGlobalKey(Object* self) const;
/*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
* \param other The pointer to another object to be compared.
* \param equal The equality comparator.
* \return the result.
*/
bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
/*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
......@@ -139,12 +155,14 @@ class ReflectionVTable {
TVM_DLL static ReflectionVTable* Global();
class Registry;
template<typename T>
template<typename T, typename TraitName>
inline Registry Register();
private:
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
/*! \brief Structural equal function. */
std::vector<FSEqualReduce> fsequal_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
......@@ -182,6 +200,44 @@ class ReflectionVTable::Registry {
uint32_t type_index_;
};
#define TVM_REFLECTION_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \
__make_reflectiion
/*!
* \brief Directly register reflection VTable.
* \param TypeName The name of the type.
* \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce.
*
* \code
*
* // Example SEQualReduce traits for runtime StringObj.
*
* struct StringObjTrait {
* static constexpr const std::nullptr_t VisitAttrs = nullptr;
*
* static bool SEqualReduce(const runtime::StringObj* lhs,
* const runtime::StringObj* rhs,
* SEqualReducer equal) {
* if (lhs == rhs) return true;
* if (lhs->size != rhs->size) return false;
* if (lhs->data != rhs->data) return true;
* return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
* }
* };
*
* TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
*
* \endcode
*
* \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
* And can be used to register the related reflection functions for runtime objects.
*/
#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
......@@ -189,15 +245,79 @@ class ReflectionVTable::Registry {
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
__make_Node ## _ ## TypeName ## __ = \
::tvm::ReflectionVTable::Global()->Register<TypeName>() \
TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
.set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \
})
// Implementation details
namespace detail {
template<typename T,
bool = T::_type_has_method_visit_attrs>
struct ImplVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T>
struct ImplVisitAttrs<T, true> {
static void VisitAttrs(T* self, AttrVisitor* v) {
self->VisitAttrs(v);
}
};
template<typename T,
bool = T::_type_has_method_sequal_reduce>
struct ImplSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T>
struct ImplSEqualReduce<T, true> {
static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
return self->SEqualReduce(other, equal);
}
};
template<typename T>
struct ReflectionTrait :
public ImplVisitAttrs<T>,
public ImplSEqualReduce<T> {
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
struct SelectVisitAttrs {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
};
template<typename T, typename TraitName>
struct SelectVisitAttrs<T, TraitName, false> {
static void VisitAttrs(Object* self, AttrVisitor* v) {
TraitName::VisitAttrs(static_cast<T*>(self), v);
}
};
template<typename T, typename TraitName,
bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
struct SelectSEqualReduce {
static constexpr const std::nullptr_t SEqualReduce = nullptr;
};
template<typename T, typename TraitName>
struct SelectSEqualReduce<T, TraitName, false> {
static bool SEqualReduce(const Object* self,
const Object* other,
SEqualReducer equal) {
return TraitName::SEqualReduce(static_cast<const T*>(self),
static_cast<const T*>(other),
equal);
}
};
} // namespace detail
template<typename T, typename TraitName>
inline ReflectionVTable::Registry
ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
......@@ -205,15 +325,15 @@ ReflectionVTable::Register() {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
fsequal_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
struct Functor {
static void VisitAttrs(Object* self, AttrVisitor* v) {
static_cast<T*>(self)->VisitAttrs(v);
}
};
fvisit_attrs_[tindex] =
::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
fsequal_[tindex] =
::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/structural_equal.h
* \brief Structural equality comparison.
*/
#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
#define TVM_NODE_STRUCTURAL_EQUAL_H_
#include <tvm/runtime/data_type.h>
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <string>
namespace tvm {
/*!
* \brief Equality definition of base value class.
*/
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
return lhs == rhs;
}
};
/*!
* \brief Content-aware structural equality comparator for objects.
*
* The structural equality is recursively defined in the DAG of IR nodes via SEqual.
* There are two kinds of nodes:
*
* - Graph node: a graph node in lhs can only be mapped as equal to
* one and only one graph node in rhs.
* - Normal node: equality is recursively defined without the restriction
* of graph nodes.
*
* Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
* For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
* to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
*
* A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
* with the same type if one of the following condition holds:
*
* - They appear in a same definition point(e.g. function argument).
* - They points to the same VarNode via the same_as relation.
* - They appear in a same usage point, and map_free_vars is set to be True.
*/
class StructuralEqual : public BaseValueEqual {
public:
// inheritate operator()
using BaseValueEqual::operator();
/*!
* \brief Compare objects via strutural equal.
* \param lhs The left operand.
* \param rhs The right operand.
* \return The comparison result.
*/
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
};
/*!
* \brief A Reducer class to reduce the structural equality result of two objects.
*
* The reducer will call the SEqualReduce function of each objects recursively.
* Importantly, the reducer may not directly use recursive calls to resolve the
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer : public BaseValueEqual {
public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
public:
/*!
* \brief Reduce condition to equality of lhs and rhs.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs,
const ObjectRef& rhs,
bool map_free_vars) = 0;
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
* This is an auxiliary method to check the Map<Var, Value> equality.
* \param lhs an lhs value.
*
* \return The corresponding rhs value if any, nullptr if not available.
*/
virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
/*!
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
};
using BaseValueEqual::operator();
/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
}
/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
*
* Call this function to compare definition points such as function params
* and var in a let-binding.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
return handler_->SEqualReduce(lhs, rhs, true);
}
/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
template<typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
* \param lhs The left operand.
* \param rhs The right operand.
* \return the result.
*/
bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
// var need to be remapped, so it belongs to graph node.
handler_->MarkGraphNode();
// We only map free vars if they corresponds to the same address
// or map free_var option is set to be true.
return lhs == rhs || map_free_vars_;
}
/*! \return Get the internal handler. */
Handler* operator->() const {
return handler_;
}
private:
/*! \brief Internal class pointer. */
Handler* handler_;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_;
};
} // namespace tvm
#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
......@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode;
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
......@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
return true;
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
......@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
return equal.DefEqual(var, other->var);
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
......@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
return
equal(constructor, other->constructor) &&
equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
......@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode {
v->Visit("span", &span);
}
bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
return equal(patterns, other->patterns);
}
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
......@@ -208,7 +227,12 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs);
}
bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
return equal(lhs, other->lhs) && equal(rhs, other->rhs);
}
static constexpr const char* _type_key = "relay.Clause";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
......@@ -248,6 +272,14 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(data, other->data) &&
equal(clauses, other->clauses) &&
equal(complete, other->complete);
}
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
......
......@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
......@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
......@@ -101,6 +106,16 @@ class TupleNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
// specially handle empty tuple as a constant is not a graph node.
if (fields.size() == other->fields.size() && fields.size() == 0) {
return true;
} else {
equal->MarkGraphNode();
return equal(fields, other->fields);
}
}
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
......@@ -157,6 +172,12 @@ class VarNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return
equal(type_annotation, other->type_annotation) &&
equal.FreeVarEqualImpl(this, other);
}
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
......@@ -238,6 +259,16 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip type_args check for primitive ops.
equal->MarkGraphNode();
return
equal(op, other->op) &&
equal(args, other->args) &&
equal(attrs, other->attrs) &&
(IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
......@@ -289,6 +320,14 @@ class LetNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
......@@ -336,6 +375,14 @@ class IfNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(cond, other->cond) &&
equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch);
}
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
......@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
return
equal(tuple, other->tuple) &&
equal(index, other->index);
}
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
......@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(value, other->value);
}
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
......@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(ref, other->ref);
}
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
......@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return
equal(ref, other->ref) &&
equal(value, other->value);
}
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
......@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
......
......@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
// Important to make def equal first.
equal->MarkGraphNode();
return
equal.DefEqual(params, other->params) &&
equal.DefEqual(type_params, other->type_params) &&
equal(ret_type, other->ret_type) &&
equal(attrs, other->attrs) &&
equal(body, other->body);
}
/*!
* \brief Return the derived function annotation of this expression.
*
......
......@@ -65,6 +65,8 @@ class NDArray : public ObjectRef {
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
/*! \return Whether the tensor is contiguous */
inline bool IsContiguous() const;
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
......@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size;
}
/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
* \return The check result.
*/
inline bool IsContiguous(const DLTensor& arr) {
if (arr.strides == nullptr) return true;
int64_t expected_stride = 1;
for (int32_t i = arr.ndim; i != 0; --i) {
int32_t k = i - 1;
if (arr.strides[k] != expected_stride) return false;
expected_stride *= arr.shape[k];
}
return true;
}
inline bool NDArray::IsContiguous() const {
return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor);
}
inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(get_mutable()->dl_tensor));
......
......@@ -211,11 +211,15 @@ class Object {
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// member information
static constexpr bool _type_has_method_visit_attrs = true;
static constexpr bool _type_has_method_sequal_reduce = false;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
......
......@@ -150,6 +150,20 @@ class BufferNode : public Object {
v->Visit("buffer_type", &buffer_type);
}
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables
// in its semantics, skip name as name is not important.
return
equal.DefEqual(data, other->data) &&
equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) &&
equal.DefEqual(strides, other->strides) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(scope, other->scope) &&
equal(data_alignment, other->data_alignment) &&
equal(buffer_type, other->buffer_type);
}
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
......@@ -169,6 +183,7 @@ class BufferNode : public Object {
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};
......
......@@ -75,6 +75,12 @@ class VarNode : public PrimExprNode {
v->Visit("type_annotation", &type_annotation);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
if (!equal(dtype, other->dtype)) return false;
if (!equal(type_annotation, other->type_annotation)) return false;
return equal.FreeVarEqualImpl(this, other);
}
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
......@@ -288,11 +294,20 @@ class IterVarNode : public Object {
v->Visit("thread_tag", &thread_tag);
}
bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
return
equal(dom, other->dom) &&
equal.DefEqual(var, other->var) &&
equal(iter_type, other->iter_type) &&
equal(thread_tag, other->thread_tag);
}
TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
};
......@@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
TVM_DLL PrimExpr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
......@@ -359,6 +378,10 @@ class CastNode : public PrimExprNode {
v->Visit("value", &value);
}
bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}
TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
static constexpr const char* _type_key = "Cast";
......@@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode {
v->Visit("b", &b);
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
......@@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode {
v->Visit("b", &b);
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
......@@ -539,6 +576,13 @@ class AndNode : public PrimExprNode {
v->Visit("b", &b);
}
bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "And";
......@@ -559,6 +603,13 @@ class OrNode : public PrimExprNode {
v->Visit("b", &b);
}
bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(a, other->a) &&
equal(b, other->b);
}
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "Or";
......@@ -576,6 +627,10 @@ class NotNode : public PrimExprNode {
v->Visit("a", &a);
}
bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(a, other->a);
}
TVM_DLL static PrimExpr make(PrimExpr a);
static constexpr const char* _type_key = "Not";
......@@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode {
v->Visit("false_value", &false_value);
}
bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(condition, other->condition) &&
equal(true_value, other->true_value) &&
equal(false_value, other->false_value);
}
TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
static constexpr const char* _type_key = "Select";
......@@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode {
v->Visit("predicate", &predicate);
}
bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(buffer_var, other->buffer_var) &&
equal(index, other->index) &&
equal(predicate, other->predicate);
}
TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
static constexpr const char* _type_key = "Load";
......@@ -673,6 +744,14 @@ class RampNode : public PrimExprNode {
v->Visit("lanes", &lanes);
}
bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(base, other->base) &&
equal(stride, other->stride) &&
equal(lanes, other->lanes);
}
TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
......@@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode {
v->Visit("lanes", &lanes);
}
bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(value, other->value) &&
equal(lanes, other->lanes);
}
TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
......@@ -718,6 +804,14 @@ class LetNode : public PrimExprNode {
v->Visit("body", &body);
}
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
static constexpr const char* _type_key = "Let";
......@@ -788,6 +882,16 @@ class CallNode : public PrimExprNode {
v->Visit("value_index", &value_index);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(name, other->name) &&
equal(args, other->args) &&
equal(call_type, other->call_type) &&
equal(func, other->func) &&
equal(value_index, other->value_index);
}
TVM_DLL static PrimExpr make(DataType dtype,
std::string name,
Array<PrimExpr> args,
......@@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode {
v->Visit("indices", &indices);
}
bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
return
equal(dtype, other->dtype) &&
equal(vectors, other->vectors) &&
equal(indices, other->indices);
}
TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
......@@ -918,7 +1029,16 @@ class CommReducerNode : public Object {
v->Visit("identity_element", &identity_element);
}
bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
return
equal.DefEqual(lhs, other->lhs) &&
equal.DefEqual(rhs, other->rhs) &&
equal(result, other->result) &&
equal(identity_element, other->identity_element);
}
static constexpr const char* _type_key = "CommReducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
};
......@@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode {
v->Visit("value_index", &value_index);
}
bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
// check axis first so IterVars can define the necessary variables.
return
equal(dtype, other->dtype) &&
equal(axis, other->axis) &&
equal(combiner, other->combiner) &&
equal(source, other->source) &&
equal(condition, other->condition) &&
equal(value_index, other->value_index);
}
static constexpr const char* _type_key = "Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
......@@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode {
class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) {}
bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
return true;
}
/*! \brief Convert to var. */
Var ToVar() const {
return Var("any_dim", DataType::Int(32));
......
......@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return
equal.DefEqual(params, other->params) &&
equal(buffer_map, other->buffer_map) &&
equal(ret_type, other->ret_type) &&
equal(body, other->body) &&
equal(attrs, other->attrs);
}
/*!
* \brief Return the derived function annotation of this function.
*
......@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode {
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
......
......@@ -38,6 +38,7 @@ namespace tir {
class StmtNode : public Object {
public:
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
......@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
return
equal.DefEqual(var, other->var) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
......@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
return
equal(node, other->node) &&
equal(attr_key, other->attr_key) &&
equal(value, other->value) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(ObjectRef node,
std::string type_key,
PrimExpr value,
......@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(message, other->message) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
......@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
......@@ -194,6 +224,14 @@ class StoreNode : public StmtNode {
v->Visit("predicate", &predicate);
}
bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var) &&
equal(value, other->value) &&
equal(index, other->index) &&
equal(predicate, other->predicate);
}
TVM_DLL static Stmt make(Var buffer_var,
PrimExpr value,
PrimExpr index,
......@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode {
v->Visit("args", &args);
}
bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(value, other->value) &&
equal(args, other->args);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
PrimExpr value,
......@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
return
equal.DefEqual(buffer_var, other->buffer_var) &&
equal(dtype, other->dtype) &&
equal(extents, other->extents) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
TVM_DLL static Stmt make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
......@@ -300,6 +355,11 @@ class FreeNode : public StmtNode {
v->Visit("buffer_var", &buffer_var);
}
bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
return
equal(buffer_var, other->buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
......@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode {
PrimExpr condition,
Stmt body);
bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds) &&
equal(condition, other->condition) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};
......@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode {
v->Visit("seq", &seq);
}
bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
return equal(seq, other->seq);
}
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
......@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode {
v->Visit("else_case", &else_case);
}
bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
return
equal(condition, other->condition) &&
equal(then_case, other->then_case) &&
equal(else_case, other->else_case);
}
TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
......@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode {
v->Visit("value", &value);
}
bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate";
......@@ -562,6 +647,16 @@ class ForNode : public StmtNode {
v->Visit("body", &body);
}
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return
equal.DefEqual(loop_var, other->loop_var) &&
equal(min, other->min) &&
equal(extent, other->extent) &&
equal(for_type, other->for_type) &&
equal(device_api, other->device_api) &&
equal(body, other->body);
}
static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
......@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode {
v->Visit("bounds", &bounds);
}
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(value_index, other->value_index) &&
equal(dtype, other->dtype) &&
equal(bounds, other->bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
......
......@@ -17,6 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .base import structural_equal, assert_structural_equal
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
......
......@@ -149,3 +149,76 @@ def save_json(node):
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
def structural_equal(lhs, rhs, map_free_vars=False):
"""Check structural equality of lhs and rhs.
The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Graph node: a graph node in lhs can only be mapped as equal to
one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
of graph nodes.
Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
with the same type if one of the following condition holds:
- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.
The rules for var are used to remap variables occurs in function
arguments and let-bindings.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Return
------
result : bool
The comparison result.
"""
return tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars)
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Raises
------
ValueError : if assertion does not hold.
See Also
--------
structural_equal
"""
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)
......@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
......@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range")
*ret = Range(args[0], args[1]);
});
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
......
......@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
data_ = std::move(n);
}
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
......@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node);
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
......
......@@ -21,11 +21,98 @@
* \file src/node/container.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <cstring>
namespace tvm {
// SEQualReduce traits for runtime containers.
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::StringObj* lhs,
const runtime::StringObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
if (lhs->data != rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::ADTObj* lhs,
const runtime::ADTObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->tag != rhs->tag) return false;
if (lhs->size != rhs->size) return false;
for (uint32_t i = 0; i < lhs->size; ++i) {
if (!equal((*lhs)[i], (*rhs)[i])) return false;
}
return true;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(lhs->dl_tensor))
<< "Can only compare contiguous tensor";
CHECK(runtime::IsContiguous(rhs->dl_tensor))
<< "Can only compare contiguous tensor";
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return false;
}
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const ArrayNode* lhs,
const ArrayNode* rhs,
SEqualReducer equal) {
if (lhs->data.size() != rhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(ArrayNode);
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<ArrayNode>();
});
TVM_REGISTER_GLOBAL("node.Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
......@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
static_cast<const ArrayNode*>(ptr)->data.size());
});
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const MapNode* lhs,
const MapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<MapNode>();
});
struct StrMapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static bool SEqualReduce(const StrMapNode* lhs,
const StrMapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(StrMapNode);
TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<StrMapNode>();
});
TVM_REGISTER_GLOBAL("node.Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
......
......@@ -180,7 +180,7 @@ ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/reflection.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
namespace tvm {
// Define the dispatch functio here since primary user is in this file.
bool ReflectionVTable::
SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
uint32_t tindex = self->type_index();
if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fsequal_[tindex](self, other, equal);
}
/*!
* \brief A non recursive stack based SEqual handler that can remaps vars.
*
* This handler pushs the Object equality cases into a stack, and
* traverses the stack to expand the necessary children that need to be checked.
*
* The order of SEqual being called is the same as the order as if we
* eagerly do recursive calls in SEqualReduce.
*/
class RemapVarSEqualHandler :
public SEqualReducer::Handler {
public:
explicit RemapVarSEqualHandler(bool assert_mode)
: assert_mode_(assert_mode) {}
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
// We cannot use check lhs.same_as(rhs) to check equality.
// if we choose to enable var remapping.
//
// Counter example below (%x, %y) are shared vars
// between the two functions(possibly before/after rewriting).
//
// - function0: fn (%x, %y) { %x + %y }
// - function1. fn (%y, %x) { %x + %y }
//
// Because we choose to enable var remapping,
// %x is mapped to %y, and %y is mapped to %x,
// the body of the function no longer means the same thing.
//
// Take away: We can either choose only compare Var by address,
// in which case we can use same_as for quick checking,
// or we have to run deep comparison and avoid to use same_as checks.
auto run = [=]() {
if (!lhs.defined() && !rhs.defined()) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// need to push to pending tasks in this case
pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
return true;
};
return CheckResult(run(), lhs, rhs);
}
void MarkGraphNode() final {
// need to push to pending tasks in this case
CHECK(!allow_push_to_stack_ && !task_stack_.empty());
task_stack_.back().graph_equal = true;
}
ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) return it->second;
return ObjectRef(nullptr);
}
// Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
task_stack_.clear();
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
CHECK_EQ(pending_tasks_.size(), 1U);
CHECK(allow_push_to_stack_);
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.clear();
return RunTasks();
}
protected:
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL)
<< "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
}
return result;
}
/*!
* \brief Run tasks until the stack reaches the stack begin
* \param stack_begin The expected beginning of the stack.
* \return The checks we encountered throughout the process.
*/
bool RunTasks() {
while (task_stack_.size() != 0) {
// Caution: entry becomes invalid when the stack changes
auto& entry = task_stack_.back();
if (entry.children_expanded) {
// When all the children has expanded and visited.
// This means all the condition checks for
// the current entry has been passed
// We can safely mark lhs and rhs as equal to each other.
auto it = equal_map_lhs_.find(entry.lhs);
if (it != equal_map_lhs_.end()) {
CHECK(it->second.same_as(entry.rhs));
}
// create the map if the quality is graph equal.
if (entry.graph_equal) {
equal_map_lhs_[entry.lhs] = entry.rhs;
equal_map_rhs_[entry.rhs] = entry.lhs;
}
task_stack_.pop_back();
} else {
// mark before expand
// Important: because entry becomes invalid when stack changes.
entry.children_expanded = true;
// Expand the objects
// The SEqual of the object can call into this->SEqualReduce
// which populates the pending tasks.
CHECK_EQ(pending_tasks_.size(), 0U);
allow_push_to_stack_ = false;
if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
allow_push_to_stack_ = true;
// Push pending tasks in reverse order, so earlier tasks get to
// expand first in the stack
while (pending_tasks_.size() != 0) {
task_stack_.emplace_back(std::move(pending_tasks_.back()));
pending_tasks_.pop_back();
}
}
}
return true;
}
// The default equal as registered in the structural equal vtable.
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
auto compute = [=]() {
CHECK(lhs.defined() &&
rhs.defined() &&
lhs->type_index() == rhs->type_index());
// skip entries that already have equality maps.
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
// Run reduce check for free nodes.
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
};
return CheckResult(compute(), lhs, rhs);
}
private:
/*! \brief Pending reduce tasks. */
struct Task {
/*! \brief The lhs operand to be compared. */
ObjectRef lhs;
/*! \brief The rhs operand to be compared. */
ObjectRef rhs;
/*! \brief The map free var argument. */
bool map_free_vars;
/*! \brief Whether the children has been expanded via SEqualReduce */
bool children_expanded{false};
/*! \brief whether the task is about graph equality(need remap). */
bool graph_equal{false};
Task() = default;
Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
: lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
};
// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task.
std::vector<Task> task_stack_;
// Whether we allow push to stack.
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
bool assert_mode,
bool map_free_vars) {
return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
});
bool StructuralEqual::operator()(const ObjectRef& lhs,
const ObjectRef& rhs) const {
return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
}
} // namespace tvm
......@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var")
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});
});
IterVar IterVarNode::make(Range dom,
Var var,
......@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) {
TVM_REGISTER_GLOBAL("tir.StringImm")
.set_body_typed(StringImmNode::make);
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
......@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) {
return PrimExpr(node);
}
PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
......@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
return PrimExpr(node);
}
PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
......@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) {
return PrimExpr(node);
}
PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
......
......@@ -1114,7 +1114,7 @@ def test_read_variable_op():
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
sess.close()
......
......@@ -17,8 +17,6 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import graph_equal, assert_graph_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
import pytest
from numpy import isclose
from typing import Union
......@@ -69,6 +67,13 @@ type List[A] {
}
"""
def assert_graph_equal(lhs, rhs):
tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
def graph_equal(lhs, rhs):
return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
def roundtrip(expr):
x = relay.fromtext(expr.astext())
assert_graph_equal(x, expr)
......@@ -86,6 +91,12 @@ def parses_as(code, expr):
result = graph_equal(parsed, expr)
return result
def assert_parses_as(code, expr):
parsed = parse_text(code)
assert_graph_equal(parsed, expr)
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
return x.data.asnumpy().item()
......@@ -102,7 +113,7 @@ UNIT = relay.Tuple([])
def test_comments():
assert parses_as(
assert_parses_as(
"""
// This is a line comment!
()
......@@ -110,7 +121,7 @@ def test_comments():
UNIT
)
assert parses_as(
assert_parses_as(
"""
/* This is a block comment!
This is still a block comment!
......@@ -120,7 +131,7 @@ def test_comments():
UNIT
)
assert parses_as(
assert_parses_as(
"""
/* This is a block comment!
/*Block comment is recursive!*/
......@@ -172,7 +183,7 @@ def test_negative():
def test_bin_op():
for bin_op in BINARY_OPS.keys():
assert parses_as(
assert_parses_as(
"1 {} 1".format(bin_op),
BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
)
......@@ -213,7 +224,7 @@ def test_vars():
def test_let():
assert parses_as(
assert_parses_as(
"let %x = 1; ()",
relay.Let(
X,
......@@ -222,7 +233,7 @@ def test_let():
)
)
assert parses_as(
assert_parses_as(
"""
let %x = 1;
let %y = 2;
......@@ -241,7 +252,7 @@ def test_let():
def test_seq():
assert parses_as(
assert_parses_as(
"();; ()",
relay.Let(
_,
......@@ -249,7 +260,7 @@ def test_seq():
UNIT)
)
assert parses_as(
assert_parses_as(
"let %_ = 1; ()",
relay.Let(
X,
......@@ -261,14 +272,10 @@ def test_seq():
def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
assert parses_as(
assert_parses_as(
code,
relay.Tuple([UNIT, UNIT, relay.const(1)])
)
assert not parses_as(
code,
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
)
@raises_parse_error
......@@ -287,18 +294,18 @@ def test_let_op():
def test_tuple():
assert parses_as("()", relay.Tuple([]))
assert_parses_as("()", relay.Tuple([]))
assert parses_as("(0,)", relay.Tuple([relay.const(0)]))
assert_parses_as("(0,)", relay.Tuple([relay.const(0)]))
assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
def test_func():
# 0 args
assert parses_as(
assert_parses_as(
"fn () { 0 }",
relay.Function(
[],
......@@ -309,7 +316,7 @@ def test_func():
)
# 1 arg
assert parses_as(
assert_parses_as(
"fn (%x) { %x }",
relay.Function(
[X],
......@@ -320,7 +327,7 @@ def test_func():
)
# 2 args
assert parses_as(
assert_parses_as(
"fn (%x, %y) { %x + %y }",
relay.Function(
[X, Y],
......@@ -331,7 +338,7 @@ def test_func():
)
# annotations
assert parses_as(
assert_parses_as(
"fn (%x: int32) -> int32 { %x }",
relay.Function(
[X_ANNO],
......@@ -342,7 +349,7 @@ def test_func():
)
# attributes
assert parses_as(
assert_parses_as(
"fn (n=5) { () }",
relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
)
......@@ -370,7 +377,7 @@ def test_recursive_call():
def test_ifelse():
assert parses_as(
assert_parses_as(
"""
if (True) {
0
......@@ -403,7 +410,7 @@ def test_ifelse_scope():
def test_call():
# select right function to call: simple ident case
id_func = relay.Var("id")
assert parses_as(
assert_parses_as(
"""
let %id = fn (%x) { %x };
10 * %id(10)
......@@ -417,7 +424,7 @@ def test_call():
# 0 args
constant = relay.Var("constant")
assert parses_as(
assert_parses_as(
"""
let %constant = fn () { 0 };
%constant()
......@@ -431,7 +438,7 @@ def test_call():
# 1 arg
id_var = relay.Var("id")
assert parses_as(
assert_parses_as(
"""
let %id = fn (%x) { %x };
%id(1)
......@@ -445,7 +452,7 @@ def test_call():
# 2 args
multiply = relay.Var("multiply")
assert parses_as(
assert_parses_as(
"""
let %multiply = fn (%x, %y) { %x * %y };
%multiply(0, 0)
......@@ -463,7 +470,7 @@ def test_call():
)
# anonymous function
assert parses_as(
assert_parses_as(
"""
(fn (%x) { %x })(0)
""",
......@@ -483,7 +490,7 @@ def test_call():
# TODO(@jmp): re-enable after sequence parsing improvements
# curried function
# curried_mult = relay.Var("curried_mult")
# assert parses_as(
# assert_parses_as(
# """
# let %curried_mult =
# fn (%x) {
......@@ -516,7 +523,7 @@ def test_call():
# )
# op
assert parses_as(
assert_parses_as(
"abs(1)",
relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
)
......@@ -525,7 +532,7 @@ def test_call():
def test_incomplete_type():
assert parses_as(
assert_parses_as(
"let %_ : _ = (); ()",
relay.Let(
_,
......@@ -541,7 +548,7 @@ def test_builtin_types():
def test_tensor_type():
assert parses_as(
assert_parses_as(
"let %_ : Tensor[(), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((), "float32")),
......@@ -550,7 +557,7 @@ def test_tensor_type():
)
)
assert parses_as(
assert_parses_as(
"let %_ : Tensor[(1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")),
......@@ -559,7 +566,7 @@ def test_tensor_type():
)
)
assert parses_as(
assert_parses_as(
"let %_ : Tensor[(1, 1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1, 1), "float32")),
......@@ -570,7 +577,7 @@ def test_tensor_type():
def test_function_type():
assert parses_as(
assert_parses_as(
"""
let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
""",
......@@ -581,7 +588,7 @@ def test_function_type():
)
)
assert parses_as(
assert_parses_as(
"""
let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
""",
......@@ -592,7 +599,7 @@ def test_function_type():
)
)
assert parses_as(
assert_parses_as(
"""
let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
""",
......@@ -605,7 +612,7 @@ def test_function_type():
def test_tuple_type():
assert parses_as(
assert_parses_as(
"""
let %_: () = (); ()
""",
......@@ -616,7 +623,7 @@ def test_tuple_type():
)
)
assert parses_as(
assert_parses_as(
"""
let %_: (int32,) = (0,); ()
""",
......@@ -627,7 +634,7 @@ def test_tuple_type():
)
)
assert parses_as(
assert_parses_as(
"""
let %_: (int32, int32) = (0, 1); ()
""",
......@@ -648,7 +655,7 @@ def test_adt_defn():
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
assert parses_as(
assert_parses_as(
"""
type Ayy { Nil }
""",
......@@ -662,7 +669,7 @@ def test_empty_adt_defn():
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
assert parses_as(
assert_parses_as(
"""
type Ayy { }
""",
......@@ -683,7 +690,7 @@ def test_multiple_cons_defn():
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
assert parses_as(LIST_DEFN, mod)
assert_parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
......@@ -699,7 +706,7 @@ def test_multiple_type_param_defn():
])
mod = tvm.IRModule()
mod[glob_typ_var] = prog
assert parses_as(
assert_parses_as(
"""
type Either[A, B] {
Left(A),
......@@ -755,7 +762,7 @@ def test_match():
)
mod[length_var] = length_func
assert parses_as(
assert_parses_as(
"""
%s
......@@ -796,7 +803,7 @@ def test_adt_cons_expr():
)
mod[make_singleton_var] = make_singleton_func
assert parses_as(
assert_parses_as(
"""
%s
......@@ -861,7 +868,7 @@ def test_extern_adt_defn():
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
assert parses_as(
assert_parses_as(
"""
extern type T[A]
""",
......@@ -872,6 +879,7 @@ def test_import_grad():
mod.import_from_std("gradient.rly")
if __name__ == "__main__":
test_graph()
test_comments()
test_int_literal()
test_float_literal()
......@@ -882,7 +890,6 @@ if __name__ == "__main__":
test_op_assoc()
test_let()
test_seq()
test_graph()
test_tuple()
test_func()
test_defn()
......
......@@ -21,23 +21,24 @@ from tvm import relay
from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass
def alpha_equal(x, y):
def sequal(x, y):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
return (tvm.ir.structural_equal(x, y) and
analysis.structural_hash(x) == analysis.structural_hash(y))
def alpha_equal_commutative(x, y):
def sequal_commutative(x, y):
"""
Check for commutative property of equality
"""
xy = analysis.alpha_equal(x, y)
yx = analysis.alpha_equal(y, x)
xy = tvm.ir.structural_equal(x, y)
yx = tvm.ir.structural_equal(y, x)
assert xy == yx
return xy
def test_tensor_type_alpha_equal():
def test_tensor_type_sequal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
t3 = relay.TensorType((3, 4, 5), "float32")
......@@ -49,7 +50,7 @@ def test_tensor_type_alpha_equal():
assert t1 == t2
def test_incomplete_type_alpha_equal():
def test_incomplete_type_sequal():
t1 = relay.IncompleteType(relay.TypeKind.ShapeVar)
t2 = relay.IncompleteType(relay.TypeKind.Type)
t3 = relay.IncompleteType(relay.TypeKind.Type)
......@@ -61,7 +62,7 @@ def test_incomplete_type_alpha_equal():
assert t2 != t3
def test_type_param_alpha_equal():
def test_type_param_sequal():
t1 = relay.TypeVar("v1", relay.TypeKind.Type)
t2 = relay.TypeVar("v2", relay.TypeKind.ShapeVar)
t3 = relay.TypeVar("v3", relay.TypeKind.Type)
......@@ -83,7 +84,7 @@ def test_type_param_alpha_equal():
assert ft1 != ft3 # kinds still do not match
def test_func_type_alpha_equal():
def test_func_type_sequal():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
......@@ -143,7 +144,7 @@ def test_func_type_alpha_equal():
assert ft != more_rels
def test_tuple_type_alpha_equal():
def test_tuple_type_sequal():
t1 = relay.TensorType((1, 2, 3), "float32")
t2 = relay.TensorType((1, 2, 3, 4), "float32")
tp1 = relay.TypeVar("v1", relay.TypeKind.Type)
......@@ -161,7 +162,7 @@ def test_tuple_type_alpha_equal():
assert tup1 != tup4
def test_type_relation_alpha_equal():
def test_type_relation_sequal():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32")
......@@ -197,7 +198,7 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs
def test_type_call_alpha_equal():
def test_type_call_sequal():
h1 = relay.GlobalTypeVar("h1")
h2 = relay.GlobalTypeVar("h2")
t1 = relay.TensorType((1, 2), "float32")
......@@ -221,49 +222,49 @@ def test_type_call_alpha_equal():
assert tc != different_order_args
def test_constant_alpha_equal():
def test_constant_sequal():
x = relay.const(1)
y = relay.const(2)
assert alpha_equal(x, x)
assert not alpha_equal(x, y)
assert alpha_equal(x, relay.const(1))
assert sequal(x, x)
assert not sequal(x, y)
assert sequal(x, relay.const(1))
def test_type_node_alpha_equal():
def test_type_node_sequal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)
assert not sequal(v1, v2)
v1 = relay.TypeVar('v1', 0)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)
assert not sequal(v1, v2)
assert alpha_equal_commutative(v1, v1)
assert sequal_commutative(v1, v1)
def test_type_node_incompatible_alpha_equal():
def test_type_node_incompatible_sequal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.Var("v2")
assert not alpha_equal_commutative(v1, v2)
assert not sequal_commutative(v1, v2)
def test_expr_node_incompatible_alpha_equal():
def test_expr_node_incompatible_sequal():
v1 = relay.Var("v1")
v2 = relay.PatternVar(relay.Var("v2"))
assert not alpha_equal_commutative(v1, v2)
assert not sequal_commutative(v1, v2)
def test_var_alpha_equal():
def test_var_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# normally only pointer equality
assert alpha_equal(v1, v1)
assert not alpha_equal(v1, v2)
assert sequal(v1, v1)
assert not sequal(v1, v2)
# let node allows for setting the eq_map
l1 = relay.Let(v1, relay.const(1), v1)
l2 = relay.Let(v2, relay.const(1), v2)
l3 = relay.Let(v1, relay.const(1), v2)
assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3)
assert sequal(l1, l2)
assert not sequal(l1, l3)
# type annotations
tt1 = relay.TensorType([], "int32")
......@@ -278,34 +279,34 @@ def test_var_alpha_equal():
l6 = relay.Let(v5, relay.const(1), v5)
# same annotations
assert alpha_equal(l4, l5)
assert sequal(l4, l5)
# different annotations
assert not alpha_equal(l4, l6)
assert not sequal(l4, l6)
# one null annotation
assert not alpha_equal(l1, l4)
assert not sequal(l1, l4)
def test_global_var_alpha_equal():
def test_global_var_sequal():
v1 = relay.GlobalVar("v1")
v2 = relay.GlobalVar("v2")
# only pointer equality suffices (smoke test)
assert alpha_equal(v1, v1)
assert not alpha_equal(v1, v2)
assert sequal(v1, v1)
assert not sequal(v1, v2)
def test_tuple_alpha_equal():
def test_tuple_sequal():
v0 = relay.Var("v0")
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# unit value is a valid tuple
assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
assert sequal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
assert alpha_equal(tup, same)
assert sequal(tup, same)
# use the eq_map
......@@ -315,33 +316,33 @@ def test_tuple_alpha_equal():
relay.Tuple([relay.const(4)])]),
v2)
assert alpha_equal(let_tup, let_mapped)
assert sequal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
assert not alpha_equal(tup, more_fields)
assert not sequal(tup, more_fields)
fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
assert not alpha_equal(tup, fewer_fields)
assert not sequal(tup, fewer_fields)
different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
relay.Tuple([relay.const(5)])])
assert not alpha_equal(tup, different_end)
assert not sequal(tup, different_end)
different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])])
assert not alpha_equal(tup, different_start)
assert not sequal(tup, different_start)
longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4), relay.const(5)])])
assert not alpha_equal(tup, longer_at_end)
assert not sequal(tup, longer_at_end)
def test_tuple_get_item_alpha_equal():
def test_tuple_get_item_sequal():
x = relay.Var('x')
y = relay.Var('y')
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_function_attr():
......@@ -364,10 +365,10 @@ def test_function_attr():
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
assert not alpha_equal(func0, func1)
assert not sequal(func0, func1)
def test_function_alpha_equal():
def test_function_sequal():
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2])
......@@ -389,58 +390,58 @@ def test_function_alpha_equal():
func = relay.Function([v1, v2], v1,
tt2, basic_tps)
mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
assert alpha_equal(func, mapped)
assert sequal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, fewer_params)
assert not sequal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2),
relay.Var("v2", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, more_params)
assert not sequal(func, more_params)
params_unordered = relay.Function([v2, v1], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_unordered)
assert not sequal(func, params_unordered)
params_mismatch = relay.Function([v1, v3], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_mismatch)
assert not sequal(func, params_mismatch)
# also would not typecheck
ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
assert not alpha_equal(func, ret_type_mismatch)
assert not sequal(func, ret_type_mismatch)
# also mis-typed
different_body = relay.Function(basic_args, v3, tt2, basic_tps)
assert not alpha_equal(func, different_body)
assert not sequal(func, different_body)
fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
assert not alpha_equal(func, fewer_type_params)
assert not sequal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params)
assert not sequal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered)
assert not sequal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
assert not alpha_equal(func, different_type_params)
assert not sequal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example)
assert not sequal(func, tupled_example)
# nullable
no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
# both null
assert alpha_equal(no_ret_type, no_ret_type)
assert sequal(no_ret_type, no_ret_type)
# one null
assert not alpha_equal(func, no_ret_type)
assert not alpha_equal(no_ret_type, func)
assert not sequal(func, no_ret_type)
assert not sequal(no_ret_type, func)
def test_call_alpha_equal():
def test_call_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
......@@ -458,43 +459,43 @@ def test_call_alpha_equal():
call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])],
attr1, [tt1])
same = relay.Call(v1, basic_args, attr1, [tt1])
assert alpha_equal(call, same)
assert sequal(call, same)
different_fn = relay.Call(v2, basic_args, attr1, [tt1])
assert not alpha_equal(call, different_fn)
assert not sequal(call, different_fn)
fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1])
assert not alpha_equal(call, fewer_args)
assert not sequal(call, fewer_args)
reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
relay.Tuple([]), v2], attr1, [tt1])
assert not alpha_equal(call, reordered_args)
assert not sequal(call, reordered_args)
different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)],
attr1, [tt1])
assert not alpha_equal(call, different_args)
assert not sequal(call, different_args)
more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]),
relay.const(3), relay.const(4)], attr1, [tt1])
assert not alpha_equal(call, more_args)
assert not sequal(call, more_args)
different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
assert not alpha_equal(call, different_attrs)
assert not sequal(call, different_attrs)
same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
assert alpha_equal(call, same_attrs)
assert sequal(call, same_attrs)
no_type_args = relay.Call(v1, basic_args, attr1)
assert not alpha_equal(call, no_type_args)
assert not sequal(call, no_type_args)
more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
assert not alpha_equal(call, more_type_args)
assert not sequal(call, more_type_args)
different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
assert not alpha_equal(call, different_type_arg)
assert not sequal(call, different_type_arg)
def test_let_alpha_equal():
def test_let_sequal():
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
v1 = relay.Var("v1")
......@@ -504,57 +505,57 @@ def test_let_alpha_equal():
let = relay.Let(v1, relay.const(2), v1)
mapped = relay.Let(v2, relay.const(2), v2)
assert alpha_equal(let, mapped)
assert sequal(let, mapped)
mismatched_var = relay.Let(v2, relay.const(2), v3)
assert not alpha_equal(let, mismatched_var)
assert not sequal(let, mismatched_var)
different_value = relay.Let(v2, relay.const(3), v2)
assert not alpha_equal(let, different_value)
assert not sequal(let, different_value)
different_body = relay.Let(v2, relay.const(3), relay.const(12))
assert not alpha_equal(let, different_body)
assert not sequal(let, different_body)
# specified types must match
let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
assert alpha_equal(let_with_type, same_type)
assert not alpha_equal(let, let_with_type)
assert sequal(let_with_type, same_type)
assert not sequal(let, let_with_type)
v2 = relay.Var("v1", tt2)
different_type = relay.Let(v2, relay.const(2), v2)
assert not alpha_equal(let_with_type, different_type)
assert not sequal(let_with_type, different_type)
def test_if_alpha_equal():
def test_if_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
assert alpha_equal(if_sample, same)
assert sequal(if_sample, same)
different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
assert not alpha_equal(if_sample, different_cond)
assert not sequal(if_sample, different_cond)
different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)]))
assert not alpha_equal(if_sample, different_true)
assert not sequal(if_sample, different_true)
different_false = relay.If(v1, relay.const(1), relay.Tuple([]))
assert not alpha_equal(if_sample, different_false)
assert not sequal(if_sample, different_false)
def test_constructor_alpha_equal():
def test_constructor_sequal():
# smoke test: it should be pointer equality
mod = tvm.IRModule()
p = relay.prelude.Prelude(mod)
assert alpha_equal(p.nil, p.nil)
assert alpha_equal(p.cons, p.cons)
assert not alpha_equal(p.nil, p.cons)
assert sequal(p.nil, p.nil)
assert sequal(p.cons, p.cons)
assert not sequal(p.nil, p.cons)
def test_match_alpha_equal():
def test_match_sequal():
mod = tvm.IRModule()
p = relay.prelude.Prelude(mod)
......@@ -604,27 +605,28 @@ def test_match_alpha_equal():
p.cons(x, p.nil()))
])
assert alpha_equal(match, match)
assert alpha_equal(match, equivalent)
assert not alpha_equal(match, no_cons)
assert not alpha_equal(match, no_nil)
assert not alpha_equal(match, empty)
assert not alpha_equal(match, different_data)
assert not alpha_equal(match, different_order)
assert not alpha_equal(match, different_nil)
assert not alpha_equal(match, different_cons)
assert not alpha_equal(match, another_case)
assert not alpha_equal(match, wrong_constructors)
def test_op_alpha_equal():
tvm.ir.assert_structural_equal(match, match)
assert sequal(match, match)
assert sequal(match, equivalent)
assert not sequal(match, no_cons)
assert not sequal(match, no_nil)
assert not sequal(match, empty)
assert not sequal(match, different_data)
assert not sequal(match, different_order)
assert not sequal(match, different_nil)
assert not sequal(match, different_cons)
assert not sequal(match, another_case)
assert not sequal(match, wrong_constructors)
def test_op_sequal():
# only checks names
op1 = relay.op.get("add")
op2 = relay.op.get("add")
assert alpha_equal(op1, op2)
assert sequal(op1, op2)
op3 = relay.op.get("take")
assert not alpha_equal(op1, op3)
assert not sequal(op1, op3)
def test_graph_equal():
......@@ -638,14 +640,14 @@ def test_graph_equal():
z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1)
assert alpha_equal(z0, z1)
assert sequal(z0, z1)
assert sequal(z0, z1)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
# Relay view them as different programs
# Check the difference in the text format.
assert not alpha_equal(z0, z3)
assert not sequal(z0, z3)
def test_hash_unequal():
x1 = relay.var("x1", shape=(10, 10), dtype="float32")
......@@ -677,7 +679,7 @@ def test_tuple_match():
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert analysis.alpha_equal(x, y)
assert sequal(x, y)
assert analysis.structural_hash(x) == analysis.structural_hash(y)
......@@ -697,34 +699,34 @@ def test_fn_attribute():
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
assert not sequal(add_1_fn, add_fn)
assert not sequal(add_fn, add_1_fn)
if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_node_alpha_equal()
test_type_node_incompatible_alpha_equal()
test_expr_node_incompatible_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_type_call_alpha_equal()
test_constant_alpha_equal()
test_global_var_alpha_equal()
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
test_function_alpha_equal()
test_tensor_type_sequal()
test_incomplete_type_sequal()
test_constant_sequal()
test_type_node_sequal()
test_type_node_incompatible_sequal()
test_expr_node_incompatible_sequal()
test_func_type_sequal()
test_tuple_type_sequal()
test_type_relation_sequal()
test_type_call_sequal()
test_constant_sequal()
test_global_var_sequal()
test_tuple_sequal()
test_tuple_get_item_sequal()
test_function_sequal()
test_function_attr()
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()
test_constructor_alpha_equal()
test_match_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()
test_call_sequal()
test_let_sequal()
test_if_sequal()
test_constructor_sequal()
test_match_sequal()
test_op_sequal()
test_var_sequal()
test_graph_equal()
test_hash_unequal()
test_fn_attribute()
......@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass):
def test_let():
orig = relay.Let(e.x, e.y, e.z)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c)
assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
......@@ -75,7 +75,7 @@ def test_inline():
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
orig = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
def use_f(func):
......@@ -111,13 +111,13 @@ def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
assert alpha_equal(dced, e.three)
assert tvm.ir.structural_equal(dced, e.three)
def test_op_let():
dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination())
assert alpha_equal(dced, add(e.three, e.two))
assert tvm.ir.structural_equal(dced, add(e.three, e.two))
def test_tuple_get_item():
......@@ -126,10 +126,10 @@ def test_tuple_get_item():
a = relay.Var('a')
g = relay.TupleGetItem(t, 0)
dced = run_opt_pass(g, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = run_opt_pass(orig, transform.DeadCodeElimination())
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
@pytest.mark.timeout(timeout=10, method="thread")
......
......@@ -72,7 +72,7 @@ def test_tuple():
f = Function([x], body, None, [t])
expected = relay.Function([x], x, None, [t])
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(dcpe(f), expected)
assert tvm.ir.structural_equal(dcpe(f), expected)
def test_const_inline():
......@@ -80,7 +80,7 @@ def test_const_inline():
d = Var("d", t)
double = Function([d], d + d)
orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0))
assert tvm.ir.structural_equal(dcpe(orig), const(8.0))
def test_ref():
......@@ -93,7 +93,7 @@ def test_ref():
body = Let(r, RefCreate(d), body)
square = Function([d], body)
expected = run_opt_pass(Function([d], d * d), transform.InferType())
assert alpha_equal(dcpe(square), expected)
assert tvm.ir.structural_equal(dcpe(square), expected)
def test_empty_ad():
......@@ -105,7 +105,7 @@ def test_empty_ad():
g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = run_opt_pass(expected, transform.InferType())
assert alpha_equal(g, expected)
assert tvm.ir.structural_equal(g, expected)
def test_ad():
......@@ -180,7 +180,7 @@ def test_head_cons():
body = hd(p.cons(x, p.nil()))
f = Function([x], body, None, [t])
res = dcpe(f, mod)
assert alpha_equal(res, Function([x], x, t, [t]))
assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_map():
......@@ -197,7 +197,7 @@ def test_map():
expected = mod["main"]
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body)
assert tvm.ir.structural_equal(res.body, expected.body)
def test_loop():
......@@ -211,7 +211,7 @@ def test_loop():
expected = mod["main"].body
call = Function([], loop(const(1)))
res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected)
assert tvm.ir.structural_equal(res.body, expected)
def test_swap_loop():
......@@ -226,7 +226,7 @@ def test_swap_loop():
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = Function([], prog)
res = dcpe(res, mod=mod)
assert alpha_equal(prog, res.body)
assert tvm.ir.structural_equal(prog, res.body)
def test_abs_diff():
......@@ -248,7 +248,7 @@ def test_abs_diff():
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 4))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id():
......@@ -265,7 +265,7 @@ def test_match_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_nat_id():
......@@ -280,7 +280,7 @@ def test_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id():
......@@ -294,7 +294,7 @@ def test_global_match_nat_id():
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 3))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_double():
......@@ -304,7 +304,7 @@ def test_double():
orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, make_nat_expr(p, 6))
assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
def test_concat():
......
......@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
......@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
......@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
......@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
......
......@@ -76,7 +76,7 @@ def test_order():
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output)
assert tvm.ir.structural_equal(anf, expected_output)
def test_if():
......@@ -93,7 +93,7 @@ def test_if():
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
assert alpha_equal(anf, expected_output)
assert tvm.ir.structural_equal(anf, expected_output)
# make sure we dont infinite loop.
......
......@@ -17,7 +17,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.analysis import detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude
......
......@@ -21,7 +21,6 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None):
......@@ -360,7 +359,7 @@ def test_let_polymorphism():
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body)
int32 = relay.TensorType((), "int32")
assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
if __name__ == "__main__":
......
......@@ -25,7 +25,7 @@ def test_const_saveload_json():
z = z + z
json_str = tvm.ir.save_json(z)
zz = tvm.ir.load_json(json_str)
assert tvm.ir.save_json(zz) == tvm.ir.save_json(z)
tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
def test_make_smap():
......@@ -38,6 +38,7 @@ def test_make_smap():
arr = tvm.ir.load_json(json_str)
assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"]
tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True)
def test_make_node():
......@@ -90,7 +91,6 @@ def test_env_func():
if __name__ == "__main__":
test_env_func()
test_make_attrs()
test_make_node()
test_make_smap()
test_const_saveload_json()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import pytest
from tvm import te
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
y = tvm.tir.const(10, "int32")
vx = te.var("x")
vy = te.var("y")
vz = te.var("z")
# test assert trigger.
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(x, y)
assert not tvm.ir.structural_equal(vx, vy)
assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
# corner case lhs:vx == rhs:vy, but cannot map it iteslf
assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
# corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
# corner case2: rolling remap.
assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
# Defintition remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
tvm.tir.Let(vy, 1, vy - 1))
# Default same address free var remap
assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
tvm.tir.Let(vy, 1, vy // vz))
zx = vx + vx
zy = vy + vy
assert tvm.ir.structural_equal(zx * zx, zx * zx)
assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
map_free_vars=False)
def test_prim_func():
x = te.var('x')
y = te.var('y')
# counter example of same equality
func0 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc(
[x, y], tvm.tir.Evaluate(y + x))
assert not tvm.ir.structural_equal(func0, func1)
# new cases
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1))
func0 = tvm.tir.PrimFunc(
[x, y, b], stmt)
# easiest way to deep copy is via save/load
func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
tvm.ir.assert_structural_equal(func0, func1)
data0 = tvm.nd.array([1, 2, 3])
data1 = tvm.nd.array([1, 2, 3])
# attributes and ndarrays
func0 = func0.with_attr("data", data0)
func1 = func1.with_attr("data", data1)
# IRModules
mod0 = tvm.IRModule.from_expr(func0)
mod1 = tvm.IRModule.from_expr(func1)
tvm.ir.assert_structural_equal(mod0, mod1)
def test_attrs():
x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
tvm.ir.assert_structural_equal(y, x)
assert not tvm.ir.structural_equal(y, z)
if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment