Unverified Commit 6cb5b882 by Tianqi Chen Committed by GitHub

[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (#5400)

Substitute now takes a std::function to customize more replacing behaviors.

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>
parent 8c0f7790
......@@ -38,3 +38,10 @@ tvm.tir.analysis
:members:
:imported-members:
:autosummary:
tvm.tir.stmt_functor
--------------------
.. automodule:: tvm.tir.stmt_functor
:members:
:autosummary:
......@@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
}
};
/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {
};
/*!
* \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef.
......@@ -642,6 +646,8 @@ class Optional : public ObjectRef {
* \param ptr
*/
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
Optional(NullOptType) {} // NOLINT(*)
// nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {}
......@@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> {
// expose the functions to the root namespace.
using runtime::String;
using runtime::Optional;
constexpr runtime::NullOptType NullOpt{};
} // namespace tvm
namespace std {
......
......@@ -82,40 +82,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse
TVM_DLL Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
......
......@@ -20,17 +20,20 @@
/*!
* \file tvm/tir/stmt_functor.h
*
* \brief Functors for tir stmts.
* \brief Functors for tir stmts
* utility functions to call common functors.
*/
#ifndef TVM_TIR_STMT_FUNCTOR_H_
#define TVM_TIR_STMT_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/expr_functor.h>
#include <utility>
#include <unordered_map>
namespace tvm {
namespace tir {
......@@ -318,9 +321,9 @@ class StmtExprMutator :
};
/*!
* \brief recursively visit the ir in post DFS order node, and transform it
* \brief recursively visit the ir nodes in post DFS order, and transform it
*
* \param node The ir to be transformed.
* \param stmt The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
......@@ -328,23 +331,76 @@ class StmtExprMutator :
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* If it is null, all IRNode will call preorder/postorder
* If it is not null, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
TVM_DLL Stmt IRTransform(Stmt stmt,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<runtime::String>& only_enable = {});
Optional<Array<String>> only_enable = NullOpt);
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* \brief Recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
/*!
* \brief Substitute the var specified by vmap.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The converted form.
*/
TVM_DLL Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Substitute the var specified by vmap.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The result.
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var& var)> vmap);
/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}
/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var.get());
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}
} // namespace tir
} // namespace tvm
......
......@@ -72,7 +72,7 @@ def _pruned_source(func):
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel
from tvm.tir import ir_pass
from tvm.tir import stmt_functor
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
......@@ -84,7 +84,7 @@ def replace_io(body, rmap):
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
def _is_tvm_arg_types(args):
......
......@@ -48,3 +48,4 @@ from . import ir_builder
from . import ir_pass
from . import transform
from . import analysis
from . import stmt_functor
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Statement functor utilities for IR transformations"""
from . import _ffi_api
def ir_transform(stmt, preorder, postorder, only_enable=None):
"""Recursively visit and transform ir nodes in post DFS order.
Parameters
----------
stmt : Stmt
The input to be transformed.
preorder: function
The function called in before recursive mutation
If preorder returns None, then the transform will proceed to recursive call.
If preorder returns a not None Stmt/Expr, the transformer will simply return it and
won't do further recursion.
postorder : function
The function called after recursive mutation.
only_enable : Optional[List[str]]
List of types that we only enable.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable)
def post_order_visit(stmt, fvisit):
"""Recursively visit the ir in post DFS order node, apply fvisit
Each node is guaranteed to be visited only once.
Parameters
----------
fvisit: function
The visitor function.
"""
return _ffi_api.PostOrderVisit(stmt, fvisit)
def substitute(node, vmap):
""" Substitute the var specified by vmap.
Parameters
----------
node: ObjectRef
The input.
vmap : Dict[Var, PrimExpr]
The variable mapping.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.Substitute(node, vmap)
......@@ -26,9 +26,10 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h>
#include <tvm/tir/op.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/data_type.h>
namespace tvm {
......@@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*S)[i][j] = new_i_j;
}
// We have to do the same with rhs
PrimExpr ea = te::make_const((*y)[index].dtype(), a);
PrimExpr eb = te::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
(*y)[index] = new_index_rhs;
......@@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*V)[i][j] = new_i_j;
}
// And apply reverse transformations to new_to_old.
PrimExpr ea = te::make_const((*x)[j].dtype(), a);
PrimExpr eb = te::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
(*x)[index] = new_index;
......@@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
IntConstraints(
/*variables=*/{},
/*ranges=*/{},
/*relations=*/{te::make_zero(DataType::Bool())}),
/*relations=*/{tir::make_zero(DataType::Bool())}),
{}, {});
} else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation);
......@@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(-Uy[j], a)));
}
......@@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// V V^{-1} x = x
for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) {
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
}
e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e);
......
......@@ -22,7 +22,7 @@
* \brief Utility for tensor-level auto-differentiation.
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include "ad_util.h"
......
......@@ -26,7 +26,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <unordered_set>
......
......@@ -25,8 +25,9 @@
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
#include "./op_util.h"
#include "./compute_op.h"
#include "../../arith/compute_expr.h"
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h>
#include <cctype>
......
......@@ -24,7 +24,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <memory>
#include <limits>
#include "../pass/ir_util.h"
......
......@@ -19,116 +19,14 @@
/*!
* \file stmt_functor.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <functional>
#include "functor_common.h"
namespace tvm {
namespace tir {
// visitor to implement apply
class IRApplyVisit :
public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
f_(node);
}
void VisitStmt(const Stmt& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
StmtVisitor::VisitStmt(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node,
std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
visitor(Downcast<PrimExpr>(node));
}
}
class IRTransformer final :
public StmtExprMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt VisitStmt(const Stmt& stmt) final {
return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
return this->BaseVisitStmt(s);
});
}
PrimExpr VisitExpr(const PrimExpr& expr) final {
return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
return this->BaseVisitExpr(e);
});
}
private:
// NOTE: redirect to parent's call
// This is used to get around limitation of gcc-4.8
Stmt BaseVisitStmt(const Stmt& s) {
return StmtMutator::VisitStmt(s);
}
PrimExpr BaseVisitExpr(const PrimExpr& e) {
return ExprMutator::VisitExpr(e);
}
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
T new_node = fmutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(new_node);
if (post.defined()) return post;
}
return new_node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<runtime::String>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (auto s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
this->VisitExpr(op->value);
this->VisitStmt(op->body);
......@@ -511,6 +409,183 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
}
// Implementations of IRTransform, PostOrderVisit and Substitute
class IRApplyVisit :
public StmtExprVisitor {
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
f_(node);
}
void VisitStmt(const Stmt& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
StmtVisitor::VisitStmt(node);
f_(node);
}
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
};
void PostOrderVisit(const ObjectRef& node,
std::function<void(const ObjectRef&)> fvisit) {
if (node.as<StmtNode>()) {
IRApplyVisit visitor(fvisit);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
visitor(Downcast<PrimExpr>(node));
}
}
class IRTransformer final :
public StmtExprMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt VisitStmt(const Stmt& stmt) final {
return MutateInternal<Stmt>(stmt, [this](const Stmt& s) {
return this->BaseVisitStmt(s);
});
}
PrimExpr VisitExpr(const PrimExpr& expr) final {
return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
return this->BaseVisitExpr(e);
});
}
private:
// NOTE: redirect to parent's call
// This is used to get around limitation of gcc-4.8
Stmt BaseVisitStmt(const Stmt& s) {
return StmtMutator::VisitStmt(s);
}
PrimExpr BaseVisitExpr(const PrimExpr& e) {
return ExprMutator::VisitExpr(e);
}
template <typename T, typename F>
T MutateInternal(const T& node, F fmutate) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return fmutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
T new_node = fmutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(new_node);
if (post.defined()) return post;
}
return new_node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
Optional<Array<String>> only_enable) {
std::unordered_set<uint32_t> only_type_index;
if (only_enable.defined()) {
for (auto s : only_enable.value()) {
only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap)
: vmap_(vmap) {
}
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto ret = vmap_(var);
if (ret.defined()) return ret.value();
return std::move(var);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
return LoadNode::make(
op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
return StoreNode::make(
Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private:
std::function<Optional<PrimExpr>(const Var&)> vmap_;
};
Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(stmt));
}
PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(expr));
}
TVM_REGISTER_GLOBAL("tir.IRTransform")
.set_body_typed(IRTransform);
TVM_REGISTER_GLOBAL("tir.PostOrderVisit")
.set_body_typed([](ObjectRef node, PackedFunc f) {
tir::PostOrderVisit(node, [f](const ObjectRef& n) {
f(n);
});
});
TVM_REGISTER_GLOBAL("tir.Substitute")
.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef{
if (node->IsInstance<StmtNode>()) {
return Substitute(Downcast<Stmt>(node), vmap);
} else {
return Substitute(Downcast<PrimExpr>(node), vmap);
}
});
} // namespace tir
} // namespace tvm
......@@ -32,28 +32,13 @@
namespace tvm {
namespace tir {
TVM_REGISTER_GLOBAL("ir_pass.Substitute")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
*ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
} else {
*ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
}
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
});
TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
tir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
f(n);
});
});
// make from two arguments
#define REGISTER_PASS(PassName) \
......@@ -63,7 +48,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(VerifyCompactBuffer);
......
......@@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
}
});
return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"});
return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
}
// Remove IfThenElse node from a For node.
......@@ -185,9 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
}
});
then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"});
then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"});
else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"IfThenElse"});
}
return std::make_pair(then_for, else_for);
......@@ -408,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
*ret = new_for;
}
});
return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")});
return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
}
Stmt HoistIfThenElse(Stmt stmt) {
......
......@@ -52,79 +52,7 @@ bool HasSideEffect(const PrimExpr& e) {
return v.has_side_effect_;
}
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
const std::unordered_map<const VarNode*, PrimExpr>& smap)
: smap_(smap) {
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
return GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return LoadNode::make(
op->dtype, Downcast<Var>(it->second), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return StoreNode::make(
Downcast<Var>(it->second), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private:
const std::unordered_map<const VarNode*, PrimExpr>& smap_;
};
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map)(std::move(stmt));
}
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map)(std::move(expr));
}
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return Substitute(stmt, vmap);
}
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return Substitute(expr, vmap);
}
class VarTouchVisitor : public ExprVisitor {
public:
......
......@@ -29,7 +29,7 @@ def run_expr(expr, vranges):
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)
return tir.stmt_functor.substitute(expr, vmap)
A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
......@@ -69,17 +69,17 @@ def check_solution(solution, vranges={}):
cond_on_vars = tir.const(1, 'bool')
for v in constraints1.variables:
# variable mapping is consistent
v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap))
cond_on_vars = te.all(cond_on_vars, v == v_back)
# Also we have to check that the new relations are true when old relations are true
cond_subst = tir.ir_pass.Substitute(
cond_subst = tir.stmt_functor.substitute(
te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
# We have to include relations from vranges too
for v in constraints2.variables:
if v in constraints2.ranges:
r = constraints2.ranges[v]
range_cond = te.all(v >= r.min, v < r.min + r.extent)
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
range_cond = tir.stmt_functor.substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond)
cond_subst = ana.simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
......
......@@ -201,7 +201,7 @@ def test_cuda_shuffle():
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
......
......@@ -685,7 +685,7 @@ def test_llvm_shuffle():
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
......
......@@ -24,7 +24,7 @@ from tvm.te.hybrid.runtime import HYBRID_GLOBALS
@pytest.mark.skip
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val):
val = tvm.tir.ir_pass.Substitute(val, var_dict)
val = tvm.tir.stmt_functor.substitute(val, var_dict)
val = tvm.arith.Analyzer().simplify(val)
assert isinstance(val, (tvm.tir.IntImm,))
return val.value
......
......@@ -148,8 +148,8 @@ def test_bound_fusesplit1():
for k in range(1, 6):
vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")})
tvm.testing.assert_prim_expr_equal(
tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars),
tvm.tir.ir_pass.Substitute(expected_extent, vars)
tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars),
tvm.tir.stmt_functor.substitute(expected_extent, vars)
)
tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l)
......@@ -170,10 +170,10 @@ def test_bound_fusesplit2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")})
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1)
tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2)
tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3)
tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1)
tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3)
def test_bound_warp():
......
......@@ -155,7 +155,7 @@ def test_inline_mixed():
def check(x):
if isinstance(x, tvm.tir.Call):
assert x.func != A2
tvm.tir.ir_pass.PostOrderVisit(s[C].op.body[0], check)
tvm.tir.stmt_functor.post_order_visit(s[C].op.body[0], check)
def test_scan_inline1():
......@@ -517,7 +517,7 @@ def test_local_stage_predicate():
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret
# local vs. threadIdx
s = schedule(tx, "local")
......@@ -563,7 +563,7 @@ def test_local_stage_predicate2():
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret
def visit_stmt(op):
......
......@@ -264,7 +264,7 @@ def test_tuple_with_different_deps():
x.func == B1.op and x.value_index == 1:
ret.append(x)
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, get_B1_realize)
tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize)
assert stmt.node == C.op and len(ret) == 1
......
......@@ -32,7 +32,7 @@ def verify_structure(stmt, expected_struct):
key = op
if isinstance(op, tvm.tir.IfThenElse):
global var_list
tvm.tir.ir_pass.PostOrderVisit(op.condition, _extract_vars)
tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
var_list.clear()
elif isinstance(op, tvm.tir.For):
......@@ -43,7 +43,7 @@ def verify_structure(stmt, expected_struct):
return
node_dict[key] = val
tvm.tir.ir_pass.PostOrderVisit(stmt, _visit)
tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
for key, val in node_dict.items():
struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
else None for child in val[0])
......
......@@ -37,7 +37,7 @@ def test_ir_transform():
if op.name == "TestA":
return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
return op
body = tvm.tir.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"])
stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0
......
......@@ -54,7 +54,7 @@ def test_double_buffer():
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync)
tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4
......
......@@ -21,7 +21,7 @@ import numpy as np
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x)))
return ret
......
......@@ -20,7 +20,7 @@ import numpy
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x)))
return ret
......
......@@ -123,7 +123,7 @@ def test_flatten_double_buffer():
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
count[0] += 1
tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync)
tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4
if __name__ == "__main__":
......
......@@ -45,7 +45,7 @@ def test_storage_share():
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1
def register_mem(scope_tb, max_bits):
......@@ -84,7 +84,7 @@ def test_alloc_seq():
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.tir.ir_pass.PostOrderVisit(body, verify)
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_alloc_different_dtypes():
......@@ -139,7 +139,7 @@ def test_alloc_different_dtypes():
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
tvm.tir.ir_pass.PostOrderVisit(body, verify)
tvm.tir.stmt_functor.post_order_visit(body, verify)
length = 1024
dtype_list = ["float16", "int32", "uint16", "int8"]
......@@ -181,7 +181,7 @@ def test_inplace_rule():
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2
......@@ -214,7 +214,7 @@ def test_storage_combine():
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert (n.extents[0].value == 16)
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1
......@@ -250,7 +250,7 @@ def test_storage_share_gpu():
if isinstance(n, tvm.tir.AttrStmt):
if n.attr_key == "storage_scope":
alloc_stats[n.value.value] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage
......@@ -318,7 +318,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
def verify(n):
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 2
def test_exceed_mem():
......@@ -407,7 +407,7 @@ def test_inplace_rule3():
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 70
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
def test_alloc_seq_type():
ib = tvm.tir.ir_builder.create()
......@@ -437,7 +437,7 @@ def test_alloc_seq_type():
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 500
tvm.tir.ir_pass.PostOrderVisit(body, verify)
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_alloc_seq_type2():
......@@ -469,7 +469,7 @@ def test_alloc_seq_type2():
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.tir.ir_pass.PostOrderVisit(body, verify)
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
......@@ -502,7 +502,7 @@ def test_reuse_small_buffer():
if isinstance(n, tvm.tir.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 800
tvm.tir.ir_pass.PostOrderVisit(body, verify)
tvm.tir.stmt_functor.post_order_visit(body, verify)
assert num_alloc[0] == 1
def test_replace_dataflow():
......@@ -540,7 +540,7 @@ def test_large_input():
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456
tvm.tir.ir_pass.PostOrderVisit(stmt, verify)
tvm.tir.stmt_functor.post_order_visit(stmt, verify)
if __name__ == "__main__":
......
......@@ -70,7 +70,7 @@ print(ir)
#
# IR Visitor
# ~~~~~~~~~~
# We can use ``tvm.tir.ir_pass.PostOrderVisit(stmt, func)`` to gather information from the Halide IR.
# We can use ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` to gather information from the Halide IR.
# ``func`` is a function callback. This function will be called before exiting the current IR node,
# i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the
# return value of ``func`` will be ignored.
......@@ -111,7 +111,7 @@ def vectorize8(op):
extent = op.extent.value
name = op.loop_var.name
lo, li = te.var(name + '.outer'), te.var(name + '.inner')
body = tvm.tir.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li})
body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body)
body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body)
return body
......@@ -121,7 +121,7 @@ def vectorize8(op):
def vectorize(f, mod, ctx):
global loops
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
if not loops:
return sf
......@@ -129,7 +129,7 @@ def vectorize(f, mod, ctx):
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For']))
#####################################################################
......@@ -161,8 +161,8 @@ with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
# Quick View
# ----------
# This tutorial gives a quick view of writing a customized IR transformation pass:
# - Use ``tvm.tir.ir_pass.PostOrderVisit`` to gather information on each IR nodes.
# - Use ``tvm.tir.ir_pass.IRTransform`` to transform IR nodes.
# - Use ``tvm.tir.stmt_functor.post_order_visit`` to gather information on each IR nodes.
# - Use ``tvm.tir.stmt_functor.ir_transform`` to transform IR nodes.
# - Wrap up two above to write an IR-transformation function.
# - Use ``tvm.target.build_config`` to put this function to TVM lowering pass
#
......@@ -86,14 +86,14 @@ def FoldUopLoop():
raise RuntimeError("unexpected op %s" % op)
return op
ret = tvm.tir.ir_pass.IRTransform(
ret = tvm.tir.stmt_functor.ir_transform(
stmt.body, None, _post_order, ["Call"])
if not fail[0] and all(x is not None for x in gemm_offsets):
def _visit(op):
if op.same_as(loop_var):
fail[0] = True
tvm.tir.ir_pass.PostOrderVisit(ret, _visit)
tvm.tir.stmt_functor.post_order_visit(ret, _visit)
if not fail[0]:
begin = tvm.tir.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
......@@ -131,7 +131,7 @@ def FoldUopLoop():
return None
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
......@@ -187,7 +187,7 @@ def CPUAccessRewrite():
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt = tvm.tir.stmt_functor.ir_transform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
......@@ -253,7 +253,7 @@ def LiftAllocToScopeBegin():
return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.ir_pass.IRTransform(
stmt = tvm.tir.stmt_functor.ir_transform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
return f.with_body(_merge_block(lift_stmt[0], stmt))
......@@ -276,7 +276,7 @@ def InjectSkipCopy():
return None
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.ir_pass.IRTransform(
return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
......@@ -306,7 +306,7 @@ def InjectCoProcSync():
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
return f.with_body(tvm.tir.ir_pass.IRTransform(
return f.with_body(tvm.tir.stmt_functor.ir_transform(
f.body, None, _do_fold, ["AttrStmt"]))
return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
......@@ -635,7 +635,7 @@ def InjectConv2DTransposeSkip():
def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"):
is_init = ".init" in str(op)
tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
tvm.tir.stmt_functor.post_order_visit(op, _find_basics)
if is_init:
# create inner most block
......@@ -707,7 +707,7 @@ def InjectConv2DTransposeSkip():
return inner
return None
return func.with_body(tvm.tir.ir_pass.IRTransform(
return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, _do_fold, None, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
......@@ -736,7 +736,7 @@ def AnnotateALUCoProcScope():
return tvm.tir.Evaluate(0)
return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform(
return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
......@@ -955,7 +955,7 @@ def InjectALUIntrin():
return irb.get()
return stmt
return func.with_body(tvm.tir.ir_pass.IRTransform(
return func.with_body(tvm.tir.stmt_functor.ir_transform(
func.body, None, _do_fold, ["AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment