Commit b19e01bf by Tianqi Chen Committed by GitHub

[PASS] RemoveNoOp. (#68)

parent 88338826
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file buffer.h * \file buffer.h
......
/*!
* Copyright (c) 2017 by Contributors
* \file channel.h
* \brief Channel object for pipeline.
*/
#ifndef TVM_CHANNEL_H_
#define TVM_CHANNEL_H_
#include <tvm/expr.h>
namespace tvm {
// Node container of channel
struct ChannelNode;
/*! \brief The data channel. */
class Channel : public NodeRef {
public:
/*! \brief default constructor */
Channel() {}
explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ChannelNode* operator->() const;
};
/*!
* \brief Generalized FIFO channel.
*/
struct ChannelNode : public Node {
/*! \brief Variable to channel handle */
Var handle_var;
/*! \brief default data type in read/write */
Type dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}
static Channel make(Var handle_var, Type dtype);
static constexpr const char* _type_key = "Channel";
TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node);
};
// Inline implementations
inline const ChannelNode* Channel::operator->() const {
return static_cast<const ChannelNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_CHANNEL_H_
...@@ -39,6 +39,7 @@ using Halide::Internal::as_const_int; ...@@ -39,6 +39,7 @@ using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint; using Halide::Internal::as_const_uint;
using Halide::Internal::const_true; using Halide::Internal::const_true;
using Halide::Internal::const_false; using Halide::Internal::const_false;
using Halide::Internal::is_no_op;
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
......
...@@ -90,9 +90,7 @@ constexpr const char* virtual_thread = "virtual_thread"; ...@@ -90,9 +90,7 @@ constexpr const char* virtual_thread = "virtual_thread";
* \brief Mark storage scope of buffers * \brief Mark storage scope of buffers
*/ */
constexpr const char* storage_scope = "storage_scope"; constexpr const char* storage_scope = "storage_scope";
/*! /*! \brief Mark storage scope of realization */
* \brief Mark storage scope of realizations
*/
constexpr const char* realize_scope = "realize_scope"; constexpr const char* realize_scope = "realize_scope";
/*! \brief Mark of loop scope */ /*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
...@@ -100,6 +98,13 @@ constexpr const char* loop_scope = "loop_scope"; ...@@ -100,6 +98,13 @@ constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope"; constexpr const char* scan_init_scope = "scan_init_scope";
// Pipeline related attributes
/*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope";
/*! \brief channel write scope */
constexpr const char* channel_write_scope = "channel_write_scope";
/*! \brief pipeline module scope */
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
} // namespace attr } // namespace attr
/*! \brief namespace of TVM Intrinsic functions */ /*! \brief namespace of TVM Intrinsic functions */
......
...@@ -106,6 +106,20 @@ Stmt StorageFlatten(Stmt stmt, ...@@ -106,6 +106,20 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer); Map<Tensor, Buffer> extern_buffer);
/*! /*!
* \brief Remove No Op from the Stmt.
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt RemoveNoOp(Stmt stmt);
/*!
* \brief Split statement into pipeine stages.
* \param stmt The stmt to be splitted
* \return Transformed stmt.
*/
Stmt SplitPipeline(Stmt stmt);
/*!
* \brief unroll the constant loops * \brief unroll the constant loops
* \param stmt The statment to be unrolled. * \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling. * \param max_auto_step The maximum step to stop performing automatic unrolling.
......
...@@ -70,6 +70,8 @@ REGISTER_PASS1(SplitHostDevice); ...@@ -70,6 +70,8 @@ REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate); REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition); REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS1(SplitPipeline);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file channel.cc
*/
#include <tvm/channel.h>
namespace tvm {
Channel ChannelNode::make(Var handle_var, Type dtype) {
auto n = std::make_shared<ChannelNode>();
n->handle_var = handle_var;
n->dtype = dtype;
return Channel(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ChannelNode>([](const ChannelNode *op, IRPrinter *p) {
p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")";
});
TVM_REGISTER_NODE_TYPE(ChannelNode);
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file remove_no_op.cc
* \brief Remove no op from the stmt
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
namespace tvm {
namespace ir {
// Mark the statment of each stage.
class NoOpRemover : public IRMutator {
public:
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<LetStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<IfThenElse>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
return IfThenElse::make(op->condition, op->then_case);
}
} else {
return stmt;
}
} else {
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
return stmt;
}
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<ProducerConsumer>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Realize>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
if (HasSideEffect(op->value)) return s;
return Evaluate::make(0);
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Block>();
if (is_no_op(op->first)) {
return op->rest;
} else if (is_no_op(op->rest)) {
return op->first;
} else {
return stmt;
}
}
private:
Stmt MakeEvaluate(Expr value) {
if (HasSideEffect(value)) {
return Evaluate::make(value);
} else {
return Evaluate::make(0);
}
}
Stmt MakeEvaluate(const Array<Expr>& values) {
Stmt stmt;
for (Expr e : values) {
if (HasSideEffect(e)) {
if (stmt.defined()) {
stmt = Block::make(stmt, Evaluate::make(e));
} else {
stmt = Evaluate::make(e);
}
}
}
return stmt.defined() ? stmt : Evaluate::make(0);
}
};
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -48,6 +48,7 @@ class IRSubstitue : public IRMutator { ...@@ -48,6 +48,7 @@ class IRSubstitue : public IRMutator {
}; };
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) { Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
if (value_map.size() == 0) return stmt;
IRSubstitue m; IRSubstitue m;
for (auto kv : value_map) { for (auto kv : value_map) {
m.smap[kv.first.get()] = kv.second; m.smap[kv.first.get()] = kv.second;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <tvm/channel.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
...@@ -17,7 +18,7 @@ namespace ir { ...@@ -17,7 +18,7 @@ namespace ir {
class IRUseDefAnalysis : public IRMutator { class IRUseDefAnalysis : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") { if (op->type_key == attr::thread_extent) {
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times // thread_extent can appear multiple times
...@@ -35,6 +36,13 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -35,6 +36,13 @@ class IRUseDefAnalysis : public IRMutator {
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s; if (value.same_as(value) && body.same_as(body)) return s;
return AttrStmt::make(op->node, op->type_key, value, body); return AttrStmt::make(op->node, op->type_key, value, body);
} else if (op->type_key == attr::channel_write_scope ||
op->type_key == attr::channel_read_scope) {
Channel ch(op->node.node_);
if (!use_count_.count(ch->handle_var.get())) {
this->HandleDef(ch->handle_var.get());
}
return IRMutator::Mutate_(op, s);
} else { } else {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file split_pipeline.cc
* \brief Split statement into pipeline stage modules.
*/
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/channel.h>
#include <unordered_map>
#include "./ir_util.h"
namespace tvm {
namespace ir {
class MarkChannelAccess : public IRMutator {
public:
MarkChannelAccess(
const std::unordered_map<const Variable*, Channel>& cmap)
: cmap_(cmap) {}
Expr Mutate_(const Load *op, const Expr& e) final {
auto it = rmap_.find(op->buffer_var.get());
if (it != rmap_.end()) {
++it->second.read_count;
}
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Store *op, const Stmt& s) final {
auto it = rmap_.find(op->buffer_var.get());
if (it != rmap_.end()) {
++it->second.write_count;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
if (cmap_.count(op->buffer_var.get())) {
CHECK(!rmap_.count(op->buffer_var.get()));
rmap_[op->buffer_var.get()] = Entry();
Stmt body = Mutate(op->body);
body = CreateChannelAccess(op, body);
rmap_.erase(op->buffer_var.get());
return body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == ir::attr::storage_scope) {
Var buf_var(op->node.node_);
if (cmap_.count(buf_var.get())) return Mutate(op->body);
}
return IRMutator::Mutate_(op, s);
}
private:
// Create channel access wrap
Stmt CreateChannelAccess(const Allocate* op, Stmt body) {
const Entry& rw = rmap_.at(op->buffer_var.get());
CHECK(rw.write_count == 0 || rw.read_count == 0)
<< "Cannot read/write to the same channel " << op->buffer_var
<< " body:" << body;
if (rw.write_count == 0 && rw.read_count == 0) {
return body;
}
const Channel& ch = cmap_.at(op->buffer_var.get());
int32_t csize = op->constant_allocation_size();
Expr alloc_size;
if (csize > 0) {
alloc_size = IntImm::make(Int(32), csize);
} else {
alloc_size = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
alloc_size *= op->extents[i];
}
alloc_size = ir::Simplify(alloc_size);
}
if (rw.write_count) {
return AttrStmt::make(
ch, ir::attr::channel_write_scope, alloc_size, body);
} else {
CHECK(rw.read_count);
return AttrStmt::make(
ch, ir::attr::channel_read_scope, alloc_size, body);
}
}
struct Entry {
int read_count{0};
int write_count{0};
};
// The channels of each allocation.
const std::unordered_map<const Variable*, Channel>& cmap_;
// the result.
std::unordered_map<const Variable*, Entry> rmap_;
};
// Mark the statment of each stage.
class StageSplitter : public IRMutator {
public:
Stmt Mutate(Stmt stmt) final {
nest_.push_back(stmt);
Stmt ret = IRMutator::Mutate(stmt);
nest_.pop_back();
return ret;
}
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
if (!op->is_producer) return IRMutator::Mutate_(op, s);
Stmt body = Mutate(op->body);
stages_.emplace_back(BuildStage(body, op->func));
return Evaluate::make(0);
}
Stmt Split(Stmt stmt) {
stmt = Mutate(stmt);
stmt = RemoveNoOp(stmt);
CHECK(is_no_op(stmt));
CHECK_NE(stages_.size(), 0);
stmt = stages_.back();
for (size_t i = stages_.size() - 1; i != 0; --i) {
stmt = Block::make(stages_[i - 1], stmt);
}
stmt = MarkChannelAccess(cmap_).Mutate(stmt);
return RemoveNoOp(stmt);
}
private:
// Build the stage.
Stmt BuildStage(Stmt body, NodeRef target) {
int stage_index = static_cast<size_t>(stages_.size());
std::string stage_suffix = "." + std::to_string(stage_index);
// The Substitute
Map<Var, Expr> subst;
std::vector<Stmt> nest;
Stmt no_op = Evaluate::make(0);
for (const Stmt& s : nest_) {
if (const For* op = s.as<For>()) {
Var loop_var(op->loop_var);
Var new_var = loop_var.copy_with_suffix(stage_suffix);
subst.Set(loop_var, new_var);
nest.emplace_back(For::make(
new_var, op->min, op->extent,
op->for_type, op->device_api, no_op));
} else if (const LetStmt* op = s.as<LetStmt>()) {
Var var(op->var);
Var new_var = var.copy_with_suffix(stage_suffix);
subst.Set(var, new_var);
nest.emplace_back(LetStmt::make(new_var, op->value, no_op));
} else if (const IfThenElse* op = s.as<IfThenElse>()) {
CHECK(!op->else_case.defined());
nest.emplace_back(IfThenElse::make(op->condition, no_op));
} else if (const AttrStmt* op = s.as<AttrStmt>()) {
nest.emplace_back(AttrStmt::make(
op->node, op->type_key, op->value, no_op));
} else if (s.as<ProducerConsumer>()) {
} else if (s.as<Block>()) {
} else if (const Allocate* op = s.as<Allocate>()) {
nest.emplace_back(Allocate::make(
op->buffer_var, op->type, op->extents,
op->condition, no_op, op->new_expr, op->free_function));
MarkChannel(op);
} else {
LOG(FATAL) << "not supported nest type " << s->type_key();
}
}
body = Substitute(MergeNest(nest, body), subst);
return AttrStmt::make(
target, ir::attr::pipeline_stage_scope,
make_const(Int(32), stage_index), body);
}
void MarkChannel(const Allocate* op) {
if (!cmap_.count(op->buffer_var.get())) {
Channel ch = ChannelNode::make(Var(op->buffer_var), op->type);
cmap_[op->buffer_var.get()] = ch;
}
}
// The stack
std::vector<Stmt> nest_;
// The stages
std::vector<Stmt> stages_;
// channel map
std::unordered_map<const Variable*, Channel> cmap_;
};
Stmt SplitPipeline(Stmt stmt) {
return StageSplitter().Split(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -26,12 +26,8 @@ Stmt MakePipeline(const Stage& s, ...@@ -26,12 +26,8 @@ Stmt MakePipeline(const Stage& s,
producer = ProducerConsumer::make(s->op, true, producer); producer = ProducerConsumer::make(s->op, true, producer);
} }
Stmt pipeline = producer; Stmt pipeline = producer;
// check if consumer is nop.
bool is_no_op{false};
const Evaluate* ev = consumer.as<Evaluate>();
if (ev && ev->value.as<IntImm>()) is_no_op = true;
if (consumer.defined() && !is_no_op) { if (consumer.defined() && !is_no_op(consumer)) {
consumer = ProducerConsumer::make(s->op, false, consumer); consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer); pipeline = Block::make(producer, consumer);
} }
......
import tvm
def test_remove_no_op():
i = tvm.Var('i')
j = tvm.Var('j')
k = tvm.Var('k')
m = tvm.Var('m')
n = tvm.Var('n')
dtype = 'int64'
Ab = tvm.Buffer((n, ), dtype)
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.IfThenElse(
(i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
ret = tvm.ir_pass.RemoveNoOp(stmt)
assert(isinstance(ret, tvm.stmt.Evaluate))
store = tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1)
stmt2 = tvm.make.Block(stmt, store)
assert(tvm.ir_pass.RemoveNoOp(stmt2) == store)
if __name__ == "__main__":
test_remove_no_op()
import tvm
def test_basic_pipeline():
n = tvm.convert(128)
A = tvm.placeholder((n,), name='A')
stages = []
num_stage = 3
B = A
for k in range(num_stage):
stages.append(B)
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=4)
for S in stages:
s[S].compute_at(s[B], xo)
# Lowering
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb})
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.SplitPipeline(stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
if __name__ == "__main__":
test_basic_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment