Unverified Commit b8efe27f by Tianqi Chen Committed by GitHub

[REFACTOR][TE] Inline -> te/schedule/operation_inline.h (#5386)

Rationale: inline is a transformation used in te to
rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass.

Also removed the python test as the test is covered by stage.compute_inline.
parent 3f03869e
...@@ -149,22 +149,6 @@ Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map); ...@@ -149,22 +149,6 @@ Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map); PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
/*! /*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
PrimExpr body);
/*!
* \brief Verify if there is any argument bound to compact buffer. * \brief Verify if there is any argument bound to compact buffer.
* *
* \param stmt The stmt to be verified. * \param stmt The stmt to be verified.
......
...@@ -18,29 +18,31 @@ ...@@ -18,29 +18,31 @@
*/ */
/*! /*!
* \file inline.cc * \file operation_inline.cc
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <utility>
#include "operation_inline.h"
namespace tvm { namespace tvm {
namespace tir { namespace te {
// inliner to inline a function // inliner to inline a function
// the result may not be SSA, // the result may not be SSA,
// ConvertSSA need to be applied after this pass // ConvertSSA need to be applied after this pass
class IRInline final : public StmtExprMutator { class OperationInliner final : public StmtExprMutator {
public: public:
IRInline(FunctionRef f, Array<Var> args, PrimExpr body) OperationInliner(Operation op, Array<Var> args, PrimExpr body)
: f_(f), args_(args), body_(body) {} : operation_(op), args_(args), body_(body) {}
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op); PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>(); op = expr.as<CallNode>();
if (op->func == f_) { if (op->func.same_as(operation_)) {
CHECK_EQ(op->value_index, 0); CHECK_EQ(op->value_index, 0);
expr = body_; expr = body_;
CHECK_EQ(args_.size(), op->args.size()); CHECK_EQ(args_.size(), op->args.size());
...@@ -68,20 +70,20 @@ class IRInline final : public StmtExprMutator { ...@@ -68,20 +70,20 @@ class IRInline final : public StmtExprMutator {
} }
private: private:
FunctionRef f_; Operation operation_;
Array<Var> args_; Array<Var> args_;
PrimExpr body_; PrimExpr body_;
}; };
Stmt Inline(Stmt stmt, Stmt Inline(Stmt stmt,
FunctionRef f, Operation f,
Array<Var> args, Array<Var> args,
PrimExpr body) { PrimExpr body) {
CHECK_EQ(f->num_outputs(), 1) CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation"; << "can only inline output single value operation";
Stmt ret = IRInline(f, args, body)(std::move(stmt)); Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret; if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret); return ConvertSSA(ret);
} }
} // namespace tir } // namespace te
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file operation_inline.h
*/
#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>
namespace tvm {
namespace te {
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param op The op to be inlined.
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
Operation op,
Array<Var> args,
PrimExpr body);
} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <unordered_set> #include <unordered_set>
#include "message_passing.h" #include "message_passing.h"
#include "operation_inline.h"
#include "../../tir/pass/ir_util.h" #include "../../tir/pass/ir_util.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
...@@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) {
<< "The Reduce inputs of ComputeOp should " << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index"; << "have the same attribute except value_index";
} }
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]), PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]),
stage->op, args, body).as<tir::EvaluateNode>()->value; stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][0])) { if (!new_value.same_as(new_body[j][0])) {
changed[j] = true; changed[j] = true;
...@@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) {
} }
} else { } else {
for (size_t k = 0; k < new_body[j].size(); ++k) { for (size_t k = 0; k < new_body[j].size(); ++k) {
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]), PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]),
stage->op, args, body).as<tir::EvaluateNode>()->value; stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][k])) { if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value); new_body[j].Set(k, new_value);
...@@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) { ...@@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) {
if (!new_hybrid_body[j].defined()) { if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body; new_hybrid_body[j] = hybrid->body;
} }
Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body); Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) { if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt; new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true; hybrid_changed[j] = true;
......
...@@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") ...@@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA); REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA); REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform); REGISTER_PASS(IRTransform);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_inline():
m = te.size_var('m')
A = te.placeholder((m,), name='A')
T = te.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.tir.Evaluate(T[10] + 11 * T[100])
stmt = tvm.tir.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
print(stmt)
assert(tvm.tir.ir_pass.VerifySSA(stmt))
try:
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.tir.ir_pass.Inline(
T.op, [1,2,3], T.op.body, stmt)
assert False
except tvm.error.TVMError:
pass
def test_inline2():
m = te.size_var('m')
A = te.placeholder((m,), name='A')
T = te.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.tir.Evaluate(te.exp(T[10]) + 11 * T[100])
stmt = tvm.tir.ir_pass.Inline(
stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
def check(op):
if isinstance(op, tvm.tir.Call):
assert op.func != T.op
tvm.tir.ir_pass.PostOrderVisit(stmt, check)
if __name__ == "__main__":
test_inline2()
test_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment