Unverified Commit e60003c2 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare (#5206)

* [REFACTOR][TIR] Introduce ExprDeepEqual, Remove IRDeepCompare

This PR introduces ExprDeepEqual which reuses the StructuralEqual infra.
We migrated the usecases of ir_pass::Equal to ExprDeepEqual and StructuralEqual.

* Address comments
parent 04499665
......@@ -24,10 +24,17 @@ tvm.tir
:autosummary:
tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:
tvm.tir.analysis
----------------
.. automodule:: tvm.tir.analysis
:members:
:imported-members:
:autosummary:
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/analysis.h
* \brief Analysis utilitie and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
namespace tvm {
namespace tir {
/*!
* \brief Compare two expressions recursively and check if they are equal
* to each other without var remapping.
*
* This function does not remap variable bindings, it will not
* return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
*
* Use StructuralEqual for such cases.
*
* Due to the restriction of not remapping variables, this function can run
* faster than StructuralEqual and can be used as a utility function during arithmetic
* simplifications.
*
* \sa StructuralEqual
*/
struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
......@@ -920,6 +920,17 @@ class FunctionBaseNode : public Object {
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;
// fall back to pointer equality now before refactor.
bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
return this == other;
}
void SHashReduce(SHashReducer hash_reduce) const {
}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
};
/*! \brief reference to a function */
......
......@@ -77,35 +77,6 @@ TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs);
/*!
* \brief Deep compare lhs and rhs
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Stmt& lhs, const Stmt& rhs);
/*!
* \brief Deep compare lhs and rhs.
*
* If you only want equality comparison, use Equal
* which will also tie definitions. The compare mode
* will give order of expression in total order.
*
* \param lhs The left operand
* \param rhs The right operand
* \return The comparison result.
*/
int Compare(const PrimExpr& lhs, const PrimExpr& rhs);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......
......@@ -22,7 +22,6 @@ import tvm.te
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
from tvm.tir import ir_pass
from tvm.tir import call_pure_intrin
from tvm.tir.stmt import For
......@@ -47,7 +46,7 @@ def _range(annotation, args):
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, const(0, dtype='int32')):
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
......
......@@ -56,7 +56,7 @@ def concat_list_to_block(lst):
def visit_list_to_block(visit, lst):
"""Visit and concatenate a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
lst = [stmt for stmt in lst if not tvm.ir.structural_equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
return concat_list_to_block(lst)
......@@ -178,7 +178,7 @@ class HybridParser(ast.NodeVisitor):
self.binds[val.var.name] = val
return
val_ = self.binds[val.var.name]
_internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
_internal_assert(tvm.tir.analysis.expr_deep_equal(val_.dom.extent, val.dom.extent),
"Thread extents should be uniform!")
self.symbols[key] = ty, val_
......@@ -525,7 +525,7 @@ class HybridParser(ast.NodeVisitor):
if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if not _ir_pass.Equal(low, tvm.runtime.const(0, 'int32')):
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, 'int32')):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
......
......@@ -198,6 +198,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars))
......@@ -225,6 +227,8 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
--------
structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, True, map_free_vars)
......
......@@ -46,3 +46,4 @@ from .op import comm_reducer, min, max, sum
from . import ir_builder
from . import ir_pass
from . import transform
from . import analysis
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR analysis utils."""
# pylint: disable=wildcard-import, invalid-name
from .analysis import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.analysis"""
import tvm._ffi
tvm._ffi._init_api("tir.analysis", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name
from . import _ffi_api
def expr_deep_equal(lhs, rhs):
"""Deeply compare two nested expressions.
Parameters
----------
lhs : PrimExpr
The left operand.
rhs : PrimExpr
The right operand.
Returns
-------
result : bool
The comparison result
Note
----
This function does not remap variable bindings, it will not
return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
Use py:func:`tvm.ir.structural_equal` to handle structural variable remapping.
Due to the restriction of not remapping variables, this function can run
faster than StructuralEqual and can be used as a utility function during arithmetic
simplifications.
Always consider py:func:`tvm.ir.structural_equal` first, which handles
the structural remapping.
See Also
--------
tvm.ir.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs)
......@@ -23,6 +23,8 @@
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/analysis.h>
#include "const_fold.h"
#include "pattern_match.h"
#include "rewrite_simplify.h"
......@@ -157,7 +159,7 @@ class SplitExpr : public PrimExpr {
inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true;
return tir::Equal(index, other->index);
return tir::ExprDeepEqual()(index, other->index);
}
inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
......
......@@ -138,10 +138,11 @@ class ConstIntBoundAnalyzer::Impl :
Entry VisitExpr(const PrimExpr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
tir::ExprDeepEqual equal;
// a linear search over additional info
// assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) {
if (tir::Equal(expr, info.expr)) {
if (equal(expr, info.expr)) {
res = Intersect(res, info.bound);
}
}
......
......@@ -66,6 +66,7 @@
#define TVM_ARITH_PATTERN_MATCH_H_
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tuple>
#include "const_fold.h"
......@@ -135,7 +136,7 @@ class PEqualChecker<PrimExpr> {
public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true;
return tir::Equal(lhs, rhs);
return tir::ExprDeepEqual()(lhs, rhs);
}
};
......
......@@ -101,11 +101,11 @@ TryCompare(const PrimExpr& x, int64_t val) {
}
void RewriteSimplifier::Impl::
Update(const Var& var, const PrimExpr& info, bool override) {
if (!override) {
Update(const Var& var, const PrimExpr& info, bool can_override) {
if (!can_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(Equal(it->second, info))
CHECK(ExprDeepEqual()(it->second, info))
<< "Trying to update var \'" << var << "\'"
<< " with a different value: "
<< "original=" << it->second
......@@ -1716,10 +1716,11 @@ VisitExpr_(const CallNode* op) {
return op->args[0] & op->args[1];
}
}
ExprDeepEqual expr_equal;
if (op->is_intrinsic(CallNode::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
}
......
......@@ -23,7 +23,9 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
#include "ir_mutator_with_analyzer.h"
......@@ -83,7 +85,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
op = stmt.as<StoreNode>();
if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
tir::ExprDeepEqual()(load->index, op->index)) {
return EvaluateNode::make(0);
}
}
......
......@@ -225,7 +225,6 @@ class RemapVarSEqualHandler :
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
};
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs,
const ObjectRef& rhs,
......
......@@ -25,6 +25,8 @@
#define TVM_RELAY_OP_NN_CONVOLUTION_H_
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <string>
#include <utility>
......@@ -158,8 +160,8 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(weight && weight->shape.defined()) <<
"Weight shape must be specified when groups is greater than 1.";
Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
if (tvm::tir::Equal(param->groups, dshape_nchw[1]) &&
tvm::tir::Equal(param->groups, wshape_oihw[0])) {
if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) &&
tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) {
is_depthwise = true;
}
}
......@@ -279,8 +281,9 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(param->kernel_size.size(), 3);
CHECK_EQ(param->dilation.size(), 3);
Array<IndexExpr> wshape;
tvm::tir::ExprDeepEqual expr_equal;
if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) {
if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution
wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]}};
......
......@@ -27,6 +27,8 @@
#include <tvm/relay/op.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/analysis.h>
#include "../../op/nn/convolution.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"
......@@ -86,7 +88,8 @@ Array<Array<Layout>> QnnConvInferCorrectLayout(const Attrs& attrs,
}
bool is_depthwise(const Conv2DAttrs* param) {
return param->channels.defined() && tvm::tir::Equal(param->channels, param->groups) &&
return param->channels.defined() &&
tvm::tir::ExprDeepEqual()(param->channels, param->groups) &&
param->groups != 1;
}
......
......@@ -27,6 +27,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <unordered_set>
#include <string>
......@@ -338,12 +339,14 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
Stmt VisitStmt_(const ForNode *op) final {
tir::ExprDeepEqual expr_equal;
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread;
if (iter_var->dom.defined()) {
CHECK(is_const_int(iter_var->dom->min, 0));
CHECK(Equal(iter_var->dom->extent, op->extent))
CHECK(expr_equal(iter_var->dom->extent, op->extent))
<< "Thread extent and loop extent mismatch!\n";
}
std::unordered_map<const VarNode *, PrimExpr> rmap;
......
......@@ -24,6 +24,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/runtime/registry.h>
#include "op_util.h"
......@@ -330,6 +331,7 @@ void VerifyTensorizeBody(
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
StructuralEqual expr_equal;
Map<Var, Range> compute_intrin_iter_space;
Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
&compute_intrin_iter_space);
......@@ -349,7 +351,7 @@ void VerifyTensorizeBody(
<< " provided=" << lhs.dtype()
<< ", intrin=" << rhs.dtype();
}
CHECK(Equal(lhs, rhs))
CHECK(expr_equal(lhs, rhs))
<< "Failed to match the compute with TensorIntrin "
<< intrin->name << "'s declaration "
<< " provided= " << lhs
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir/analysis/deep_equal.cc
* \brief Deep equality checking.
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
namespace tvm {
namespace tir {
class DeepCmpSEqualHandler :
public SEqualReducer::Handler {
public:
// use direct recursion.
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false));
}
ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
return ObjectRef(nullptr);
}
void MarkGraphNode() final {
}
private:
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
};
bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
// quick path
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
if (auto* plhs = lhs.as<IntImmNode>()) {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
}
TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal")
.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) {
return ExprDeepEqual()(lhs, rhs);
});
} // namespace tir
} // namespace tvm
......@@ -24,7 +24,9 @@
#include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/ir_pass.h>
#include <iterator>
#include <stack>
#include "../../arith/compute_expr.h"
......@@ -112,6 +114,8 @@ inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr &mult_expr,
const PrimExpr* search_ptr = inner;
PrimExpr mult_inner; // The inner multiplication factor
PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized
tir::ExprDeepEqual expr_equal;
while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<MulNode>();
......@@ -120,9 +124,9 @@ inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr &mult_expr,
return std::make_pair(false, PrimExpr());
} else if (inner_div_ptr) {
PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
if (Equal(overall_mult, inner_div_ptr->b)
&& Equal(overall_mult, mod_r_expr)
&& Equal(inner_div_ptr->a, mod_l_expr)) {
if (expr_equal(overall_mult, inner_div_ptr->b)
&& expr_equal(overall_mult, mod_r_expr)
&& expr_equal(inner_div_ptr->a, mod_l_expr)) {
// Found!
PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
return std::make_pair(true, ret);
......
......@@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});
TVM_REGISTER_GLOBAL("ir_pass.Equal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
} else {
*ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr());
}
});
TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() <= 3) {
......
......@@ -22,6 +22,7 @@
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h>
......@@ -255,7 +256,7 @@ class DeviceTypeBinder: public StmtExprMutator {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::Equal(op->a, op->b)) {
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
......
......@@ -25,6 +25,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/target_info.h>
#include <map>
......@@ -311,7 +312,7 @@ class InplaceOpVerifier : public StmtExprVisitor {
if (src_ == buf) {
if (store_ == nullptr ||
store_->value.dtype() != op->dtype ||
!tir::Equal(store_->index, op->index)) {
!tir::ExprDeepEqual()(store_->index, op->index)) {
result_ = false; return;
}
}
......
......@@ -22,6 +22,7 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <unordered_set>
......@@ -179,7 +180,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// TODO(tqchen) more standard set based testing.
if (e.touched.is_single_point() &&
x.touched.is_single_point()) {
if (Equal(e.touched.point_value(),
if (ExprDeepEqual()(e.touched.point_value(),
x.touched.point_value())) continue;
}
if (x.double_buffer_write &&
......
......@@ -26,11 +26,13 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
#include <map>
#include <unordered_map>
namespace tvm {
namespace tir {
......@@ -39,12 +41,6 @@ namespace tir {
// These information are needed during codegen.
class ContextCallCombiner final : public StmtExprMutator {
public:
struct CompareExpr {
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
return Compare(lhs, rhs) < 0;
}
};
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
......@@ -73,7 +69,7 @@ class ContextCallCombiner final : public StmtExprMutator {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
std::map<PrimExpr, Var, CompareExpr> temp;
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
std::swap(temp, ctx_map_);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
......@@ -86,7 +82,7 @@ class ContextCallCombiner final : public StmtExprMutator {
Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
std::map<PrimExpr, Var, CompareExpr> temp;
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
std::swap(temp, ctx_map_);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
......@@ -101,7 +97,8 @@ class ContextCallCombiner final : public StmtExprMutator {
}
private:
static Stmt BuildContext(const std::map<PrimExpr, Var, CompareExpr>& cmap,
static Stmt BuildContext(
const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap,
Stmt body) {
for (const auto& kv : cmap) {
body = LetStmtNode::make(kv.second, kv.first, body);
......@@ -109,7 +106,7 @@ class ContextCallCombiner final : public StmtExprMutator {
return body;
}
// Map of comparison expression to variable
std::map<PrimExpr, Var, CompareExpr> ctx_map_;
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_;
};
LoweredFunc CombineContextCall(LoweredFunc f) {
......
......@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/tir/analysis.h>
#include "../src/arith/pattern_match.h"
TEST(Pattern, Basic) {
......@@ -39,12 +40,13 @@ TEST(Pattern, Basic) {
{
CHECK((px + (py + px)).Match(r));
auto rr = (px + py).Eval();
CHECK(tir::Equal(rr, 1 + y));
CHECK(tir::Equal(px.Eval() + py.Eval(), 1 + y));
CHECK(tir::ExprDeepEqual()(rr, 1 + y));
CHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y));
}
{
CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK(tir::Equal(px.Eval(), x + 1));
CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1));
}
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z)));
......@@ -64,7 +66,7 @@ TEST(Pattern, Basic) {
{
CHECK(select(px >= pz, py, py + pz).Match(
tir::SelectNode::make((x + 1) >= 1, y, y + 1)));
CHECK(tir::Equal(px.Eval(), x + 1));
CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1));
}
// bit intrinsics
{
......@@ -90,7 +92,7 @@ TEST(Pattern, Basic) {
{
CHECK(select(px, py, pz).Match(
tir::SelectNode::make(x > 2, y, y + 1)));
CHECK(tir::Equal(pz.Eval(), y + 1));
CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1));
}
// if_then_else
{
......
......@@ -23,7 +23,9 @@ class CanonicalChecker:
def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
assert tvm.tir.ir_pass.Equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected)
expected = tvm.runtime.convert(expected)
assert tvm.ir.structural_equal(
res, expected), "\ndata={}\nres={}\nexpected={}".format(data, res, expected)
def test_mul_sum_simplify():
......@@ -197,7 +199,7 @@ def test_reduce_combiner_simplify():
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert tvm.tir.ir_pass.Equal(lhs, rhs)
assert tvm.ir.structural_equal(lhs, rhs)
# Test that components with side effects are not removed
side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0)
......
......@@ -45,7 +45,7 @@ def test_multivariate():
v = [te.var("v%d" % i) for i in range(4)]
b = te.var("b")
m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5))
assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5))
assert(m[1].value == 8)
m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
......
......@@ -28,7 +28,7 @@ class IntSetChecker:
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
def equal(x, y):
res = self.analyzer.canonical_simplify(x - y)
return tvm.tir.ir_pass.Equal(res, 0)
return tvm.tir.analysis.expr_deep_equal(res, 0)
assert equal(res.min_value, expected[0]), err_msg()
assert equal(res.max_value, expected[1]), err_msg()
......
......@@ -23,7 +23,7 @@ class RewriteChecker:
def verify(self, data, expected):
res = self.analyzer.rewrite_simplify(data)
assert tvm.tir.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
assert tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
def test_vector_simplify():
......
......@@ -182,7 +182,7 @@ def test_fanout():
assert isinstance(ir, tvm.tir.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert tvm.tir.ir_pass.Equal(ir.extent, n - 3)
assert tvm.ir.structural_equal(ir.extent, n - 3)
#Check loopbody
ibody = ir.body
assert isinstance(ibody, tvm.tir.AttrStmt)
......@@ -215,7 +215,7 @@ def test_fanout():
assert value.a.args[0].value == 0
assert value.b.name == 'a'
assert len(value.b.args) == 1
assert tvm.tir.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
assert tvm.ir.structural_equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody[2]
assert isinstance(divide, tvm.tir.Provide)
assert len(divide.args) == 1
......
......@@ -100,12 +100,13 @@ def test_tensorize_vadd():
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[z], dom_map)
assert tvm.tir.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
assert tvm.tir.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
assert tvm.tir.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].extent, factor)
assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].min, xo * factor)
assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z], out_dom, in_dom, vadd)
assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]),
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
......@@ -133,12 +134,13 @@ def test_tensorize_matmul():
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor)
assert tvm.ir.structural_equal(out_dom[x].extent, 1)
assert tvm.ir.structural_equal(out_dom[y].extent, factor)
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]),
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -157,12 +159,13 @@ def test_tensorize_matmul():
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor)
assert tvm.ir.structural_equal(out_dom[x].extent, 1)
assert tvm.ir.structural_equal(out_dom[y].extent, factor)
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]),
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -180,12 +183,13 @@ def test_tensorize_matmul():
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor)
assert tvm.ir.structural_equal(out_dom[x].extent, 1)
assert tvm.ir.structural_equal(out_dom[y].extent, factor)
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]),
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......@@ -204,12 +208,13 @@ def test_tensorize_matmul():
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.tir.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.tir.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.tir.ir_pass.Equal(out_dom[y].min, yo * factor)
assert tvm.ir.structural_equal(out_dom[x].extent, 1)
assert tvm.ir.structural_equal(out_dom[y].extent, factor)
assert tvm.ir.structural_equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.CanonicalSimplify(body[0]),
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.CanonicalSimplify(body[0]),
tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
......
......@@ -27,40 +27,10 @@ def test_equal_expr():
def func2():
return te.exp(tvm.tir.truncdiv((x + y + 1) * y, 4))
assert tvm.tir.ir_pass.Equal(func1(), func1())
assert tvm.tir.ir_pass.Equal(func2(), func2())
assert not tvm.tir.ir_pass.Equal(func2(), func1())
def test_equal_compute():
x = te.var('x')
y = te.var('y')
n = 128
A = te.placeholder((n, n), name='A')
B = te.placeholder((n, n), name='B')
ii = te.var('i')
jj = te.var('j')
def func1():
k = te.reduce_axis((0, n), name='k')
return te.sum(A[ii, k] * B[jj, k], axis=k)
Ab = tvm.tir.decl_buffer((n,), name='A')
n = te.var("n")
def func2():
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
A[j] = A[j] + 2
return ib.get()
assert tvm.tir.ir_pass.Equal(func1(), func1())
assert tvm.tir.ir_pass.Equal(func2(), func2())
assert tvm.tir.analysis.expr_deep_equal(func1(), func1())
assert tvm.tir.analysis.expr_deep_equal(func2(), func2())
assert not tvm.tir.analysis.expr_deep_equal(func2(), func1())
if __name__ == "__main__":
test_equal_expr()
test_equal_compute()
......@@ -36,7 +36,7 @@ def test_buffer_access_ptr():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw")
assert tvm.tir.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m)
assert aptr.args[0].dtype == Ab.dtype
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("w")
......@@ -49,16 +49,16 @@ def test_buffer_access_ptr_offset():
Ab = tvm.tir.decl_buffer((m, n), "float32")
aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.tir.ir_pass.Equal(offset, 100)
assert tvm.ir.structural_equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
v = te.size_var('int32')
aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.tir.ir_pass.Equal(offset, 200 + v)
assert tvm.ir.structural_equal(offset, 200 + v)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v))
offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
assert tvm.tir.ir_pass.Equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v))
assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
......@@ -67,12 +67,12 @@ def test_buffer_access_ptr_extent():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((m, n), "float32")
aptr = Ab.access_ptr("rw")
assert tvm.tir.ir_pass.Equal(aptr.args[3], m * n)
assert tvm.ir.structural_equal(aptr.args[3], m * n)
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.tir.ir_pass.Equal(aptr.args[3], m * n - 100)
assert tvm.ir.structural_equal(aptr.args[3], m * n - 100)
Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.tir.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
def test_buffer_vload():
......@@ -81,7 +81,7 @@ def test_buffer_vload():
Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.tir.ir_pass.Simplify(load.index)
assert tvm.tir.ir_pass.Equal(offset, n * 2 + 103)
assert tvm.ir.structural_equal(offset, n * 2 + 103)
def test_buffer_index_merge_mult_mod():
......@@ -93,7 +93,7 @@ def test_buffer_index_merge_mult_mod():
A = tvm.tir.decl_buffer((m, n), "float32")
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.tir.ir_pass.Equal(index_simplified, index_direct),\
assert tvm.ir.structural_equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
......
......@@ -71,7 +71,7 @@ def test_const_fold3():
for tvm_func, py_func in [(tvm.tir.all, lambda a, b: a and b), (tvm.tir.any, lambda a, b: a or b)]:
for v1 in [0, 1]:
for v2 in [0, 1]:
assert tvm.tir.ir_pass.Equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')),
assert tvm.ir.structural_equal(tvm_func(tvm.tir.const(v1, 'uint1'), tvm.tir.const(v2, 'uint1')),
tvm.tir.const(py_func(v1, v2), 'uint1'))
x = te.var("x", 'uint1')
......@@ -170,13 +170,13 @@ def test_if_then_else():
out = tvm.tir.if_then_else(cond, lhs, rhs)
out2 = tvm.tir.if_then_else(not cond, rhs, lhs)
out3 = tvm.tir.if_then_else(not cond, lhs, rhs)
assert tvm.tir.ir_pass.Equal(out, out2) == 1
assert tvm.ir.structural_equal(out, out2) == 1
if cond:
assert tvm.tir.ir_pass.Equal(out, lhs.astype(out_dtype)) == 1
assert tvm.tir.ir_pass.Equal(out3, rhs.astype(out_dtype)) == 1
assert tvm.ir.structural_equal(out, lhs.astype(out_dtype)) == 1
assert tvm.ir.structural_equal(out3, rhs.astype(out_dtype)) == 1
else:
assert tvm.tir.ir_pass.Equal(out, rhs.astype(out_dtype)) == 1
assert tvm.tir.ir_pass.Equal(out3, lhs.astype(out_dtype)) == 1
assert tvm.ir.structural_equal(out, rhs.astype(out_dtype)) == 1
assert tvm.ir.structural_equal(out3, lhs.astype(out_dtype)) == 1
elif cond.dtype == 'bool':
out = tvm.tir.if_then_else(cond, lhs, rhs)
assert out.dtype == out_dtype
......
......@@ -22,11 +22,11 @@ def test_simplify():
tmod = tvm.tir.truncmod
x = te.var('x')
e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1)
assert(tvm.tir.ir_pass.Equal(e1, x + 3))
assert(tvm.ir.structural_equal(e1, x + 3))
e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x)
assert(tvm.tir.ir_pass.Equal(e2, x * 8))
assert(tvm.ir.structural_equal(e2, x * 8))
e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3)
assert(tvm.tir.ir_pass.Equal(e3, tmod(x, 3)))
assert(tvm.ir.structural_equal(e3, tmod(x, 3)))
def test_verify_ssa():
......
......@@ -444,7 +444,7 @@ def test_simple_rfactor():
stmt2 = tvm.tir.ir_pass.Simplify(stmt2)
#make sure loop partition actually did something
assert not tvm.tir.ir_pass.Equal(stmt1.body, stmt2.body)
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
if __name__ == "__main__":
......
......@@ -142,9 +142,34 @@ def test_attrs():
assert not consistent_equal(y, z)
def test_stmt():
x = te.var('x')
y = te.var('y')
n = 128
A = te.placeholder((n, n), name='A')
B = te.placeholder((n, n), name='B')
ii = te.var('i')
jj = te.var('j')
Ab = tvm.tir.decl_buffer((n,), name='A')
n = te.var("n")
def func2():
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
A[j] = A[j] + 2
return ib.get()
assert consistent_equal(func2(), func2())
if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
test_array()
test_env_func()
test_stmt()
......@@ -43,7 +43,7 @@ def test_prim_func_pass():
mod = tvm.IRModule({"main": func})
mod = TestReplaceFunc(new_func)(mod)
assert tvm.tir.ir_pass.Equal(mod["main"].body, new_func.body)
assert tvm.ir.structural_equal(mod["main"].body, new_func.body)
if __name__ == "__main__":
......
......@@ -26,6 +26,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <string>
#include <vector>
......@@ -114,10 +115,11 @@ inline std::vector<int64_t> GetConstInt64Values(
* \return result True if both expressions are equal, else false
*/
inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
bool result = tvm::tir::Equal(lhs, rhs);
tvm::tir::ExprDeepEqual expr_equal;
bool result = expr_equal(lhs, rhs);
if (!result) {
PrimExpr zero(0);
result = tvm::tir::Equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero);
result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero);
}
return result;
}
......
......@@ -83,7 +83,7 @@ def fold_uop_loop(stmt_in):
fail[0] = True
return op
if gemm_offsets[i] is not None:
if not tvm.tir.ir_pass.Equal(m[0], gemm_offsets[i]):
if not tvm.ir.structural_equal(m[0], gemm_offsets[i]):
fail[0] = True
return op
args.append(m[1])
......@@ -775,7 +775,7 @@ def inject_alu_intrin(stmt_in):
def _do_fold(stmt):
def _equal(x, y):
return tvm.tir.ir_pass.Equal(tvm.tir.ir_pass.Simplify(x - y), 0)
return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
......@@ -895,9 +895,9 @@ def inject_alu_intrin(stmt_in):
lhs_equal = True
rhs_equal = True
for i, coef in enumerate(dst_coeff):
if not tvm.tir.ir_pass.Equal(coef, src_lhs_coeff[i]):
if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
lhs_equal = False
if not tvm.tir.ir_pass.Equal(coef, src_rhs_coeff[i]):
if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
rhs_equal = False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
......@@ -916,20 +916,20 @@ def inject_alu_intrin(stmt_in):
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.tir.ir_pass.Equal(
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.Simplify(
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.tir.ir_pass.Equal(
assert tvm.ir.structural_equal(
tvm.tir.ir_pass.Simplify(
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.tir.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.tir.ir_pass.Equal(dst_coeff[-2], 1)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.tir.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.tir.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment