Unverified Commit e8138f7d by Tianqi Chen Committed by GitHub

[TIR] Remove ProducerConsumer and AllocateNode::new_expr (#5333)

* [TIR] Remove ProducerConsumer and AllocateNode::new_expr

This PR removes two legacy IR parts in TIR that are deprecated.

ProducerConsumer node only serves as a hint markup and may no longer be
informative after extensive transformations in the pass.
If necessary, we can add related info via AttrStmt.

The new_expr field in the AllocateNode is deprecated since it can just be
replaced by a LetStmt.

- Remove dependencies of passes on ProducerConsumer.
- Remove ProducerConsumer from the IR.
- Remove the deprecated fields (new_expr, free_function) from AllocateNode.

* Fix additional testcases
parent f1438813
...@@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode { ...@@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
}; };
// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
class ProducerConsumerNode : public StmtNode {
public:
/*! \brief The corresponding tensor. */
FunctionRef func;
/*! \brief Whether the relation is producer. */
bool is_producer;
/*! \brief Body to be executed. */
Stmt body;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("is_producer", &is_producer);
v->Visit("body", &body);
}
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(is_producer);
hash_reduce(body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
};
/*! /*!
* \brief Store value to the buffer. * \brief Store value to the buffer.
* *
...@@ -385,10 +349,6 @@ class AllocateNode : public StmtNode { ...@@ -385,10 +349,6 @@ class AllocateNode : public StmtNode {
PrimExpr condition; PrimExpr condition;
/*! \brief The body to be executed. */ /*! \brief The body to be executed. */
Stmt body; Stmt body;
// The following two fields are deprecated
// kept for backward compatibility and will be refactored later.
PrimExpr new_expr;
std::string free_function;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var); v->Visit("buffer_var", &buffer_var);
...@@ -419,9 +379,7 @@ class AllocateNode : public StmtNode { ...@@ -419,9 +379,7 @@ class AllocateNode : public StmtNode {
DataType dtype, DataType dtype,
Array<PrimExpr> extents, Array<PrimExpr> extents,
PrimExpr condition, PrimExpr condition,
Stmt body, Stmt body);
PrimExpr new_expr = PrimExpr(),
std::string free_function = std::string());
/*! /*!
* \brief If the buffer size is constant, return the size. * \brief If the buffer size is constant, return the size.
...@@ -589,8 +547,6 @@ class SeqStmt : public Stmt { ...@@ -589,8 +547,6 @@ class SeqStmt : public Stmt {
* *
* - When an argument is nullptr, it will be ignored. * - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively. * - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence. * - A normal Stmt will be appended to the end of the sequence.
* *
* \note This function can directly return an element * \note This function can directly return an element
...@@ -618,13 +574,6 @@ class SeqStmt : public Stmt { ...@@ -618,13 +574,6 @@ class SeqStmt : public Stmt {
if (!stmt.defined()) return; if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) { if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq); operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumerNode>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else { } else {
seq_->push_back(stmt); seq_->push_back(stmt);
} }
......
...@@ -94,7 +94,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -94,7 +94,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
...@@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(FreeNode); IR_STMT_FUNCTOR_DISPATCH(FreeNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
IR_STMT_FUNCTOR_DISPATCH(ProvideNode); IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
IR_STMT_FUNCTOR_DISPATCH(RealizeNode); IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
...@@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor : ...@@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
void VisitStmt_(const ProvideNode* op) override; void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const RealizeNode* op) override; void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const PrefetchNode* op) override;
...@@ -253,7 +250,6 @@ class TVM_DLL StmtMutator : ...@@ -253,7 +250,6 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerConsumerNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override; Stmt VisitStmt_(const ProvideNode* op) override;
Stmt VisitStmt_(const RealizeNode* op) override; Stmt VisitStmt_(const RealizeNode* op) override;
Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override;
......
...@@ -27,7 +27,7 @@ from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not ...@@ -27,7 +27,7 @@ from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
......
...@@ -77,26 +77,6 @@ class AssertStmt(Stmt): ...@@ -77,26 +77,6 @@ class AssertStmt(Stmt):
@tvm._ffi.register_object @tvm._ffi.register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_ffi_api.ProducerConsumer, func, is_producer, body)
@tvm._ffi.register_object
class For(Stmt): class For(Stmt):
"""For node. """For node.
...@@ -425,6 +405,4 @@ def stmt_list(stmt): ...@@ -425,6 +405,4 @@ def stmt_list(stmt):
for x in stmt: for x in stmt:
res += stmt_list(x) res += stmt_list(x)
return res return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt] return [stmt]
...@@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { ...@@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
stream << str << "\n"; stream << str << "\n";
} }
void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
void CodeGenHybrid::PrintIndent() { void CodeGenHybrid::PrintIndent() {
stream << std::string(indent_, ' '); stream << std::string(indent_, ' ');
} }
......
...@@ -131,7 +131,6 @@ class CodeGenHybrid : ...@@ -131,7 +131,6 @@ class CodeGenHybrid :
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*! /*!
* \brief Print Type represetnation of type t. * \brief Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
......
...@@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final { void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop"); int32_t constant_size = op->constant_allocation_size();
buf = MakeValue(op->new_expr); CHECK_GT(constant_size, 0)
} else { << "Can only handle constant size stack allocation in GPU";
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
<< "Can only handle constant size stack allocation in GPU"; if (constant_size % 4 == 0 && info.alignment == 0) {
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
if (constant_size % 4 == 0 && info.alignment == 0) { }
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); // maximum necessary alignment in the AMD devices
} if (info.alignment > 16) {
// maximum necessary alignment in the AMD devices info.alignment = 16;
if (info.alignment > 16) { }
info.alignment = 16; if (info.scope.rank == runtime::StorageRank::kLocal) {
} // const int local_address_space = 5;
if (info.scope.rank == runtime::StorageRank::kLocal) { // TODO(tqchen): for higher version of LLVM, local address space can be set.
// const int local_address_space = 5; llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
// TODO(tqchen): for higher version of LLVM, local address space can be set. return builder_->CreateAlloca(
llvm::AllocaInst* alloca = WithFunctionEntry([&]() { DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
return builder_->CreateAlloca( });
DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment)); alloca->setAlignment(llvm::Align(info.alignment));
#else #else
alloca->setAlignment(info.alignment); alloca->setAlignment(info.alignment);
#endif #endif
} }
buf = alloca; buf = alloca;
} else { } else {
CHECK(info.scope.rank == runtime::StorageRank::kShared) CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get( llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size); DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3 // Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable( llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment)); global->setAlignment(llvm::Align(info.alignment));
#else #else
global->setAlignment(info.alignment); global->setAlignment(info.alignment);
#endif #endif
buf = global; buf = global;
}
} }
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace())); buf->getType()->getPointerAddressSpace()));
......
...@@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { ...@@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation"; << "Can only handle constant size stack allocation";
...@@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { ...@@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
} }
info.alignment = alloca->getAlignment(); info.alignment = alloca->getAlignment();
buf = alloca; buf = alloca;
}
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace())); buf->getType()->getPointerAddressSpace()));
...@@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { ...@@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value); MakeValue(op->value);
} }
void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -150,7 +150,6 @@ class CodeGenLLVM : ...@@ -150,7 +150,6 @@ class CodeGenLLVM :
void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected: protected:
/*! \brief The storage information */ /*! \brief The storage information */
......
...@@ -48,55 +48,53 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -48,55 +48,53 @@ class CodeGenNVPTX : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final { void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop"); int32_t constant_size = op->constant_allocation_size();
buf = MakeValue(op->new_expr); CHECK_GT(constant_size, 0)
} else { << "Can only handle constant size stack allocation in GPU";
int32_t constant_size = op->constant_allocation_size(); StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
CHECK_GT(constant_size, 0) if (constant_size % 4 == 0 && info.alignment == 0) {
<< "Can only handle constant size stack allocation in GPU"; info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; }
if (constant_size % 4 == 0 && info.alignment == 0) { // maximum necessary alignment in the NV devices
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); if (info.alignment > 16) {
} info.alignment = 16;
// maximum necessary alignment in the NV devices }
if (info.alignment > 16) {
info.alignment = 16; if (info.scope.rank == runtime::StorageRank::kLocal) {
} // const int local_address_space = 5;
if (info.scope.rank == runtime::StorageRank::kLocal) { // TODO(tqchen): for higher version of LLVM, local address space can be set.
// const int local_address_space = 5; llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
// TODO(tqchen): for higher version of LLVM, local address space can be set. return builder_->CreateAlloca(
llvm::AllocaInst* alloca = WithFunctionEntry([&]() { DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
return builder_->CreateAlloca( });
DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment)); alloca->setAlignment(llvm::Align(info.alignment));
#else #else
alloca->setAlignment(info.alignment); alloca->setAlignment(info.alignment);
#endif #endif
} }
buf = alloca; buf = alloca;
} else { } else {
CHECK(info.scope.rank == runtime::StorageRank::kShared) CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get( llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size); DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3 // Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable( llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment)); global->setAlignment(llvm::Align(info.alignment));
#else #else
global->setAlignment(info.alignment); global->setAlignment(info.alignment);
#endif #endif
buf = global; buf = global;
}
} }
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace())); buf->getType()->getPointerAddressSpace()));
......
...@@ -814,14 +814,7 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { ...@@ -814,14 +814,7 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) {
void CodeGenC::VisitStmt_(const AllocateNode* op) { void CodeGenC::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->dtype, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent(); this->PrintIndent();
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
...@@ -833,7 +826,7 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { ...@@ -833,7 +826,7 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
stream << ' '<< vid << '[' stream << ' '<< vid << '['
<< constant_size << "];\n"; << constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->dtype); RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
...@@ -942,10 +935,6 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { ...@@ -942,10 +935,6 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) {
} }
} }
void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
void CodeGenC::PrintVecElemLoadExpr( void CodeGenC::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) { DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1); CHECK_GT(t.lanes(), 1);
......
...@@ -153,7 +153,6 @@ class CodeGenC : ...@@ -153,7 +153,6 @@ class CodeGenC :
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*! /*!
* Print Type represetnation of type t. * Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
......
...@@ -514,51 +514,44 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { ...@@ -514,51 +514,44 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program this->PrintIndent();
CHECK_EQ(op->free_function, "nop"); int32_t constant_size = op->constant_allocation_size();
std::string new_data = PrintExpr(op->new_expr); CHECK_GT(constant_size, 0)
this->PrintIndent(); << "Can only handle constant size stack allocation for now";
PrintType(op->dtype, stream); const VarNode* buffer = op->buffer_var.as<VarNode>();
stream << "* "<< vid << '=' << new_data << ";\n"; std::string scope = alloc_storage_scope_.at(buffer);
} else { if (scope.find("wmma.") == 0) {
this->PrintIndent(); if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
int32_t constant_size = op->constant_allocation_size(); CHECK(op->dtype == DataType::Float(16) ||
CHECK_GT(constant_size, 0) op->dtype == DataType::Int(8) ||
<< "Can only handle constant size stack allocation for now"; op->dtype == DataType::UInt(8) ||
const VarNode* buffer = op->buffer_var.as<VarNode>(); op->dtype == DataType::Int(4) ||
std::string scope = alloc_storage_scope_.at(buffer); op->dtype == DataType::UInt(4) ||
if (scope.find("wmma.") == 0) { op->dtype == DataType::Int(1))
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { << "Matrix_a and matrix_b only support half or char or unsigned char "
CHECK(op->dtype == DataType::Float(16) || << "or uint4 or int4 or int1 type for now";
op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else { } else {
PrintStorageScope(scope, stream); CHECK(op->dtype == DataType::Float(16) ||
stream << ' '; op->dtype == DataType::Float(32) ||
PrintType(op->dtype, stream); op->dtype == DataType::Int(32))
} << "Accumulator only support half, float and int type for now";
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
} }
stream << ' '<< vid << '[' constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
<< constant_size << "];\n"; PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
} }
stream << ' '<< vid << '['
<< constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype); RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
......
...@@ -586,7 +586,6 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { ...@@ -586,7 +586,6 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
CHECK(!op->new_expr.defined());
CHECK(!op->dtype.is_handle()); CHECK(!op->dtype.is_handle());
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
...@@ -659,9 +658,5 @@ void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { ...@@ -659,9 +658,5 @@ void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value); MakeValue(op->value);
} }
void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -99,7 +99,6 @@ class CodeGenSPIRV: ...@@ -99,7 +99,6 @@ class CodeGenSPIRV:
void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected: protected:
/*! \brief The storage information */ /*! \brief The storage information */
......
...@@ -156,16 +156,7 @@ void CodeGenStackVM::VisitStmt_(const StoreNode* op) { ...@@ -156,16 +156,7 @@ void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
} }
void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); LOG(FATAL) << "Dynamic allocation not supported";
int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
this->Push(op->new_expr);
this->PushOp(StackVM::STORE_HEAP, vid);
} else {
LOG(FATAL) << "Dynamic allocation not supported";
}
} }
void CodeGenStackVM::VisitExpr_(const CallNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) {
...@@ -410,10 +401,6 @@ void CodeGenStackVM::VisitExpr_(const NotNode* op) { ...@@ -410,10 +401,6 @@ void CodeGenStackVM::VisitExpr_(const NotNode* op) {
this->PushOp(StackVM::NOT); this->PushOp(StackVM::NOT);
} }
void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const ForNode* op) { void CodeGenStackVM::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
int vid = this->AllocVarID(op->loop_var.get()); int vid = this->AllocVarID(op->loop_var.get());
......
...@@ -147,7 +147,6 @@ class CodeGenStackVM ...@@ -147,7 +147,6 @@ class CodeGenStackVM
void VisitStmt_(const AssertStmtNode* op) final; void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const EvaluateNode* op) final; void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const SeqStmtNode* op) final; void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumerNode* op) final;
private: private:
bool debug_{false}; bool debug_{false};
......
...@@ -43,9 +43,6 @@ Stmt MakePipeline(const Stage& s, ...@@ -43,9 +43,6 @@ Stmt MakePipeline(const Stage& s,
Stmt consumer, Stmt consumer,
bool debug_keep_trivial_loop) { bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (producer.defined()) {
producer = ProducerConsumerNode::make(s->op, true, producer);
}
if (s->double_buffer) { if (s->double_buffer) {
producer = AttrStmtNode::make( producer = AttrStmtNode::make(
s->op, tir::attr::double_buffer_scope, 1, producer); s->op, tir::attr::double_buffer_scope, 1, producer);
...@@ -53,7 +50,6 @@ Stmt MakePipeline(const Stage& s, ...@@ -53,7 +50,6 @@ Stmt MakePipeline(const Stage& s,
Stmt pipeline = producer; Stmt pipeline = producer;
if (consumer.defined() && !is_no_op(consumer)) { if (consumer.defined() && !is_no_op(consumer)) {
consumer = ProducerConsumerNode::make(s->op, false, consumer);
pipeline = SeqStmt({producer, consumer}); pipeline = SeqStmt({producer, consumer});
} }
pipeline = s->op->BuildRealize(s, dom_map, pipeline); pipeline = s->op->BuildRealize(s, dom_map, pipeline);
...@@ -163,20 +159,6 @@ class InjectScanStep : public StmtMutator { ...@@ -163,20 +159,6 @@ class InjectScanStep : public StmtMutator {
// Replace the init and update's expression by scan's buffer. // Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public StmtExprMutator { class SchedulePostProc : public StmtExprMutator {
public: public:
Stmt VisitStmt_(const ProducerConsumerNode* op) final {
auto it = replace_op_.find(op->func.get());
if (it != replace_op_.end()) {
Stmt body = this->VisitStmt(op->body);
if (it->second.defined()) {
return ProducerConsumerNode::make(
it->second, op->is_producer, body);
} else {
return body;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const LetStmtNode* op) final { Stmt VisitStmt_(const LetStmtNode* op) final {
if (!HasSideEffect(op->value)) { if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = this->VisitExpr(op->value); var_value_[op->var.get()] = this->VisitExpr(op->value);
......
...@@ -39,8 +39,8 @@ namespace { ...@@ -39,8 +39,8 @@ namespace {
* threads, CPU code is generated that tries to access GPU memory, * threads, CPU code is generated that tries to access GPU memory,
* which is illegal. * which is illegal.
* *
* This pass performs such verification by checking if all Producer/Consumer * This pass performs such verification by checking if all
* with memory accesses are bound with threads when device type is GPU. * memory accesses are bound with threads when device type is GPU.
*/ */
class MemoryAccessVerifier final : protected StmtExprVisitor { class MemoryAccessVerifier final : protected StmtExprVisitor {
public: public:
...@@ -94,12 +94,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -94,12 +94,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
} }
} }
void VisitStmt_(const ProducerConsumerNode* op) final {
EnterProducerConsumer(op);
StmtExprVisitor::VisitStmt_(op);
ExitProducerConsumer();
}
void VisitExpr_(const LoadNode* op) final { void VisitExpr_(const LoadNode* op) final {
HandleLoadStoreToVariable(op->buffer_var); HandleLoadStoreToVariable(op->buffer_var);
return StmtExprVisitor::VisitExpr_(op); return StmtExprVisitor::VisitExpr_(op);
...@@ -138,11 +132,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -138,11 +132,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
// We skip the access within thread env. // We skip the access within thread env.
if (InThreadEnv()) return; if (InThreadEnv()) return;
// We only check access within a producer/consumer.
// Because for load/store out side of producer/consumer,
// they don't have to be in thread env to stay legal (e.g. Load of args).
if (!InProducerConsumer()) return;
// We only handle the variable from function argument. // We only handle the variable from function argument.
// If it does not come from args, then it could be allocated internally, // If it does not come from args, then it could be allocated internally,
// it may possibly be in host or device address space. // it may possibly be in host or device address space.
...@@ -158,10 +147,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -158,10 +147,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
bool InThreadEnv() const { return in_thread_env_; } bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; } void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; } void ExitThreadEnv() { in_thread_env_ = false; }
bool InProducerConsumer() const { return pc_ != nullptr; }
const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; }
void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; }
void ExitProducerConsumer() { pc_ = nullptr; }
void SetFailure() { failure_ = true; } void SetFailure() { failure_ = true; }
//@} //@}
...@@ -180,7 +165,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -180,7 +165,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Status of visitor /// Status of visitor
//@{ //@{
bool in_thread_env_{false}; bool in_thread_env_{false};
const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access) bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@} //@}
tir::PrimFunc func_{nullptr}; ///< Function to be verified. tir::PrimFunc func_{nullptr}; ///< Function to be verified.
...@@ -197,13 +181,18 @@ void VerifyMemory(const IRModule& mod) { ...@@ -197,13 +181,18 @@ void VerifyMemory(const IRModule& mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget); auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute"; << "LowerWarpMemory: Require the target attribute";
MemoryAccessVerifier v(func, target.value()->device_type);
v.Run(); if (func->GetAttr<Integer>(
if (v.Failed()) { tvm::attr::kCallingConv,
LOG(FATAL) Integer(CallingConv::kDefault)) == CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->device_type);
v.Run();
if (v.Failed()) {
LOG(FATAL)
<< "ValueError: Direct host side access to device memory is detected." << "ValueError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n" << " Did you forget to bind?\n"
<< func; << func;
}
} }
} }
} }
......
...@@ -82,20 +82,6 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") ...@@ -82,20 +82,6 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt")
} }
}); });
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
ObjectPtr<ProducerConsumerNode> node = make_object<ProducerConsumerNode>();
node->func = std::move(func);
node->is_producer = is_producer;
node->body = std::move(body);
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
.set_body_typed(ProducerConsumerNode::make);
Stmt ForNode::make(Var loop_var, Stmt ForNode::make(Var loop_var,
PrimExpr min, PrimExpr min,
PrimExpr extent, PrimExpr extent,
...@@ -184,9 +170,7 @@ Stmt AllocateNode::make(Var buffer_var, ...@@ -184,9 +170,7 @@ Stmt AllocateNode::make(Var buffer_var,
DataType dtype, DataType dtype,
Array<PrimExpr> extents, Array<PrimExpr> extents,
PrimExpr condition, PrimExpr condition,
Stmt body, Stmt body) {
PrimExpr new_expr,
std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) { for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined()); CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar()); CHECK(extents[i].dtype().is_scalar());
...@@ -201,8 +185,6 @@ Stmt AllocateNode::make(Var buffer_var, ...@@ -201,8 +185,6 @@ Stmt AllocateNode::make(Var buffer_var,
node->extents = std::move(extents); node->extents = std::move(extents);
node->condition = std::move(condition); node->condition = std::move(condition);
node->body = std::move(body); node->body = std::move(body);
node->new_expr = std::move(new_expr);
node->free_function = std::move(free_function);
return Stmt(node); return Stmt(node);
} }
...@@ -381,22 +363,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -381,22 +363,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->body); p->Print(op->body);
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerConsumerNode*>(node.get());
if (op->is_producer) {
p->PrintIndent();
p->stream << "produce " << op->func->func_name() << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
} else {
p->Print(op->body);
}
});
std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
switch (type) { switch (type) {
case ForType::Serial: case ForType::Serial:
...@@ -615,7 +581,6 @@ TVM_REGISTER_NODE_TYPE(CallNode); ...@@ -615,7 +581,6 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_NODE_TYPE(LetStmtNode); TVM_REGISTER_NODE_TYPE(LetStmtNode);
TVM_REGISTER_NODE_TYPE(AssertStmtNode); TVM_REGISTER_NODE_TYPE(AssertStmtNode);
TVM_REGISTER_NODE_TYPE(ProducerConsumerNode);
TVM_REGISTER_NODE_TYPE(ForNode); TVM_REGISTER_NODE_TYPE(ForNode);
TVM_REGISTER_NODE_TYPE(StoreNode); TVM_REGISTER_NODE_TYPE(StoreNode);
TVM_REGISTER_NODE_TYPE(ProvideNode); TVM_REGISTER_NODE_TYPE(ProvideNode);
......
...@@ -149,9 +149,6 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { ...@@ -149,9 +149,6 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) {
VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body); this->VisitStmt(op->body);
this->VisitExpr(op->condition); this->VisitExpr(op->condition);
if (op->new_expr.defined()) {
this->VisitExpr(op->new_expr);
}
} }
void StmtVisitor::VisitStmt_(const StoreNode* op) { void StmtVisitor::VisitStmt_(const StoreNode* op) {
...@@ -180,10 +177,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { ...@@ -180,10 +177,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const ProvideNode* op) { void StmtVisitor::VisitStmt_(const ProvideNode* op) {
VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitExpr(op->value); this->VisitExpr(op->value);
...@@ -291,21 +284,16 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { ...@@ -291,21 +284,16 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
Array<PrimExpr> extents = Internal::Mutate(this, op->extents); Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr new_expr;
if (op->new_expr.defined()) {
new_expr = this->VisitExpr(op->new_expr);
}
if (extents.same_as(op->extents) && if (extents.same_as(op->extents) &&
body.same_as(op->body) && body.same_as(op->body) &&
condition.same_as(op->condition) && condition.same_as(op->condition)) {
new_expr.same_as(op->new_expr)) {
return GetRef<Stmt>(op); return GetRef<Stmt>(op);
} else { } else {
auto n = CopyOnWrite(op); auto n = CopyOnWrite(op);
n->extents = std::move(extents); n->extents = std::move(extents);
n->body = std::move(body); n->body = std::move(body);
n->condition = std::move(condition); n->condition = std::move(condition);
n->new_expr = std::move(new_expr);
return Stmt(n); return Stmt(n);
} }
} }
...@@ -475,17 +463,6 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { ...@@ -475,17 +463,6 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
} }
} }
Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) {
Stmt body = this->VisitStmt(op->body);
if (body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
......
...@@ -129,9 +129,6 @@ class VarTouchedAnalysis : public StmtVisitor { ...@@ -129,9 +129,6 @@ class VarTouchedAnalysis : public StmtVisitor {
tc(op->extents[i]); tc(op->extents[i]);
} }
tc.VisitExpr(op->condition); tc.VisitExpr(op->condition);
if (op->new_expr.defined()) {
tc(op->new_expr);
}
Record(op->buffer_var.get(), tc); Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} }
...@@ -371,9 +368,6 @@ class VTInjector : public StmtExprMutator { ...@@ -371,9 +368,6 @@ class VTInjector : public StmtExprMutator {
} }
// Allocate // Allocate
Stmt VisitStmt_(const AllocateNode* op) final { Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) { if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true); return InjectVTLoop(GetRef<Stmt>(op), true);
...@@ -419,8 +413,7 @@ class VTInjector : public StmtExprMutator { ...@@ -419,8 +413,7 @@ class VTInjector : public StmtExprMutator {
} else { } else {
return AllocateNode::make( return AllocateNode::make(
op->buffer_var, op->dtype, op->buffer_var, op->dtype,
extents, condition, body, extents, condition, body);
op->new_expr, op->free_function);
} }
} }
......
...@@ -58,8 +58,7 @@ class AttrScopeLifter : public StmtMutator { ...@@ -58,8 +58,7 @@ class AttrScopeLifter : public StmtMutator {
attr_value_ = PrimExpr(); attr_value_ = PrimExpr();
return AllocateNode::make( return AllocateNode::make(
op->buffer_var, op->dtype, op->buffer_var, op->dtype,
op->extents, op->condition, body, op->extents, op->condition, body);
op->new_expr, op->free_function);
} else { } else {
return stmt; return stmt;
} }
......
...@@ -79,11 +79,7 @@ class NoOpRemover : public StmtMutator { ...@@ -79,11 +79,7 @@ class NoOpRemover : public StmtMutator {
op = stmt.as<AllocateNode>(); op = stmt.as<AllocateNode>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
} }
Stmt VisitStmt_(const ProducerConsumerNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ProducerConsumerNode>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt VisitStmt_(const RealizeNode* op) final { Stmt VisitStmt_(const RealizeNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op); Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<RealizeNode>(); op = stmt.as<RealizeNode>();
......
...@@ -158,7 +158,7 @@ class IRConvertSSA final : public StmtExprMutator { ...@@ -158,7 +158,7 @@ class IRConvertSSA final : public StmtExprMutator {
op = stmt.as<AllocateNode>(); op = stmt.as<AllocateNode>();
return AllocateNode::make( return AllocateNode::make(
new_var, op->dtype, op->extents, op->condition, new_var, op->dtype, op->extents, op->condition,
op->body, op->new_expr, op->free_function); op->body);
} else { } else {
defined_.insert(v.get()); defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
......
...@@ -403,10 +403,6 @@ class Vectorizer : public StmtExprMutator { ...@@ -403,10 +403,6 @@ class Vectorizer : public StmtExprMutator {
} }
// Allocate // Allocate
Stmt VisitStmt_(const AllocateNode* op) final { Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined()) {
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(GetRef<Stmt>(op));
}
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) { if (condition.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc "; LOG(WARNING) << "Cannot handle vector extent in alloc ";
...@@ -429,8 +425,7 @@ class Vectorizer : public StmtExprMutator { ...@@ -429,8 +425,7 @@ class Vectorizer : public StmtExprMutator {
body = this->VisitStmt(body); body = this->VisitStmt(body);
return AllocateNode::make( return AllocateNode::make(
op->buffer_var, op->dtype, op->buffer_var, op->dtype,
extents, condition, body, extents, condition, body);
op->new_expr, op->free_function);
} }
// scalarize the statment // scalarize the statment
Stmt Scalarize(Stmt stmt) { Stmt Scalarize(Stmt stmt) {
......
...@@ -56,29 +56,6 @@ class GPUCodeVerifier : public StmtVisitor { ...@@ -56,29 +56,6 @@ class GPUCodeVerifier : public StmtVisitor {
return valid_; return valid_;
} }
void VisitStmt_(const ProducerConsumerNode* op) final {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
if (op->is_producer) {
nest_level_++;
StmtVisitor::VisitStmt_(op);
nest_level_--;
} else {
StmtVisitor::VisitStmt_(op);
}
if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
}
void VisitStmt_(const AllocateNode* op) final { void VisitStmt_(const AllocateNode* op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
// visit an allocation of a buffer in shared memory, record its size // visit an allocation of a buffer in shared memory, record its size
...@@ -99,7 +76,13 @@ class GPUCodeVerifier : public StmtVisitor { ...@@ -99,7 +76,13 @@ class GPUCodeVerifier : public StmtVisitor {
} else if (op_value == "shared") { } else if (op_value == "shared") {
visited_shared_buffers_.insert(op->node.as<VarNode>()); visited_shared_buffers_.insert(op->node.as<VarNode>());
} }
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == attr::thread_extent) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
Var var = op->node.as<IterVarNode>()->var; Var var = op->node.as<IterVarNode>()->var;
const auto *extent = op->value.as<IntImmNode>(); const auto *extent = op->value.as<IntImmNode>();
CHECK(extent); CHECK(extent);
...@@ -133,8 +116,21 @@ class GPUCodeVerifier : public StmtVisitor { ...@@ -133,8 +116,21 @@ class GPUCodeVerifier : public StmtVisitor {
} }
} }
} }
nest_level_++;
StmtVisitor::VisitStmt_(op);
nest_level_--;
if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
} else {
StmtVisitor::VisitStmt_(op);
} }
StmtVisitor::VisitStmt_(op);
} }
private: private:
......
...@@ -79,9 +79,9 @@ class CustomDatatypesLowerer : public StmtExprMutator { ...@@ -79,9 +79,9 @@ class CustomDatatypesLowerer : public StmtExprMutator {
if (toBeLowered) { if (toBeLowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, return AllocateNode::make(
allocate->condition, allocate->body, allocate->new_expr, allocate->buffer_var, new_allocate_type, allocate->extents,
allocate->free_function); allocate->condition, allocate->body);
} }
return stmt; return stmt;
} }
......
...@@ -51,12 +51,13 @@ class StorageAccessInfoLower : public StmtExprMutator { ...@@ -51,12 +51,13 @@ class StorageAccessInfoLower : public StmtExprMutator {
++it->second.alloc_count; ++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1) CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string(); << "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) { if (info->head_address.defined()) {
return AllocateNode::make( return LetStmtNode::make(
op->buffer_var, op->dtype, op->extents, op->condition, op->buffer_var, info->head_address, op->body);
op->body, info->head_address, "nop"); } else {
return op->body;
} }
return op->body;
} else { } else {
return stmt; return stmt;
} }
...@@ -110,10 +111,10 @@ class StorageAccessInfoLower : public StmtExprMutator { ...@@ -110,10 +111,10 @@ class StorageAccessInfoLower : public StmtExprMutator {
} }
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var, Var buffer_var,
DataType dtype, DataType dtype,
PrimExpr offset, PrimExpr offset,
const MemoryInfo& info) { const MemoryInfo& info) {
if (ptr_type.is_handle()) { if (ptr_type.is_handle()) {
CHECK(info->head_address.defined()) CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable."; << buffer_var << " is not adddressable.";
......
...@@ -93,7 +93,6 @@ class BuiltinLower : public StmtExprMutator { ...@@ -93,7 +93,6 @@ class BuiltinLower : public StmtExprMutator {
// Lower allocate to device allocate when needed. // Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>(); op = stmt.as<AllocateNode>();
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound. // Get constant allocation bound.
int64_t dev_type; int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype); int64_t nbytes = GetVectorBytes(op->dtype);
......
...@@ -49,8 +49,8 @@ def test_split_uneven_unique_likely(): ...@@ -49,8 +49,8 @@ def test_split_uneven_unique_likely():
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5) xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c], simple_mode=True) stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(stmt.body.body.body.body, tvm.tir.stmt.IfThenElse) assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body.body).count("likely") == 1 assert str(stmt.body.body.body).count("likely") == 1
if __name__ == "__main__": if __name__ == "__main__":
test_lower_rfactor() test_lower_rfactor()
......
...@@ -366,7 +366,7 @@ def test_bind(): ...@@ -366,7 +366,7 @@ def test_bind():
c = foo(a) c = foo(a)
s = te.create_schedule(c.op) s = te.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True) ir = tvm.lower(s, [a, c], simple_mode=True)
assert not isinstance(ir, tvm.tir.AttrStmt)
func, ins, outs = run_and_check(foo, [a], target='cuda') func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
...@@ -731,8 +731,6 @@ def test_schedule(): ...@@ -731,8 +731,6 @@ def test_schedule():
sch[c].vectorize(ji) sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji) sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.tir.For) assert isinstance(ir, tvm.tir.For)
...@@ -754,8 +752,6 @@ def test_schedule(): ...@@ -754,8 +752,6 @@ def test_schedule():
sch = te.create_schedule(c.op) sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1]) sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt) assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.tir.For) assert isinstance(ir, tvm.tir.For)
......
...@@ -284,10 +284,10 @@ def test_tensor_intrin_scalar_params(): ...@@ -284,10 +284,10 @@ def test_tensor_intrin_scalar_params():
C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True) stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.tir.Evaluate) assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.body.value.args) == 5 assert len(stmt.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)" assert str(stmt.body.body.value.args[4]) == "(i + j)"
if __name__ == "__main__": if __name__ == "__main__":
test_singleton() test_singleton()
......
...@@ -72,7 +72,6 @@ def test_schedule_scan(): ...@@ -72,7 +72,6 @@ def test_schedule_scan():
s = te.create_schedule(res.op) s = te.create_schedule(res.op)
s = s.normalize() s = s.normalize()
ir = tvm.lower(s, [s_state], simple_mode=True) ir = tvm.lower(s, [s_state], simple_mode=True)
assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition")
bounds = tvm.te.schedule.InferBound(s) bounds = tvm.te.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1) assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
...@@ -557,7 +556,6 @@ def test_local_stage_predicate2(): ...@@ -557,7 +556,6 @@ def test_local_stage_predicate2():
return ret return ret
def visit_stmt(op): def visit_stmt(op):
print(op)
if (isinstance(op, tvm.tir.Allocate)): if (isinstance(op, tvm.tir.Allocate)):
return op.extents[0].value == 97 return op.extents[0].value == 97
return False return False
...@@ -593,4 +591,4 @@ if __name__ == "__main__": ...@@ -593,4 +591,4 @@ if __name__ == "__main__":
test_reduction_and_dummy_fuse_split() test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline() test_schedule_compute_inline()
test_local_stage_predicate() test_local_stage_predicate()
test_local_stage_predicate2() test_local_stage_predicate2()
\ No newline at end of file
...@@ -327,8 +327,8 @@ def test_tensorize_tensor_compute_op(): ...@@ -327,8 +327,8 @@ def test_tensorize_tensor_compute_op():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map) stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
# The loop that we tried to tensorize still exists in the code # The loop that we tried to tensorize still exists in the code
# That means tensorize didn't work as expected # That means tensorize didn't work as expected
assert isinstance(stmt.body.body.body, tvm.tir.For) assert isinstance(stmt.body.body, tvm.tir.For)
assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name assert stmt.body.body.loop_var.name == C.op.axis[0].var.name
......
...@@ -129,7 +129,7 @@ def test_tensor_compute1(): ...@@ -129,7 +129,7 @@ def test_tensor_compute1():
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert isinstance(stmt.body, tvm.tir.Evaluate)
def test_tensor_compute2(): def test_tensor_compute2():
M = 2048 M = 2048
...@@ -172,8 +172,8 @@ def test_tensor_compute2(): ...@@ -172,8 +172,8 @@ def test_tensor_compute2():
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True) stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate) assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate) assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
def test_tensor_scan(): def test_tensor_scan():
m = te.size_var("m") m = te.size_var("m")
......
...@@ -148,10 +148,6 @@ def test_stmt_constructor(): ...@@ -148,10 +148,6 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AssertStmt) assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop assert x.body == nop
x = tvm.tir.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.tir.ProducerConsumer)
assert x.body == nop
x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop) x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.tir.For) assert isinstance(x, tvm.tir.For)
assert x.min.value == 0 assert x.min.value == 0
......
...@@ -23,14 +23,6 @@ def collect_visit(stmt, f): ...@@ -23,14 +23,6 @@ def collect_visit(stmt, f):
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret return ret
def find_top_produce(stmt):
def f(x, ret):
if isinstance(x, tvm.tir.ProducerConsumer):
ret.append(x)
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret))
return ret[-1]
def lower(sch, args): def lower(sch, args):
binds = {} binds = {}
arg_list = [] arg_list = []
...@@ -65,8 +57,8 @@ def test_basic(): ...@@ -65,8 +57,8 @@ def test_basic():
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0])) assert('if' not in str(stmt.body.body[0]))
assert('if' in str(stmt.body.body.body[1])) assert('if' in str(stmt.body.body[1]))
def test_const_loop(): def test_const_loop():
n = 21 n = 21
...@@ -81,7 +73,7 @@ def test_const_loop(): ...@@ -81,7 +73,7 @@ def test_const_loop():
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0])) assert('if' not in str(stmt.body.body[0]))
def test_multi_loop(): def test_multi_loop():
ib = tvm.tir.ir_builder.create() ib = tvm.tir.ir_builder.create()
...@@ -136,7 +128,7 @@ def test_thread_axis(): ...@@ -136,7 +128,7 @@ def test_thread_axis():
stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0])) assert('if' not in str(stmt.body.body[0]))
def test_vectorize(): def test_vectorize():
n = te.size_var('n') n = te.size_var('n')
...@@ -156,7 +148,7 @@ def test_vectorize(): ...@@ -156,7 +148,7 @@ def test_vectorize():
s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(x) s[C].vectorize(x)
stmt = lower(s, [A, B]) stmt = lower(s, [A, B])
body = stmt.body.body.body.body.body body = stmt.body.body.body.body
assert(x.var.name not in str(body.condition)) assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))) assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
...@@ -199,7 +191,7 @@ def test_thread_axis2(): ...@@ -199,7 +191,7 @@ def test_thread_axis2():
s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x"))
stmt = lower(s, [A, B]) stmt = lower(s, [A, B])
for_body = stmt.body.body.body.body.body[0] for_body = stmt.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent)) assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction(): def test_everything_during_deduction():
...@@ -405,9 +397,7 @@ def test_double_splitting_with_indivisible_factors(): ...@@ -405,9 +397,7 @@ def test_double_splitting_with_indivisible_factors():
f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False) f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
func = tvm.build(f, target=target) func = tvm.build(f, target=target)
# Find the beginning of the Halide IR corresponding to kernel code top_produce = f["fadd1"].body
# and make sure it doesn't have an if statements left
top_produce = find_top_produce(f["fadd1"].body)
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))) assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
# check functional correctness of generated code # check functional correctness of generated code
......
...@@ -148,7 +148,7 @@ def test_reduce(): ...@@ -148,7 +148,7 @@ def test_reduce():
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op) s = te.create_schedule(B.op)
stmt = lower_sch(s, [A, B], target_bits) stmt = lower_sch(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype assert stmt[1].loop_var.dtype == target_dtype
# i32 -> i32 # i32 -> i32
check(const(64, dtype='int32'), 32, 'int32') check(const(64, dtype='int32'), 32, 'int32')
......
...@@ -60,11 +60,8 @@ def fold_uop_loop(stmt_in): ...@@ -60,11 +60,8 @@ def fold_uop_loop(stmt_in):
def _fold_outermost_loop(body): def _fold_outermost_loop(body):
stmt = body stmt = body
while not isinstance(stmt, tvm.tir.For): if not isinstance(stmt, tvm.tir.For):
if isinstance(stmt, (tvm.tir.ProducerConsumer,)): return None, body, None
stmt = stmt.body
else:
return None, body, None
loop_var = stmt.loop_var loop_var = stmt.loop_var
gemm_offsets = [None, None, None] gemm_offsets = [None, None, None]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment