Unverified Commit 6cb5b882 by Tianqi Chen Committed by GitHub

[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (#5400)

Substitute now takes a std::function to customize more replacing behaviors.

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>
parent 8c0f7790
...@@ -38,3 +38,10 @@ tvm.tir.analysis ...@@ -38,3 +38,10 @@ tvm.tir.analysis
:members: :members:
:imported-members: :imported-members:
:autosummary: :autosummary:
tvm.tir.stmt_functor
--------------------
.. automodule:: tvm.tir.stmt_functor
:members:
:autosummary:
...@@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { ...@@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
} }
}; };
/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {
};
/*! /*!
* \brief Optional container that to represent to a Nullable variant of T. * \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef. * \tparam T The original ObjectRef.
...@@ -642,6 +646,8 @@ class Optional : public ObjectRef { ...@@ -642,6 +646,8 @@ class Optional : public ObjectRef {
* \param ptr * \param ptr
*/ */
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {} explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
Optional(NullOptType) {} // NOLINT(*)
// nullptr handling. // nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {} explicit Optional(std::nullptr_t) {}
...@@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> { ...@@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> {
// expose the functions to the root namespace. // expose the functions to the root namespace.
using runtime::String; using runtime::String;
using runtime::Optional; using runtime::Optional;
constexpr runtime::NullOptType NullOpt{};
} // namespace tvm } // namespace tvm
namespace std { namespace std {
......
...@@ -82,40 +82,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse ...@@ -82,40 +82,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse
TVM_DLL Stmt ConvertSSA(Stmt stmt); TVM_DLL Stmt ConvertSSA(Stmt stmt);
/*! /*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
/*!
* \brief Verify if there is any argument bound to compact buffer. * \brief Verify if there is any argument bound to compact buffer.
* *
* \param stmt The stmt to be verified. * \param stmt The stmt to be verified.
......
...@@ -20,17 +20,20 @@ ...@@ -20,17 +20,20 @@
/*! /*!
* \file tvm/tir/stmt_functor.h * \file tvm/tir/stmt_functor.h
* *
* \brief Functors for tir stmts. * \brief Functors for tir stmts
* utility functions to call common functors.
*/ */
#ifndef TVM_TIR_STMT_FUNCTOR_H_ #ifndef TVM_TIR_STMT_FUNCTOR_H_
#define TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/expr_functor.h> #include <tvm/tir/expr_functor.h>
#include <utility> #include <utility>
#include <unordered_map>
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -318,9 +321,9 @@ class StmtExprMutator : ...@@ -318,9 +321,9 @@ class StmtExprMutator :
}; };
/*! /*!
* \brief recursively visit the ir in post DFS order node, and transform it * \brief recursively visit the ir nodes in post DFS order, and transform it
* *
* \param node The ir to be transformed. * \param stmt The ir to be transformed.
* \param preorder The function called in before recursive mutation * \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call. * If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and * If preorder returns a not None Stmt/Expr, the transformer will simply return it and
...@@ -328,23 +331,76 @@ class StmtExprMutator : ...@@ -328,23 +331,76 @@ class StmtExprMutator :
* \param postorder The function called after recursive mutation. * \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation. * The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of runtime::String. * \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder * If it is null, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called * If it is not null, preorder/postorder will only be called
* when the IRNode's type key is in the list. * when the IRNode's type key is in the list.
*/ */
TVM_DLL Stmt IRTransform(Stmt node, TVM_DLL Stmt IRTransform(Stmt stmt,
const runtime::PackedFunc& preorder, const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder, const runtime::PackedFunc& postorder,
const Array<runtime::String>& only_enable = {}); Optional<Array<String>> only_enable = NullOpt);
/*! /*!
* \brief recursively visit the ir in post DFS order node, apply fvisit * \brief Recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once. * Each node is guaranteed to be visited only once.
* \param node The ir to be visited. * \param node The ir to be visited.
* \param fvisit The visitor function to be applied. * \param fvisit The visitor function to be applied.
*/ */
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit); TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
/*!
* \brief Substitute the var specified by vmap.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The converted form.
*/
TVM_DLL Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Substitute the var specified by vmap.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The result.
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}
/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var.get());
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
......
...@@ -72,7 +72,7 @@ def _pruned_source(func): ...@@ -72,7 +72,7 @@ def _pruned_source(func):
def replace_io(body, rmap): def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given""" """Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from tvm.tir import ir_pass from tvm.tir import stmt_functor
def replace(op): def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
...@@ -84,7 +84,7 @@ def replace_io(body, rmap): ...@@ -84,7 +84,7 @@ def replace_io(body, rmap):
_expr.Call.Halide, buf.op, buf.value_index) _expr.Call.Halide, buf.op, buf.value_index)
return None return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
def _is_tvm_arg_types(args): def _is_tvm_arg_types(args):
......
...@@ -48,3 +48,4 @@ from . import ir_builder ...@@ -48,3 +48,4 @@ from . import ir_builder
from . import ir_pass from . import ir_pass
from . import transform from . import transform
from . import analysis from . import analysis
from . import stmt_functor
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Statement functor utilities for IR transformations"""
from . import _ffi_api
def ir_transform(stmt, preorder, postorder, only_enable=None):
"""Recursively visit and transform ir nodes in post DFS order.
Parameters
----------
stmt : Stmt
The input to be transformed.
preorder: function
The function called in before recursive mutation
If preorder returns None, then the transform will proceed to recursive call.
If preorder returns a not None Stmt/Expr, the transformer will simply return it and
won't do further recursion.
postorder : function
The function called after recursive mutation.
only_enable : Optional[List[str]]
List of types that we only enable.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable)
def post_order_visit(stmt, fvisit):
"""Recursively visit the ir in post DFS order node, apply fvisit
Each node is guaranteed to be visited only once.
Parameters
----------
fvisit: function
The visitor function.
"""
return _ffi_api.PostOrderVisit(stmt, fvisit)
def substitute(node, vmap):
""" Substitute the var specified by vmap.
Parameters
----------
node: ObjectRef
The input.
vmap : Dict[Var, PrimExpr]
The variable mapping.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.Substitute(node, vmap)
...@@ -26,9 +26,10 @@ ...@@ -26,9 +26,10 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h> #include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h> #include <tvm/arith/util.h>
#include <tvm/tir/op.h>
#include <tvm/arith/pattern.h> #include <tvm/arith/pattern.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/data_type.h> #include <tvm/runtime/data_type.h>
namespace tvm { namespace tvm {
...@@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S, ...@@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*S)[i][j] = new_i_j; (*S)[i][j] = new_i_j;
} }
// We have to do the same with rhs // We have to do the same with rhs
PrimExpr ea = te::make_const((*y)[index].dtype(), a); PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
PrimExpr eb = te::make_const((*y)[i].dtype(), b); PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g); PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g); PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
(*y)[index] = new_index_rhs; (*y)[index] = new_index_rhs;
...@@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S, ...@@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*V)[i][j] = new_i_j; (*V)[i][j] = new_i_j;
} }
// And apply reverse transformations to new_to_old. // And apply reverse transformations to new_to_old.
PrimExpr ea = te::make_const((*x)[j].dtype(), a); PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
PrimExpr eb = te::make_const((*x)[index].dtype(), b); PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g); PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g); PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
(*x)[index] = new_index; (*x)[index] = new_index;
...@@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol ...@@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
IntConstraints( IntConstraints(
/*variables=*/{}, /*variables=*/{},
/*ranges=*/{}, /*ranges=*/{},
/*relations=*/{te::make_zero(DataType::Bool())}), /*relations=*/{tir::make_zero(DataType::Bool())}),
{}, {}); {}, {});
} else if (!tir::is_const_int(new_relation, 1)) { } else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation); new_relations.push_back(new_relation);
...@@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol ...@@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// The j-th variable is just a single value, don't create a tvm variable // The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn} // S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) { if (S[j][j] >= 0) {
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
solution_for_V_inv_x.push_back( solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(Uy[j], a))); analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else { } else {
// This is required because some simplifiers // This is required because some simplifiers
// have problems with dividing by negative numbers // have problems with dividing by negative numbers
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
solution_for_V_inv_x.push_back( solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(-Uy[j], a))); analyzer_problem.Simplify(floordiv(-Uy[j], a)));
} }
...@@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol ...@@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// V V^{-1} x = x // V V^{-1} x = x
for (size_t i = 0; i < num_vars; ++i) { for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype()); PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) { for (size_t j = 0; j < num_vars; ++j) {
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
} }
e = analyzer_problem.Simplify(e); e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e); old_to_new_map.Set(system_to_solve->variables[i], e);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \brief Utility for tensor-level auto-differentiation. * \brief Utility for tensor-level auto-differentiation.
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/stmt_functor.h>
#include <string> #include <string>
#include "ad_util.h" #include "ad_util.h"
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <unordered_set> #include <unordered_set>
......
...@@ -79,7 +79,7 @@ Stmt ReplaceTensor(Stmt stmt, ...@@ -79,7 +79,7 @@ Stmt ReplaceTensor(Stmt stmt,
* \param replace The replacement rule. * \param replace The replacement rule.
*/ */
PrimExpr ReplaceTensor(PrimExpr expr, PrimExpr ReplaceTensor(PrimExpr expr,
const std::unordered_map<Tensor, Tensor>& replace); const std::unordered_map<Tensor, Tensor>& replace);
/*! /*!
* \brief Substitute the variables of stmt by value map. * \brief Substitute the variables of stmt by value map.
......
...@@ -25,8 +25,9 @@ ...@@ -25,8 +25,9 @@
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/stmt_functor.h>
#include <unordered_set> #include <unordered_set>
#include "./op_util.h" #include "./op_util.h"
#include "./compute_op.h" #include "./compute_op.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <cctype> #include <cctype>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/stmt_functor.h>
#include <memory> #include <memory>
#include <limits> #include <limits>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
...@@ -363,8 +363,8 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b ...@@ -363,8 +363,8 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
value_map.Set(rhs[i], b[i]); value_map.Set(rhs[i], b[i]);
} }
return UpdateArray(result, [&value_map] (const PrimExpr& e) { return UpdateArray(result, [&value_map] (const PrimExpr& e) {
return Substitute(e, value_map); return Substitute(e, value_map);
}); });
} }
TVM_REGISTER_GLOBAL("tir.CommReducer") TVM_REGISTER_GLOBAL("tir.CommReducer")
......
...@@ -19,116 +19,14 @@ ...@@ -19,116 +19,14 @@
/*! /*!
* \file stmt_functor.cc * \file stmt_functor.cc
*/ */
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <functional>
#include "functor_common.h" #include "functor_common.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
// visitor to implement apply
class IRApplyVisit :
public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
f_(node);
}
void VisitStmt(const Stmt& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
StmtVisitor::VisitStmt(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node,
std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
visitor(Downcast<PrimExpr>(node));
}
}
class IRTransformer final :
public StmtExprMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt VisitStmt(const Stmt& stmt) final {
return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
return this->BaseVisitStmt(s);
});
}
PrimExpr VisitExpr(const PrimExpr& expr) final {
return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
return this->BaseVisitExpr(e);
});
}
private:
// NOTE: redirect to parent's call
// This is used to get around limitation of gcc-4.8
Stmt BaseVisitStmt(const Stmt& s) {
return StmtMutator::VisitStmt(s);
}
PrimExpr BaseVisitExpr(const PrimExpr& e) {
return ExprMutator::VisitExpr(e);
}
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
T new_node = fmutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(new_node);
if (post.defined()) return post;
}
return new_node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<runtime::String>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (auto s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
void StmtVisitor::VisitStmt_(const LetStmtNode* op) { void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
this->VisitExpr(op->value); this->VisitExpr(op->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
...@@ -511,6 +409,183 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) { ...@@ -511,6 +409,183 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
} }
// Implementations of IRTransform, PostOrderVisit and Substitute
class IRApplyVisit :
public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
f_(node);
}
void VisitStmt(const Stmt& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
StmtVisitor::VisitStmt(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node,
std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
visitor(Downcast<PrimExpr>(node));
}
}
class IRTransformer final :
public StmtExprMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt VisitStmt(const Stmt& stmt) final {
return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
return this->BaseVisitStmt(s);
});
}
PrimExpr VisitExpr(const PrimExpr& expr) final {
return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
return this->BaseVisitExpr(e);
});
}
private:
// NOTE: redirect to parent's call
// This is used to get around limitation of gcc-4.8
Stmt BaseVisitStmt(const Stmt& s) {
return StmtMutator::VisitStmt(s);
}
PrimExpr BaseVisitExpr(const PrimExpr& e) {
return ExprMutator::VisitExpr(e);
}
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
T new_node = fmutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(new_node);
if (post.defined()) return post;
}
return new_node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
Optional<Array<String>> only_enable) {
std::unordered_set<uint32_t> only_type_index;
if (only_enable.defined()) {
for (auto s : only_enable.value()) {
only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap)
: vmap_(vmap) {
}
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto ret = vmap_(var);
if (ret.defined()) return ret.value();
return std::move(var);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
return LoadNode::make(
op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
return StoreNode::make(
Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private:
std::function<Optional<PrimExpr>(const Var&)> vmap_;
};
Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(stmt));
}
PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(expr));
}
TVM_REGISTER_GLOBAL("tir.IRTransform")
.set_body_typed(IRTransform);
TVM_REGISTER_GLOBAL("tir.PostOrderVisit")
.set_body_typed([](ObjectRef node, PackedFunc f) {
tir::PostOrderVisit(node, [f](const ObjectRef& n) {
f(n);
});
});
TVM_REGISTER_GLOBAL("tir.Substitute")
.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef{
if (node->IsInstance<StmtNode>()) {
return Substitute(Downcast<Stmt>(node), vmap);
} else {
return Substitute(Downcast<PrimExpr>(node), vmap);
}
});
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -32,28 +32,13 @@ ...@@ -32,28 +32,13 @@
namespace tvm { namespace tvm {
namespace tir { namespace tir {
TVM_REGISTER_GLOBAL("ir_pass.Substitute")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
*ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
} else {
*ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
}
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
}); });
TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
tir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
f(n);
});
});
// make from two arguments // make from two arguments
#define REGISTER_PASS(PassName) \ #define REGISTER_PASS(PassName) \
...@@ -63,7 +48,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") ...@@ -63,7 +48,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA); REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA); REGISTER_PASS(VerifySSA);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
......
...@@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { ...@@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
} }
}); });
return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"}); return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
} }
// Remove IfThenElse node from a For node. // Remove IfThenElse node from a For node.
...@@ -185,9 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { ...@@ -185,9 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
} }
}); });
then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"}); then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) { if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"}); else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"IfThenElse"});
} }
return std::make_pair(then_for, else_for); return std::make_pair(then_for, else_for);
...@@ -408,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { ...@@ -408,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
*ret = new_for; *ret = new_for;
} }
}); });
return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")}); return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
} }
Stmt HoistIfThenElse(Stmt stmt) { Stmt HoistIfThenElse(Stmt stmt) {
......
...@@ -52,79 +52,7 @@ bool HasSideEffect(const PrimExpr& e) { ...@@ -52,79 +52,7 @@ bool HasSideEffect(const PrimExpr& e) {
return v.has_side_effect_; return v.has_side_effect_;
} }
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
const std::unordered_map<const VarNode*, PrimExpr>& smap)
: smap_(smap) {
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
return GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return LoadNode::make(
op->dtype, Downcast<Var>(it->second), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return StoreNode::make(
Downcast<Var>(it->second), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private:
const std::unordered_map<const VarNode*, PrimExpr>& smap_;
};
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map)(std::move(stmt));
}
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map)(std::move(expr));
}
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return Substitute(stmt, vmap);
}
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return Substitute(expr, vmap);
}
class VarTouchVisitor : public ExprVisitor { class VarTouchVisitor : public ExprVisitor {
public: public:
......
...@@ -29,7 +29,7 @@ def run_expr(expr, vranges): ...@@ -29,7 +29,7 @@ def run_expr(expr, vranges):
""" """
def _compute_body(*us): def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap) return tir.stmt_functor.substitute(expr, vmap)
A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)] args = [tvm.nd.empty(A.shape, A.dtype)]
...@@ -69,17 +69,17 @@ def check_solution(solution, vranges={}): ...@@ -69,17 +69,17 @@ def check_solution(solution, vranges={}):
cond_on_vars = tir.const(1, 'bool') cond_on_vars = tir.const(1, 'bool')
for v in constraints1.variables: for v in constraints1.variables:
# variable mapping is consistent # variable mapping is consistent
v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap))
cond_on_vars = te.all(cond_on_vars, v == v_back) cond_on_vars = te.all(cond_on_vars, v == v_back)
# Also we have to check that the new relations are true when old relations are true # Also we have to check that the new relations are true when old relations are true
cond_subst = tir.ir_pass.Substitute( cond_subst = tir.stmt_functor.substitute(
te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
# We have to include relations from vranges too # We have to include relations from vranges too
for v in constraints2.variables: for v in constraints2.variables:
if v in constraints2.ranges: if v in constraints2.ranges:
r = constraints2.ranges[v] r = constraints2.ranges[v]
range_cond = te.all(v >= r.min, v < r.min + r.extent) range_cond = te.all(v >= r.min, v < r.min + r.extent)
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) range_cond = tir.stmt_functor.substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond) cond_subst = te.all(cond_subst, range_cond)
cond_subst = ana.simplify(cond_subst) cond_subst = ana.simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
......
...@@ -201,7 +201,7 @@ def test_cuda_shuffle(): ...@@ -201,7 +201,7 @@ def test_cuda_shuffle():
def _transform(f, *_): def _transform(f, *_):
return f.with_body( return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]): with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
......
...@@ -685,7 +685,7 @@ def test_llvm_shuffle(): ...@@ -685,7 +685,7 @@ def test_llvm_shuffle():
def _transform(f, *_): def _transform(f, *_):
return f.with_body( return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
......
...@@ -24,7 +24,7 @@ from tvm.te.hybrid.runtime import HYBRID_GLOBALS ...@@ -24,7 +24,7 @@ from tvm.te.hybrid.runtime import HYBRID_GLOBALS
@pytest.mark.skip @pytest.mark.skip
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val): def tvm_val_2_py_val(val):
val = tvm.tir.ir_pass.Substitute(val, var_dict) val = tvm.tir.stmt_functor.substitute(val, var_dict)
val = tvm.arith.Analyzer().simplify(val) val = tvm.arith.Analyzer().simplify(val)
assert isinstance(val, (tvm.tir.IntImm,)) assert isinstance(val, (tvm.tir.IntImm,))
return val.value return val.value
......
...@@ -148,8 +148,8 @@ def test_bound_fusesplit1(): ...@@ -148,8 +148,8 @@ def test_bound_fusesplit1():
for k in range(1, 6): for k in range(1, 6):
vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")}) vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")})
tvm.testing.assert_prim_expr_equal( tvm.testing.assert_prim_expr_equal(
tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars),
tvm.tir.ir_pass.Substitute(expected_extent, vars) tvm.tir.stmt_functor.substitute(expected_extent, vars)
) )
tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l) tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l)
...@@ -170,10 +170,10 @@ def test_bound_fusesplit2(): ...@@ -170,10 +170,10 @@ def test_bound_fusesplit2():
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")}) vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")})
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2) tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3) tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1) tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3) tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3)
def test_bound_warp(): def test_bound_warp():
......
...@@ -155,7 +155,7 @@ def test_inline_mixed(): ...@@ -155,7 +155,7 @@ def test_inline_mixed():
def check(x): def check(x):
if isinstance(x, tvm.tir.Call): if isinstance(x, tvm.tir.Call):
assert x.func != A2 assert x.func != A2
tvm.tir.ir_pass.PostOrderVisit(s[C].op.body[0], check) tvm.tir.stmt_functor.post_order_visit(s[C].op.body[0], check)
def test_scan_inline1(): def test_scan_inline1():
...@@ -517,7 +517,7 @@ def test_local_stage_predicate(): ...@@ -517,7 +517,7 @@ def test_local_stage_predicate():
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret return ret
# local vs. threadIdx # local vs. threadIdx
s = schedule(tx, "local") s = schedule(tx, "local")
...@@ -563,7 +563,7 @@ def test_local_stage_predicate2(): ...@@ -563,7 +563,7 @@ def test_local_stage_predicate2():
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret return ret
def visit_stmt(op): def visit_stmt(op):
......
...@@ -264,7 +264,7 @@ def test_tuple_with_different_deps(): ...@@ -264,7 +264,7 @@ def test_tuple_with_different_deps():
x.func == B1.op and x.value_index == 1: x.func == B1.op and x.value_index == 1:
ret.append(x) ret.append(x)
ret = [] ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, get_B1_realize) tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize)
assert stmt.node == C.op and len(ret) == 1 assert stmt.node == C.op and len(ret) == 1
......
...@@ -32,7 +32,7 @@ def verify_structure(stmt, expected_struct): ...@@ -32,7 +32,7 @@ def verify_structure(stmt, expected_struct):
key = op key = op
if isinstance(op, tvm.tir.IfThenElse): if isinstance(op, tvm.tir.IfThenElse):
global var_list global var_list
tvm.tir.ir_pass.PostOrderVisit(op.condition, _extract_vars) tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
var_list.clear() var_list.clear()
elif isinstance(op, tvm.tir.For): elif isinstance(op, tvm.tir.For):
...@@ -43,7 +43,7 @@ def verify_structure(stmt, expected_struct): ...@@ -43,7 +43,7 @@ def verify_structure(stmt, expected_struct):
return return
node_dict[key] = val node_dict[key] = val
tvm.tir.ir_pass.PostOrderVisit(stmt, _visit) tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
for key, val in node_dict.items(): for key, val in node_dict.items():
struct[val[1]] = tuple(node_dict[child][1] if child in node_dict struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
else None for child in val[0]) else None for child in val[0])
......
...@@ -37,7 +37,7 @@ def test_ir_transform(): ...@@ -37,7 +37,7 @@ def test_ir_transform():
if op.name == "TestA": if op.name == "TestA":
return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
return op return op
body = tvm.tir.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"])
stmt_list = tvm.tir.stmt_list(body.body.body) stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0 assert stmt_list[1].value.value == 0
......
...@@ -54,7 +54,7 @@ def test_double_buffer(): ...@@ -54,7 +54,7 @@ def test_double_buffer():
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1 count[0] += 1
tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4 assert count[0] == 4
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret return ret
......
...@@ -20,7 +20,7 @@ import numpy ...@@ -20,7 +20,7 @@ import numpy
def collect_visit(stmt, f): def collect_visit(stmt, f):
ret = [] ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x)))
return ret return ret
......
...@@ -123,7 +123,7 @@ def test_flatten_double_buffer(): ...@@ -123,7 +123,7 @@ def test_flatten_double_buffer():
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1 count[0] += 1
tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4 assert count[0] == 4
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -45,7 +45,7 @@ def test_storage_share(): ...@@ -45,7 +45,7 @@ def test_storage_share():
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def register_mem(scope_tb, max_bits): def register_mem(scope_tb, max_bits):
...@@ -84,7 +84,7 @@ def test_alloc_seq(): ...@@ -84,7 +84,7 @@ def test_alloc_seq():
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
assert n.extents[0].value == 200 assert n.extents[0].value == 200
tvm.tir.ir_pass.PostOrderVisit(body, verify) tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def test_alloc_different_dtypes(): def test_alloc_different_dtypes():
...@@ -139,7 +139,7 @@ def test_alloc_different_dtypes(): ...@@ -139,7 +139,7 @@ def test_alloc_different_dtypes():
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body)) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
tvm.tir.ir_pass.PostOrderVisit(body, verify) tvm.tir.stmt_functor.post_order_visit(body, verify)
length = 1024 length = 1024
dtype_list = ["float16", "int32", "uint16", "int8"] dtype_list = ["float16", "int32", "uint16", "int8"]
...@@ -181,7 +181,7 @@ def test_inplace_rule(): ...@@ -181,7 +181,7 @@ def test_inplace_rule():
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2 assert num_alloc[0] == 2
...@@ -214,7 +214,7 @@ def test_storage_combine(): ...@@ -214,7 +214,7 @@ def test_storage_combine():
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
assert (n.extents[0].value == 16) assert (n.extents[0].value == 16)
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
...@@ -250,7 +250,7 @@ def test_storage_share_gpu(): ...@@ -250,7 +250,7 @@ def test_storage_share_gpu():
if isinstance(n, tvm.tir.AttrStmt): if isinstance(n, tvm.tir.AttrStmt):
if n.attr_key == "storage_scope": if n.attr_key == "storage_scope":
alloc_stats[n.value.value] += 1 alloc_stats[n.value.value] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert alloc_stats["global"] == 2 assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage assert alloc_stats["shared"] == num_stage
...@@ -318,7 +318,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): ...@@ -318,7 +318,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2 assert num_alloc[0] == 2
def test_exceed_mem(): def test_exceed_mem():
...@@ -407,7 +407,7 @@ def test_inplace_rule3(): ...@@ -407,7 +407,7 @@ def test_inplace_rule3():
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 70 assert n.extents[0].value == 70
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
def test_alloc_seq_type(): def test_alloc_seq_type():
ib = tvm.tir.ir_builder.create() ib = tvm.tir.ir_builder.create()
...@@ -437,7 +437,7 @@ def test_alloc_seq_type(): ...@@ -437,7 +437,7 @@ def test_alloc_seq_type():
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
assert n.extents[0].value == 500 assert n.extents[0].value == 500
tvm.tir.ir_pass.PostOrderVisit(body, verify) tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def test_alloc_seq_type2(): def test_alloc_seq_type2():
...@@ -469,7 +469,7 @@ def test_alloc_seq_type2(): ...@@ -469,7 +469,7 @@ def test_alloc_seq_type2():
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
assert n.extents[0].value == 200 assert n.extents[0].value == 200
tvm.tir.ir_pass.PostOrderVisit(body, verify) tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
...@@ -502,7 +502,7 @@ def test_reuse_small_buffer(): ...@@ -502,7 +502,7 @@ def test_reuse_small_buffer():
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1 num_alloc[0] += 1
assert n.extents[0].value == 800 assert n.extents[0].value == 800
tvm.tir.ir_pass.PostOrderVisit(body, verify) tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def test_replace_dataflow(): def test_replace_dataflow():
...@@ -540,7 +540,7 @@ def test_large_input(): ...@@ -540,7 +540,7 @@ def test_large_input():
def verify(n): def verify(n):
if isinstance(n, tvm.tir.Allocate): if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456 assert n.extents[0].value == 268435456
tvm.tir.ir_pass.PostOrderVisit(stmt, verify) tvm.tir.stmt_functor.post_order_visit(stmt, verify)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -70,7 +70,7 @@ print(ir) ...@@ -70,7 +70,7 @@ print(ir)
# #
# IR Visitor # IR Visitor
# ~~~~~~~~~~ # ~~~~~~~~~~
# We can use ``tvm.tir.ir_pass.PostOrderVisit(stmt, func)`` to gather information from the Halide IR. # We can use ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` to gather information from the Halide IR.
# ``func`` is a function callback. This function will be called before exiting the current IR node, # ``func`` is a function callback. This function will be called before exiting the current IR node,
# i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the # i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the
# return value of ``func`` will be ignored. # return value of ``func`` will be ignored.
...@@ -111,7 +111,7 @@ def vectorize8(op): ...@@ -111,7 +111,7 @@ def vectorize8(op):
extent = op.extent.value extent = op.extent.value
name = op.loop_var.name name = op.loop_var.name
lo, li = te.var(name + '.outer'), te.var(name + '.inner') lo, li = te.var(name + '.outer'), te.var(name + '.inner')
body = tvm.tir.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li}) body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body)
body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body)
return body return body
...@@ -121,7 +121,7 @@ def vectorize8(op): ...@@ -121,7 +121,7 @@ def vectorize8(op):
def vectorize(f, mod, ctx): def vectorize(f, mod, ctx):
global loops global loops
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8) tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
if not loops: if not loops:
return sf return sf
...@@ -129,7 +129,7 @@ def vectorize(f, mod, ctx): ...@@ -129,7 +129,7 @@ def vectorize(f, mod, ctx):
# The last list arugment indicates what kinds of nodes will be transformed. # The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8` # Thus, in this case only `For` nodes will call `vectorize8`
return f.with_body( return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For'])) tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For']))
##################################################################### #####################################################################
...@@ -161,8 +161,8 @@ with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg: ...@@ -161,8 +161,8 @@ with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
# Quick View # Quick View
# ---------- # ----------
# This tutorial gives a quick view of writing a customized IR transformation pass: # This tutorial gives a quick view of writing a customized IR transformation pass:
# - Use ``tvm.tir.ir_pass.PostOrderVisit`` to gather information on each IR nodes. # - Use ``tvm.tir.stmt_functor.post_order_visit`` to gather information on each IR nodes.
# - Use ``tvm.tir.ir_pass.IRTransform`` to transform IR nodes. # - Use ``tvm.tir.stmt_functor.ir_transform`` to transform IR nodes.
# - Wrap up two above to write an IR-transformation function. # - Wrap up two above to write an IR-transformation function.
# - Use ``tvm.target.build_config`` to put this function to TVM lowering pass # - Use ``tvm.target.build_config`` to put this function to TVM lowering pass
# #
...@@ -86,14 +86,14 @@ def FoldUopLoop(): ...@@ -86,14 +86,14 @@ def FoldUopLoop():
raise RuntimeError("unexpected op %s" % op) raise RuntimeError("unexpected op %s" % op)
return op return op
ret = tvm.tir.ir_pass.IRTransform( ret = tvm.tir.stmt_functor.ir_transform(
stmt.body, None, _post_order, ["Call"]) stmt.body, None, _post_order, ["Call"])
if not fail[0] and all(x is not None for x in gemm_offsets): if not fail[0] and all(x is not None for x in gemm_offsets):
def _visit(op): def _visit(op):
if op.same_as(loop_var): if op.same_as(loop_var):
fail[0] = True fail[0] = True
tvm.tir.ir_pass.PostOrderVisit(ret, _visit) tvm.tir.stmt_functor.post_order_visit(ret, _visit)
if not fail[0]: if not fail[0]:
begin = tvm.tir.call_extern( begin = tvm.tir.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
...@@ -131,7 +131,7 @@ def FoldUopLoop(): ...@@ -131,7 +131,7 @@ def FoldUopLoop():
return None return None
def _ftransform(f, mod, ctx): def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform( return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, _do_fold, None, ["AttrStmt"])) f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass( return tvm.tir.transform.prim_func_pass(
...@@ -187,7 +187,7 @@ def CPUAccessRewrite(): ...@@ -187,7 +187,7 @@ def CPUAccessRewrite():
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt_in = f.body stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform( stmt = tvm.tir.stmt_functor.ir_transform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items(): for buffer_var, new_var in rw_info.items():
...@@ -253,7 +253,7 @@ def LiftAllocToScopeBegin(): ...@@ -253,7 +253,7 @@ def LiftAllocToScopeBegin():
return _merge_block(lift_stmt.pop() + [op], op.body) return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt_in = f.body stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform( stmt = tvm.tir.stmt_functor.ir_transform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1 assert len(lift_stmt) == 1
return f.with_body(_merge_block(lift_stmt[0], stmt)) return f.with_body(_merge_block(lift_stmt[0], stmt))
...@@ -276,7 +276,7 @@ def InjectSkipCopy(): ...@@ -276,7 +276,7 @@ def InjectSkipCopy():
return None return None
def _ftransform(f, mod, ctx): def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform( return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, _do_fold, None, ["AttrStmt"])) f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass( return tvm.tir.transform.prim_func_pass(
...@@ -306,7 +306,7 @@ def InjectCoProcSync(): ...@@ -306,7 +306,7 @@ def InjectCoProcSync():
op.loop_var, op.min, 2, op.for_type, op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body) op.device_api, op.body)
return None return None
return f.with_body(tvm.tir.ir_pass.IRTransform( return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, None, _do_fold, ["AttrStmt"])) f.body, None, _do_fold, ["AttrStmt"]))
return tvm.transform.Sequential( return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"), [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
...@@ -635,7 +635,7 @@ def InjectConv2DTransposeSkip(): ...@@ -635,7 +635,7 @@ def InjectConv2DTransposeSkip():
def _do_fold(op): def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"): if _match_pragma(op, "conv2d_transpose_gemm"):
is_init = ".init" in str(op) is_init = ".init" in str(op)
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) tvm.tir.stmt_functor.post_order_visit(op, _find_basics)
if is_init: if is_init:
# create inner most block # create inner most block
...@@ -707,7 +707,7 @@ def InjectConv2DTransposeSkip(): ...@@ -707,7 +707,7 @@ def InjectConv2DTransposeSkip():
return inner return inner
return None return None
return func.with_body(tvm.tir.ir_pass.IRTransform( return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, _do_fold, None, ["AttrStmt"])) func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass( return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip") _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
...@@ -736,7 +736,7 @@ def AnnotateALUCoProcScope(): ...@@ -736,7 +736,7 @@ def AnnotateALUCoProcScope():
return tvm.tir.Evaluate(0) return tvm.tir.Evaluate(0)
return stmt return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform( return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, None, _do_fold, ["AttrStmt"])) func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass( return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope") _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
...@@ -955,7 +955,7 @@ def InjectALUIntrin(): ...@@ -955,7 +955,7 @@ def InjectALUIntrin():
return irb.get() return irb.get()
return stmt return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform( return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, None, _do_fold, ["AttrStmt"])) func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass( return tvm.tir.transform.prim_func_pass(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment