Unverified Commit e8138f7d by Tianqi Chen Committed by GitHub

[TIR] Remove ProducerConsumer and AllocateNode::new_expr (#5333)

* [TIR] Remove ProducerConsumer and AllocateNode::new_expr

This PR removes two legacy IR parts in TIR that are deprecated.

ProducerConsumer node only serves as a hint markup and may no longer be
informative after extensive transformations in the pass.
If necessary, we can add related info via AttrStmt.

The new_expr field in the AllocateNode is deprecated since it can just be
replaced by a LetStmt.

- Remove dependencies of passes on ProducerConsumer.
- Remove ProducerConsumer from the IR.
- Remove the deprecated fields (new_expr, free_function) from AllocateNode.

* Fix additional testcases
parent f1438813
......@@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};
// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
class ProducerConsumerNode : public StmtNode {
public:
/*! \brief The corresponding tensor. */
FunctionRef func;
/*! \brief Whether the relation is producer. */
bool is_producer;
/*! \brief Body to be executed. */
Stmt body;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("is_producer", &is_producer);
v->Visit("body", &body);
}
bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(is_producer, other->is_producer) &&
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(is_producer);
hash_reduce(body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
};
/*!
* \brief Store value to the buffer.
*
......@@ -385,10 +349,6 @@ class AllocateNode : public StmtNode {
PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
// The following two fields are deprecated
// kept for backward compatibility and will be refactored later.
PrimExpr new_expr;
std::string free_function;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
......@@ -419,9 +379,7 @@ class AllocateNode : public StmtNode {
DataType dtype,
Array<PrimExpr> extents,
PrimExpr condition,
Stmt body,
PrimExpr new_expr = PrimExpr(),
std::string free_function = std::string());
Stmt body);
/*!
* \brief If the buffer size is constant, return the size.
......@@ -589,8 +547,6 @@ class SeqStmt : public Stmt {
*
* - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence.
*
* \note This function can directly return an element
......@@ -618,13 +574,6 @@ class SeqStmt : public Stmt {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumerNode>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else {
seq_->push_back(stmt);
}
......
......@@ -94,7 +94,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
......@@ -117,7 +116,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(FreeNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
......@@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const PrefetchNode* op) override;
......@@ -253,7 +250,6 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerConsumerNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
Stmt VisitStmt_(const RealizeNode* op) override;
Stmt VisitStmt_(const PrefetchNode* op) override;
......
......@@ -27,7 +27,7 @@ from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
......
......@@ -77,26 +77,6 @@ class AssertStmt(Stmt):
@tvm._ffi.register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
Parameters
----------
func : Operation
The Operation.
is_producer : bool
Whether if the node is producer.
body : Stmt
The body statement.
"""
def __init__(self, func, is_producer, body):
self.__init_handle_by_constructor__(
_ffi_api.ProducerConsumer, func, is_producer, body)
@tvm._ffi.register_object
class For(Stmt):
"""For node.
......@@ -425,6 +405,4 @@ def stmt_list(stmt):
for x in stmt:
res += stmt_list(x)
return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
......@@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
stream << str << "\n";
}
void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
void CodeGenHybrid::PrintIndent() {
stream << std::string(indent_, ' ');
}
......
......@@ -131,7 +131,6 @@ class CodeGenHybrid :
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* \brief Print Type represetnation of type t.
* \param t The type representation.
......
......@@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment));
alloca->setAlignment(llvm::Align(info.alignment));
#else
alloca->setAlignment(info.alignment);
alloca->setAlignment(info.alignment);
#endif
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
global->setAlignment(llvm::Align(info.alignment));
#else
global->setAlignment(info.alignment);
global->setAlignment(info.alignment);
#endif
buf = global;
}
buf = global;
}
buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
......
......@@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation";
......@@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
}
info.alignment = alloca->getAlignment();
buf = alloca;
}
buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
......@@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value);
}
void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -150,7 +150,6 @@ class CodeGenLLVM :
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
......
......@@ -48,55 +48,53 @@ class CodeGenNVPTX : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
alloca->setAlignment(llvm::Align(info.alignment));
alloca->setAlignment(llvm::Align(info.alignment));
#else
alloca->setAlignment(info.alignment);
alloca->setAlignment(info.alignment);
#endif
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
}
buf = alloca;
} else {
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
global->setAlignment(llvm::Align(info.alignment));
#else
global->setAlignment(info.alignment);
global->setAlignment(info.alignment);
#endif
buf = global;
}
buf = global;
}
buf = builder_->CreatePointerCast(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
......
......@@ -814,14 +814,7 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) {
void CodeGenC::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->dtype, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
......@@ -833,7 +826,7 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
PrintType(op->dtype, stream);
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
......@@ -942,10 +935,6 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) {
}
}
void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
void CodeGenC::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
......
......@@ -153,7 +153,6 @@ class CodeGenC :
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
......
......@@ -514,51 +514,44 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->dtype, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
......
......@@ -586,7 +586,6 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
CHECK(!op->new_expr.defined());
CHECK(!op->dtype.is_handle());
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
......@@ -659,9 +658,5 @@ void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value);
}
void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen
} // namespace tvm
......@@ -99,7 +99,6 @@ class CodeGenSPIRV:
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
......
......@@ -156,16 +156,7 @@ void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
}
void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
this->Push(op->new_expr);
this->PushOp(StackVM::STORE_HEAP, vid);
} else {
LOG(FATAL) << "Dynamic allocation not supported";
}
LOG(FATAL) << "Dynamic allocation not supported";
}
void CodeGenStackVM::VisitExpr_(const CallNode* op) {
......@@ -410,10 +401,6 @@ void CodeGenStackVM::VisitExpr_(const NotNode* op) {
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
int vid = this->AllocVarID(op->loop_var.get());
......
......@@ -147,7 +147,6 @@ class CodeGenStackVM
void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumerNode* op) final;
private:
bool debug_{false};
......
......@@ -43,9 +43,6 @@ Stmt MakePipeline(const Stage& s,
Stmt consumer,
bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (producer.defined()) {
producer = ProducerConsumerNode::make(s->op, true, producer);
}
if (s->double_buffer) {
producer = AttrStmtNode::make(
s->op, tir::attr::double_buffer_scope, 1, producer);
......@@ -53,7 +50,6 @@ Stmt MakePipeline(const Stage& s,
Stmt pipeline = producer;
if (consumer.defined() && !is_no_op(consumer)) {
consumer = ProducerConsumerNode::make(s->op, false, consumer);
pipeline = SeqStmt({producer, consumer});
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
......@@ -163,20 +159,6 @@ class InjectScanStep : public StmtMutator {
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public StmtExprMutator {
public:
Stmt VisitStmt_(const ProducerConsumerNode* op) final {
auto it = replace_op_.find(op->func.get());
if (it != replace_op_.end()) {
Stmt body = this->VisitStmt(op->body);
if (it->second.defined()) {
return ProducerConsumerNode::make(
it->second, op->is_producer, body);
} else {
return body;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const LetStmtNode* op) final {
if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = this->VisitExpr(op->value);
......
......@@ -39,8 +39,8 @@ namespace {
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal.
*
* This pass performs such verification by checking if all Producer/Consumer
* with memory accesses are bound with threads when device type is GPU.
* This pass performs such verification by checking if all
* memory accesses are bound with threads when device type is GPU.
*/
class MemoryAccessVerifier final : protected StmtExprVisitor {
public:
......@@ -94,12 +94,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
}
}
void VisitStmt_(const ProducerConsumerNode* op) final {
EnterProducerConsumer(op);
StmtExprVisitor::VisitStmt_(op);
ExitProducerConsumer();
}
void VisitExpr_(const LoadNode* op) final {
HandleLoadStoreToVariable(op->buffer_var);
return StmtExprVisitor::VisitExpr_(op);
......@@ -138,11 +132,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
// We skip the access within thread env.
if (InThreadEnv()) return;
// We only check access within a producer/consumer.
// Because for load/store out side of producer/consumer,
// they don't have to be in thread env to stay legal (e.g. Load of args).
if (!InProducerConsumer()) return;
// We only handle the variable from function argument.
// If it does not come from args, then it could be allocated internally,
// it may possibly be in host or device address space.
......@@ -158,10 +147,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
bool InProducerConsumer() const { return pc_ != nullptr; }
const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; }
void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; }
void ExitProducerConsumer() { pc_ = nullptr; }
void SetFailure() { failure_ = true; }
//@}
......@@ -180,7 +165,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Status of visitor
//@{
bool in_thread_env_{false};
const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
tir::PrimFunc func_{nullptr}; ///< Function to be verified.
......@@ -197,13 +181,18 @@ void VerifyMemory(const IRModule& mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute";
MemoryAccessVerifier v(func, target.value()->device_type);
v.Run();
if (v.Failed()) {
LOG(FATAL)
if (func->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->device_type);
v.Run();
if (v.Failed()) {
LOG(FATAL)
<< "ValueError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n"
<< func;
}
}
}
}
......
......@@ -82,20 +82,6 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt")
}
});
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
ObjectPtr<ProducerConsumerNode> node = make_object<ProducerConsumerNode>();
node->func = std::move(func);
node->is_producer = is_producer;
node->body = std::move(body);
return Stmt(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
.set_body_typed(ProducerConsumerNode::make);
Stmt ForNode::make(Var loop_var,
PrimExpr min,
PrimExpr extent,
......@@ -184,9 +170,7 @@ Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
PrimExpr condition,
Stmt body,
PrimExpr new_expr,
std::string free_function) {
Stmt body) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
......@@ -201,8 +185,6 @@ Stmt AllocateNode::make(Var buffer_var,
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->new_expr = std::move(new_expr);
node->free_function = std::move(free_function);
return Stmt(node);
}
......@@ -381,22 +363,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerConsumerNode*>(node.get());
if (op->is_producer) {
p->PrintIndent();
p->stream << "produce " << op->func->func_name() << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
} else {
p->Print(op->body);
}
});
std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
switch (type) {
case ForType::Serial:
......@@ -615,7 +581,6 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_NODE_TYPE(LetStmtNode);
TVM_REGISTER_NODE_TYPE(AssertStmtNode);
TVM_REGISTER_NODE_TYPE(ProducerConsumerNode);
TVM_REGISTER_NODE_TYPE(ForNode);
TVM_REGISTER_NODE_TYPE(StoreNode);
TVM_REGISTER_NODE_TYPE(ProvideNode);
......
......@@ -149,9 +149,6 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) {
VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
if (op->new_expr.defined()) {
this->VisitExpr(op->new_expr);
}
}
void StmtVisitor::VisitStmt_(const StoreNode* op) {
......@@ -180,10 +177,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
void StmtVisitor::VisitStmt_(const ProvideNode* op) {
VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitExpr(op->value);
......@@ -291,21 +284,16 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr new_expr;
if (op->new_expr.defined()) {
new_expr = this->VisitExpr(op->new_expr);
}
if (extents.same_as(op->extents) &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->extents = std::move(extents);
n->body = std::move(body);
n->condition = std::move(condition);
n->new_expr = std::move(new_expr);
return Stmt(n);
}
}
......@@ -475,17 +463,6 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
}
}
Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) {
Stmt body = this->VisitStmt(op->body);
if (body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
......
......@@ -129,9 +129,6 @@ class VarTouchedAnalysis : public StmtVisitor {
tc(op->extents[i]);
}
tc.VisitExpr(op->condition);
if (op->new_expr.defined()) {
tc(op->new_expr);
}
Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body);
}
......@@ -371,9 +368,6 @@ class VTInjector : public StmtExprMutator {
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
......@@ -419,8 +413,7 @@ class VTInjector : public StmtExprMutator {
} else {
return AllocateNode::make(
op->buffer_var, op->dtype,
extents, condition, body,
op->new_expr, op->free_function);
extents, condition, body);
}
}
......
......@@ -58,8 +58,7 @@ class AttrScopeLifter : public StmtMutator {
attr_value_ = PrimExpr();
return AllocateNode::make(
op->buffer_var, op->dtype,
op->extents, op->condition, body,
op->new_expr, op->free_function);
op->extents, op->condition, body);
} else {
return stmt;
}
......
......@@ -79,11 +79,7 @@ class NoOpRemover : public StmtMutator {
op = stmt.as<AllocateNode>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
}
Stmt VisitStmt_(const ProducerConsumerNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ProducerConsumerNode>();
return is_no_op(op->body) ? op->body : stmt;
}
Stmt VisitStmt_(const RealizeNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<RealizeNode>();
......
......@@ -158,7 +158,7 @@ class IRConvertSSA final : public StmtExprMutator {
op = stmt.as<AllocateNode>();
return AllocateNode::make(
new_var, op->dtype, op->extents, op->condition,
op->body, op->new_expr, op->free_function);
op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
......
......@@ -403,10 +403,6 @@ class Vectorizer : public StmtExprMutator {
}
// Allocate
Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined()) {
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(GetRef<Stmt>(op));
}
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
......@@ -429,8 +425,7 @@ class Vectorizer : public StmtExprMutator {
body = this->VisitStmt(body);
return AllocateNode::make(
op->buffer_var, op->dtype,
extents, condition, body,
op->new_expr, op->free_function);
extents, condition, body);
}
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
......
......@@ -56,29 +56,6 @@ class GPUCodeVerifier : public StmtVisitor {
return valid_;
}
void VisitStmt_(const ProducerConsumerNode* op) final {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
if (op->is_producer) {
nest_level_++;
StmtVisitor::VisitStmt_(op);
nest_level_--;
} else {
StmtVisitor::VisitStmt_(op);
}
if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
}
void VisitStmt_(const AllocateNode* op) final {
StmtVisitor::VisitStmt_(op);
// visit an allocation of a buffer in shared memory, record its size
......@@ -99,7 +76,13 @@ class GPUCodeVerifier : public StmtVisitor {
} else if (op_value == "shared") {
visited_shared_buffers_.insert(op->node.as<VarNode>());
}
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::thread_extent) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
Var var = op->node.as<IterVarNode>()->var;
const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
......@@ -133,8 +116,21 @@ class GPUCodeVerifier : public StmtVisitor {
}
}
}
nest_level_++;
StmtVisitor::VisitStmt_(op);
nest_level_--;
if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;
valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
}
} else {
StmtVisitor::VisitStmt_(op);
}
StmtVisitor::VisitStmt_(op);
}
private:
......
......@@ -79,9 +79,9 @@ class CustomDatatypesLowerer : public StmtExprMutator {
if (toBeLowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
allocate->condition, allocate->body, allocate->new_expr,
allocate->free_function);
return AllocateNode::make(
allocate->buffer_var, new_allocate_type, allocate->extents,
allocate->condition, allocate->body);
}
return stmt;
}
......
......@@ -51,12 +51,13 @@ class StorageAccessInfoLower : public StmtExprMutator {
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return AllocateNode::make(
op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
return LetStmtNode::make(
op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
return op->body;
} else {
return stmt;
}
......@@ -110,10 +111,10 @@ class StorageAccessInfoLower : public StmtExprMutator {
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
DataType dtype,
PrimExpr offset,
const MemoryInfo& info) {
Var buffer_var,
DataType dtype,
PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
......
......@@ -93,7 +93,6 @@ class BuiltinLower : public StmtExprMutator {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype);
......
......@@ -49,8 +49,8 @@ def test_split_uneven_unique_likely():
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(stmt.body.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body.body).count("likely") == 1
assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body).count("likely") == 1
if __name__ == "__main__":
test_lower_rfactor()
......
......@@ -366,7 +366,7 @@ def test_bind():
c = foo(a)
s = te.create_schedule(c.op)
ir = tvm.lower(s, [a, c], simple_mode=True)
assert not isinstance(ir, tvm.tir.AttrStmt)
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
......@@ -731,8 +731,6 @@ def test_schedule():
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
......@@ -754,8 +752,6 @@ def test_schedule():
sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.tir.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
......
......@@ -284,10 +284,10 @@ def test_tensor_intrin_scalar_params():
C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, C], simple_mode=True)
assert isinstance(stmt.body.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.body.value.args) == 5
assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.value.args) == 5
assert str(stmt.body.body.value.args[3]) == "(i*i)"
assert str(stmt.body.body.value.args[4]) == "(i + j)"
if __name__ == "__main__":
test_singleton()
......
......@@ -72,7 +72,6 @@ def test_schedule_scan():
s = te.create_schedule(res.op)
s = s.normalize()
ir = tvm.lower(s, [s_state], simple_mode=True)
assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition")
bounds = tvm.te.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
......@@ -557,7 +556,6 @@ def test_local_stage_predicate2():
return ret
def visit_stmt(op):
print(op)
if (isinstance(op, tvm.tir.Allocate)):
return op.extents[0].value == 97
return False
......@@ -593,4 +591,4 @@ if __name__ == "__main__":
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
test_local_stage_predicate()
test_local_stage_predicate2()
\ No newline at end of file
test_local_stage_predicate2()
......@@ -327,8 +327,8 @@ def test_tensorize_tensor_compute_op():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
# The loop that we tried to tensorize still exists in the code
# That means tensorize didn't work as expected
assert isinstance(stmt.body.body.body, tvm.tir.For)
assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name
assert isinstance(stmt.body.body, tvm.tir.For)
assert stmt.body.body.loop_var.name == C.op.axis[0].var.name
......
......@@ -129,7 +129,7 @@ def test_tensor_compute1():
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert isinstance(stmt.body, tvm.tir.Evaluate)
def test_tensor_compute2():
M = 2048
......@@ -172,8 +172,8 @@ def test_tensor_compute2():
s = te.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate)
assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
def test_tensor_scan():
m = te.size_var("m")
......
......@@ -148,10 +148,6 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
x = tvm.tir.ProducerConsumer(None, True, nop)
assert isinstance(x, tvm.tir.ProducerConsumer)
assert x.body == nop
x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop)
assert isinstance(x, tvm.tir.For)
assert x.min.value == 0
......
......@@ -23,14 +23,6 @@ def collect_visit(stmt, f):
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret
def find_top_produce(stmt):
def f(x, ret):
if isinstance(x, tvm.tir.ProducerConsumer):
ret.append(x)
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret))
return ret[-1]
def lower(sch, args):
binds = {}
arg_list = []
......@@ -65,8 +57,8 @@ def test_basic():
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0]))
assert('if' in str(stmt.body.body.body[1]))
assert('if' not in str(stmt.body.body[0]))
assert('if' in str(stmt.body.body[1]))
def test_const_loop():
n = 21
......@@ -81,7 +73,7 @@ def test_const_loop():
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0]))
assert('if' not in str(stmt.body.body[0]))
def test_multi_loop():
ib = tvm.tir.ir_builder.create()
......@@ -136,7 +128,7 @@ def test_thread_axis():
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body[0]))
assert('if' not in str(stmt.body.body[0]))
def test_vectorize():
n = te.size_var('n')
......@@ -156,7 +148,7 @@ def test_vectorize():
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = lower(s, [A, B])
body = stmt.body.body.body.body.body
body = stmt.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))))
......@@ -199,7 +191,7 @@ def test_thread_axis2():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
stmt = lower(s, [A, B])
for_body = stmt.body.body.body.body.body[0]
for_body = stmt.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction():
......@@ -405,9 +397,7 @@ def test_double_splitting_with_indivisible_factors():
f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False)
func = tvm.build(f, target=target)
# Find the beginning of the Halide IR corresponding to kernel code
# and make sure it doesn't have an if statements left
top_produce = find_top_produce(f["fadd1"].body)
top_produce = f["fadd1"].body
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
# check functional correctness of generated code
......
......@@ -148,7 +148,7 @@ def test_reduce():
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op)
stmt = lower_sch(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype
assert stmt[1].loop_var.dtype == target_dtype
# i32 -> i32
check(const(64, dtype='int32'), 32, 'int32')
......
......@@ -60,11 +60,8 @@ def fold_uop_loop(stmt_in):
def _fold_outermost_loop(body):
stmt = body
while not isinstance(stmt, tvm.tir.For):
if isinstance(stmt, (tvm.tir.ProducerConsumer,)):
stmt = stmt.body
else:
return None, body, None
if not isinstance(stmt, tvm.tir.For):
return None, body, None
loop_var = stmt.loop_var
gemm_offsets = [None, None, None]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment