Unverified Commit b8efe27f by Tianqi Chen Committed by GitHub

[REFACTOR][TE] Inline -> te/schedule/operation_inline.h (#5386)

Rationale: inline is a transformation used in te to
rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass.

Also removed the python test as the test is covered by stage.compute_inline.
parent 3f03869e
......@@ -149,22 +149,6 @@ Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
PrimExpr body);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
......
......@@ -18,29 +18,31 @@
*/
/*!
* \file inline.cc
* \file operation_inline.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
#include "operation_inline.h"
namespace tvm {
namespace tir {
namespace te {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline final : public StmtExprMutator {
class OperationInliner final : public StmtExprMutator {
public:
IRInline(FunctionRef f, Array<Var> args, PrimExpr body)
: f_(f), args_(args), body_(body) {}
OperationInliner(Operation op, Array<Var> args, PrimExpr body)
: operation_(op), args_(args), body_(body) {}
PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (op->func == f_) {
if (op->func.same_as(operation_)) {
CHECK_EQ(op->value_index, 0);
expr = body_;
CHECK_EQ(args_.size(), op->args.size());
......@@ -68,20 +70,20 @@ class IRInline final : public StmtExprMutator {
}
private:
FunctionRef f_;
Operation operation_;
Array<Var> args_;
PrimExpr body_;
};
Stmt Inline(Stmt stmt,
FunctionRef f,
Operation f,
Array<Var> args,
PrimExpr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
Stmt ret = IRInline(f, args, body)(std::move(stmt));
Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
} // namespace tir
} // namespace te
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file operation_inline.h
*/
#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>
namespace tvm {
namespace te {
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param op The op to be inlined.
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
Operation op,
Array<Var> args,
PrimExpr body);
} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_
......@@ -26,6 +26,8 @@
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "message_passing.h"
#include "operation_inline.h"
#include "../../tir/pass/ir_util.h"
#include "../../arith/compute_expr.h"
......@@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) {
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]),
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]),
stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
......@@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) {
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]),
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]),
stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
......@@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) {
if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body;
}
Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body);
Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true;
......
......@@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_inline():
m = te.size_var('m')
A = te.placeholder((m,), name='A')
T = te.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.tir.Evaluate(T[10] + 11 * T[100])
stmt = tvm.tir.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt)
assert(tvm.tir.ir_pass.VerifySSA(stmt))
try:
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.tir.ir_pass.Inline(
T.op, [1,2,3], T.op.body, stmt)
assert False
except tvm.error.TVMError:
pass
def test_inline2():
m = te.size_var('m')
A = te.placeholder((m,), name='A')
T = te.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.tir.Evaluate(te.exp(T[10]) + 11 * T[100])
stmt = tvm.tir.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op):
if isinstance(op, tvm.tir.Call):
assert op.func != T.op
tvm.tir.ir_pass.PostOrderVisit(stmt, check)
if __name__ == "__main__":
test_inline2()
test_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment