Unverified Commit 3595cbe0 by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Introduce SeqStmt to replace ir::Block (#4627)

* [REFACTOR][IR] Introduce SeqStmt to replace Block

ir::Block was used to represent a sequence of Stmts in the original low-level IR.
The nested ir::Block structure is not really friendly for recursive visits,
especially when the statements are unrolled.

This PR introduce a SeqStmt that directly stores a sequence of statements in an Array container.
The new SeqStmt will be used as a replacement of the original Block structure.

* [REFACTOR] Migrate use of Block to SeqStmt.

* [REFACTOR] Remove Block

* Add more comments per yizhi's comment
parent 34b98eb7
......@@ -667,6 +667,9 @@ inline bool is_no_op(const Stmt& stmt) {
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
return op->seq.size() == 0;
}
return false;
}
......
......@@ -1022,25 +1022,112 @@ class Realize : public StmtNode {
};
/*!
* \brief A sequence of statements.
* \brief The container of seq statement.
* Represent a sequence of statements.
*/
class Block : public StmtNode {
class SeqStmtNode : public StmtNode {
public:
/*! \brief The first statement. */
Stmt first;
/*! \brief The restof statments. */
Stmt rest;
/*! \brief internal sequence content. */
Array<Stmt> seq;
/*! \return get the size of the sequence */
size_t size() const {
return seq.size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return seq[index];
}
void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first);
v->Visit("rest", &rest);
v->Visit("seq", &seq);
}
TVM_DLL static Stmt make(Stmt first, Stmt rest);
TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
/*! \brief Sequence statement. */
class SeqStmt : public Stmt {
public:
/*!
* \brief Construct SeqStmt.
* \param seq The sequence.
*/
TVM_DLL explicit SeqStmt(Array<Stmt> seq);
/*! \return get the size of the sequence */
size_t size() const {
return operator->()->size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return (*(operator->()))[index];
}
/*!
* \brief Construct a sequence statement by flattening
* all the arrays and sequences in the arguments
* recursively.
*
* - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence.
*
* \note This function can directly return an element
* if it is the only element in the sequence.
*
* \param seq_args The list of arguments to be flattened.
* \tparam Args arguments
* \return The constructed statement
*/
template<typename ...Args>
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
runtime::detail::for_each(
Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];
return SeqStmt(seq);
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
public:
explicit Flattener(Array<Stmt>* seq)
: seq_(seq) {}
void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumer>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else {
seq_->push_back(stmt);
}
}
template<typename T>
void operator()(size_t i, const T& seq) const {
for (auto v : seq) {
this->operator()(0, v);
}
}
private:
Array<Stmt>* seq_;
};
static constexpr const char* _type_key = "Block";
TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode);
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
};
/*!
......
......@@ -253,7 +253,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
......@@ -276,7 +276,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
......@@ -408,7 +408,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
};
......@@ -502,8 +502,23 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
* This function can be called when a child class override
* VisitStmt_(const SeqStmtNode*) to introduce
* the special behavior to visit
*
* \param op The sequence.
* \param flatten_before_visit Whether to flatten the sequence before visit.
* \param fmutate The mutate function, can be nullptr, which defaults to Visit.
* \return The mutated result.
*/
Stmt VisitSeqStmt_(const SeqStmtNode* op,
bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate = nullptr);
// internal helper.
class Internal;
};
......
......@@ -272,6 +272,14 @@ class Array : public ObjectRef {
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
......
......@@ -37,6 +37,8 @@ from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt
from .. import api as _api
from .. import ir_pass as _ir_pass
......@@ -48,11 +50,7 @@ def concat_list_to_block(lst):
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
body = _make.Block(stmt, body)
return body
return _stmt.SeqStmt(lst)
def visit_list_to_block(visit, lst):
......
......@@ -120,14 +120,16 @@ class IRBuilder(object):
seq = self._seq_stack.pop()
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
stmt = seq[-1]
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]]
for s in reversed(seq[:-1]):
if callable(s):
stmt = s(stmt)
ret_seq = [s(seqwrap(ret_seq))]
else:
assert isinstance(s, _stmt.Stmt)
stmt = _make.Block(s, stmt)
return stmt
ret_seq.append(s)
return seqwrap(ret_seq)
def emit(self, stmt):
"""Emit a statement to the end of current scope.
......
......@@ -289,20 +289,23 @@ class Realize(Stmt):
@register_node
class Block(Stmt):
"""Block node.
class SeqStmt(Stmt):
"""Sequence of statements.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
seq : List[Stmt]
The statements
"""
def __init__(self, first, rest):
def __init__(self, seq):
self.__init_handle_by_constructor__(
_make.Block, first, rest)
_make.SeqStmt, seq)
def __getitem__(self, i):
return self.seq[i]
def __len__(self):
return len(self.seq)
@register_node
......@@ -375,12 +378,14 @@ def stmt_seq(*args):
stmt : Stmt
The combined statement.
"""
ret = None
ret = []
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
ret.append(value)
if len(ret) == 1:
return ret[0]
return SeqStmt(ret)
def stmt_list(stmt):
......@@ -395,12 +400,14 @@ def stmt_list(stmt):
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
if isinstance(stmt, SeqStmt):
res = []
for x in stmt:
res += stmt_list(x)
return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq
......@@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("make._cast")
TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("make.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});
TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
......@@ -163,9 +169,6 @@ REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);
// overloaded, needs special handling
TVM_REGISTER_GLOBAL("make.Block")
.set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));
// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([](
......
......@@ -405,22 +405,22 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
}
}
void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
if (op->dtype.lanes() == 1) {
......@@ -443,7 +443,7 @@ inline void PrintBinaryExpr(const T* op,
}
inline void PrintBinaryIntrinsic(const Call* op,
const char *opstr,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
if (op->dtype.lanes() == 1) {
......@@ -457,65 +457,65 @@ inline void PrintBinaryIntrinsic(const Call* op,
p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*)
std::stringstream value;
this->PrintExpr(op->value, value);
os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
}
void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*)
os << '!';
PrintExpr(op->a, os);
}
void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->call_type == Call::Extern ||
op->call_type == Call::PureExtern) {
os << op->name << "(";
......@@ -875,12 +875,13 @@ void CodeGenC::VisitStmt_(const IfThenElse* op) {
stream << "}\n";
}
void CodeGenC::VisitStmt_(const Block *op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
}
}
void CodeGenC::VisitStmt_(const Evaluate *op) {
void CodeGenC::VisitStmt_(const Evaluate* op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call) {
......@@ -906,7 +907,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) {
}
}
void CodeGenC::VisitStmt_(const ProducerConsumer *op) {
void CodeGenC::VisitStmt_(const ProducerConsumer* op) {
PrintStmt(op->body);
}
......
......@@ -140,7 +140,7 @@ class CodeGenC :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* Print Type represetnation of type t.
......
......@@ -1214,10 +1214,9 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const Block* op) {
this->VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
}
}
......
......@@ -140,7 +140,7 @@ class CodeGenLLVM :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
......
......@@ -638,10 +638,9 @@ void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
}
}
......
......@@ -98,7 +98,7 @@ class CodeGenSPIRV:
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
......
......@@ -268,59 +268,59 @@ void CodeGenStackVM::PushCast(DataType dst, DataType src) {
}
}
void CodeGenStackVM::VisitExpr_(const StringImm *op) {
void CodeGenStackVM::VisitExpr_(const StringImm* op) {
int sid = this->GetStrID(op->value);
this->PushOp(StackVM::PUSH_I64, sid);
}
void CodeGenStackVM::VisitExpr_(const IntImm *op) {
void CodeGenStackVM::VisitExpr_(const IntImm* op) {
CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
void CodeGenStackVM::VisitExpr_(const UIntImm *op) {
void CodeGenStackVM::VisitExpr_(const UIntImm* op) {
CHECK(op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
void CodeGenStackVM::VisitExpr_(const FloatImm *op) {
void CodeGenStackVM::VisitExpr_(const FloatImm* op) {
LOG(FATAL) << "Float Imm is not supported";
}
void CodeGenStackVM::VisitExpr_(const Variable *op) {
void CodeGenStackVM::VisitExpr_(const Variable* op) {
int vid = this->GetVarID(op);
this->PushOp(StackVM::LOAD_HEAP, vid);
}
void CodeGenStackVM::VisitExpr_(const Cast *op) {
void CodeGenStackVM::VisitExpr_(const Cast* op) {
this->Push(op->value);
PushCast(op->dtype, op->value.dtype());
}
void CodeGenStackVM::VisitExpr_(const Add *op) {
void CodeGenStackVM::VisitExpr_(const Add* op) {
PushBinary(StackVM::ADD_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Sub *op) {
void CodeGenStackVM::VisitExpr_(const Sub* op) {
PushBinary(StackVM::SUB_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Mul *op) {
void CodeGenStackVM::VisitExpr_(const Mul* op) {
PushBinary(StackVM::MUL_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Div *op) {
void CodeGenStackVM::VisitExpr_(const Div* op) {
PushBinary(StackVM::DIV_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Mod *op) {
void CodeGenStackVM::VisitExpr_(const Mod* op) {
PushBinary(StackVM::MOD_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const Min *op) {
void CodeGenStackVM::VisitExpr_(const Min* op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, -1);
......@@ -329,7 +329,7 @@ void CodeGenStackVM::VisitExpr_(const Min *op) {
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitExpr_(const Max *op) {
void CodeGenStackVM::VisitExpr_(const Max* op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, 0);
......@@ -338,34 +338,34 @@ void CodeGenStackVM::VisitExpr_(const Max *op) {
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitExpr_(const EQ *op) {
void CodeGenStackVM::VisitExpr_(const EQ* op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const LE *op) {
void CodeGenStackVM::VisitExpr_(const LE* op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const NE *op) {
void CodeGenStackVM::VisitExpr_(const NE* op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const LT *op) {
void CodeGenStackVM::VisitExpr_(const LT* op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
}
void CodeGenStackVM::VisitExpr_(const GE *op) {
void CodeGenStackVM::VisitExpr_(const GE* op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const GT *op) {
void CodeGenStackVM::VisitExpr_(const GT* op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitExpr_(const And *op) {
void CodeGenStackVM::VisitExpr_(const And* op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
......@@ -375,7 +375,7 @@ void CodeGenStackVM::VisitExpr_(const And *op) {
this->SetOperand(opr_index, diff);
}
void CodeGenStackVM::VisitExpr_(const Or *op) {
void CodeGenStackVM::VisitExpr_(const Or* op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
......@@ -389,11 +389,11 @@ void CodeGenStackVM::VisitExpr_(const Not* op) {
this->PushOp(StackVM::NOT);
}
void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) {
void CodeGenStackVM::VisitStmt_(const ProducerConsumer* op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const For *op) {
void CodeGenStackVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
int vid = this->AllocVarID(op->loop_var.get());
this->PushOp(StackVM::PUSH_I64, 0);
......@@ -417,9 +417,10 @@ void CodeGenStackVM::VisitStmt_(const For *op) {
this->SetOperand(backward_jump, loop_head - label_bjump);
}
void CodeGenStackVM::VisitStmt_(const Block *op) {
this->Push(op->first);
if (op->rest.defined()) this->Push(op->rest);
void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->Push(stmt);
}
}
void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
......@@ -444,7 +445,7 @@ void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
}
}
void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
void CodeGenStackVM::VisitStmt_(const IfThenElse* op) {
this->Push(op->condition);
int64_t label_ejump = this->GetPC();
int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
......@@ -466,29 +467,29 @@ void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
}
}
void CodeGenStackVM::VisitStmt_(const LetStmt *op) {
void CodeGenStackVM::VisitStmt_(const LetStmt* op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body);
}
void CodeGenStackVM::VisitExpr_(const Ramp *op) {
void CodeGenStackVM::VisitExpr_(const Ramp* op) {
LOG(FATAL) << "Ramp is not supported";
}
void CodeGenStackVM::VisitExpr_(const Broadcast *op) {
void CodeGenStackVM::VisitExpr_(const Broadcast* op) {
LOG(FATAL) << "Broadcast is not supported";
}
void CodeGenStackVM::VisitExpr_(const Select *op) {
void CodeGenStackVM::VisitExpr_(const Select* op) {
this->Push(op->true_value);
this->Push(op->false_value);
this->Push(op->condition);
this->PushOp(StackVM::SELECT);
}
void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
void CodeGenStackVM::VisitStmt_(const AssertStmt* op) {
if (const auto* str = op->message.as<StringImm>()) {
int sid = this->GetStrID(str->value);
this->Push(op->condition);
......@@ -497,11 +498,11 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitStmt_(const AttrStmt *op) {
void CodeGenStackVM::VisitStmt_(const AttrStmt* op) {
this->Push(op->body);
}
void CodeGenStackVM::VisitExpr_(const Let *op) {
void CodeGenStackVM::VisitExpr_(const Let* op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -148,7 +148,7 @@ class CodeGenStackVM
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
private:
......
......@@ -76,24 +76,24 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
os << t.bits();
}
void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*)
os << op->value;
}
void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << std::setprecision(20) << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
os << "'" << op->value << "'";
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented";
......@@ -115,7 +115,7 @@ inline void PrintBinaryExpr(const T* op,
}
inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented";
......@@ -127,7 +127,7 @@ inline void PrintBinaryIntrinsitc(const Call* op,
os << ')';
}
void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*)
if (op->dtype == op->value.dtype()) {
PrintExpr(op->value, stream);
} else {
......@@ -138,76 +138,76 @@ void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
}
}
void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*)
os << "not ";
PrintExpr(op->a, os);
}
void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->call_type == Call::Halide) {
os << GetTensorID(op->func, op->value_index);
os << "[";
......@@ -313,7 +313,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
}
}
void CodeGenHybrid::VisitStmt_(const Realize *op) {
void CodeGenHybrid::VisitStmt_(const Realize* op) {
CHECK(alloc_storage_scope_.count(op->func));
if (!alloc_storage_scope_[op->func].empty()) {
PrintIndent();
......@@ -389,19 +389,20 @@ void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
}
}
void CodeGenHybrid::VisitStmt_(const Block *op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
}
}
void CodeGenHybrid::VisitStmt_(const Evaluate *op) {
void CodeGenHybrid::VisitStmt_(const Evaluate* op) {
if (is_const(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty())
stream << str << "\n";
}
void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) {
void CodeGenHybrid::VisitStmt_(const ProducerConsumer* op) {
PrintStmt(op->body);
}
......
......@@ -131,7 +131,7 @@ class CodeGenHybrid :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* \brief Print Type represetnation of type t.
......
......@@ -504,31 +504,10 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo
return Stmt(node);
}
Stmt Block::make(Stmt first, Stmt rest) {
CHECK(first.defined());
CHECK(rest.defined());
ObjectPtr<Block> node = make_object<Block>();
// canonicalize.
if (const Block* b = first.as<Block>()) {
node->first = b->first;
node->rest = Block::make(b->rest, rest);
} else {
node->first = std::move(first);
node->rest = std::move(rest);
}
return Stmt(node);
}
Stmt Block::make(const std::vector<Stmt>& stmts) {
if (stmts.empty()) {
return Stmt();
}
Stmt result = stmts.back();
for (size_t i = stmts.size() - 1; i != 0; --i) {
result = Block::make(stmts[i - 1], result);
}
return result;
SeqStmt::SeqStmt(Array<Stmt> seq) {
auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
data_ = std::move(node);
}
Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
......@@ -1032,10 +1011,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<Block>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const Block*>(node.get());
p->Print(op->first);
if (op->rest.defined()) p->Print(op->rest);
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SeqStmtNode*>(node.get());
for (Stmt stmt : op->seq) {
p->Print(stmt);
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
......@@ -1212,7 +1192,7 @@ TVM_REGISTER_NODE_TYPE(Provide);
TVM_REGISTER_NODE_TYPE(Allocate);
TVM_REGISTER_NODE_TYPE(Free);
TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
......
......@@ -337,8 +337,8 @@ void MakeReduction(const ComputeOpNode* op,
provides.emplace_back(Provide::make(
t->op, t->value_index, update_value[i], args));
}
*init = Block::make(inits);
*provide = Block::make(provides);
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
......@@ -382,7 +382,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
if (debug_keep_trivial_loop) {
provide = MergeNest(common, provide);
} else {
provide = MergeNest(common, Block::make(init, provide));
provide = MergeNest(common, SeqStmt::Flatten(init, provide));
}
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
......@@ -392,7 +392,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = Block::make(provides);
Stmt provide = SeqStmt::Flatten(provides);
provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
......
......@@ -100,10 +100,10 @@ Stmt MakeCrossThreadReduction(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
Stmt assign_body = SeqStmt::Flatten(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
......
......@@ -242,7 +242,7 @@ Stmt TensorComputeOpNode::BuildProvide(
update = MergeNest(binder.asserts(), update);
update = op::Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(this->intrin->body.defined())
......
......@@ -478,7 +478,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(intrin->body.defined())
......
......@@ -240,7 +240,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
stride_err_msg.str(), Evaluate::make(0));
check = IfThenElse::make(Not::make(is_null), check, Stmt());
asserts_.emplace_back(Block::make(check, Evaluate::make(0)));
asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
......
......@@ -655,24 +655,14 @@ class CoProcSyncInserter : public StmtMutator {
}
Stmt VisitStmt(const Stmt& stmt) final {
Stmt before, after;
auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) {
before = MergeSeq(std::vector<Stmt>(
it->second.rbegin(), it->second.rend()));
}
it = insert_after_.find(stmt.get());
if (it != insert_after_.end()) {
after = MergeSeq(it->second);
}
auto it_before = insert_before_.find(stmt.get());
auto it_after = insert_after_.find(stmt.get());
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
if (before.defined()) {
new_stmt = Block::make(before, new_stmt);
}
if (after.defined()) {
new_stmt = Block::make(new_stmt, after);
}
return new_stmt;
return SeqStmt::Flatten(
it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(),
new_stmt,
it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
}
private:
......
......@@ -147,7 +147,7 @@ class DoubleBufferInjector : public StmtExprMutator {
}
Stmt loop = For::make(
outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
MergeSeq(loop_seq));
SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
......@@ -158,9 +158,9 @@ class DoubleBufferInjector : public StmtExprMutator {
IfThenElse::make(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
stmt = Block::make(loop, MergeSeq(tail_seq));
stmt = SeqStmt::Flatten(loop, tail_seq);
}
stmt = Block::make(MergeSeq(it->second), stmt);
stmt = SeqStmt::Flatten(it->second, stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
......
......@@ -59,7 +59,7 @@ class PrefetchInjector : public StmtMutator {
vectorized_.erase(iter_var);
Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region);
return Block::make(prefetch, op->body);
return SeqStmt({prefetch, op->body});
}
return ret;
}
......
......@@ -356,20 +356,18 @@ class VTInjector : public StmtExprMutator {
return IfThenElse::make(condition, then_case, else_case);
}
}
// Block
Stmt VisitStmt_(const Block* op) final {
// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
CHECK_EQ(max_loop_depth_, 0);
Stmt first = this->VisitStmt(op->first);
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt rest = this->VisitStmt(op->rest);
max_loop_depth_ = std::max(max_loop_depth_, temp);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return GetRef<Stmt>(op);
} else {
return Block::make(first, rest);
}
auto fmutate = [this](const Stmt& s) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt ret = this->VisitStmt(s);
max_loop_depth_ = std::max(max_loop_depth_, temp);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
// Allocate
Stmt VisitStmt_(const Allocate* op) final {
......@@ -442,12 +440,11 @@ class VTInjector : public StmtExprMutator {
// only unroll if number of vthreads are small
if (max_loop_depth_ == 0 && num_threads_ < 16) {
// do unrolling if it is inside innermost content.
Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}});
for (int i = 1; i < num_threads_; ++i) {
blk = Block::make(
blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
Array<Stmt> seq;
for (int i = 0; i < num_threads_; ++i) {
seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
}
return blk;
return SeqStmt::Flatten(seq);
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->dtype);
......
......@@ -179,10 +179,12 @@ class IRDeepCompare :
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
}
void VisitStmt_(const Block* op, const Stmt& other) final {
const Block* rhs = other.as<Block>();
if (CompareStmt(op->first, rhs->first) != 0) return;
if (CompareStmt(op->rest, rhs->rest) != 0) return;
void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final {
const SeqStmtNode* rhs = other.as<SeqStmtNode>();
if (CompareValue(op->size(), rhs->size()) != 0) return;
for (size_t i = 0; i < op->size(); ++i) {
if (CompareStmt(op->seq[i], rhs->seq[i]) != 0) return;
}
}
void VisitStmt_(const Evaluate* op, const Stmt& other) final {
......
......@@ -209,9 +209,10 @@ void StmtVisitor::VisitStmt_(const Prefetch* op) {
});
}
void StmtVisitor::VisitStmt_(const Block* op) {
this->VisitStmt(op->first);
this->VisitStmt(op->rest);
void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
VisitArray(op->seq, [this](const Stmt& s) {
this->VisitStmt(s);
});
}
void StmtVisitor::VisitStmt_(const Evaluate* op) {
......@@ -490,20 +491,63 @@ Stmt StmtMutator::VisitStmt_(const Prefetch* op) {
}
}
Stmt StmtMutator::VisitStmt_(const Block* op) {
Stmt first = this->VisitStmt(op->first);
Stmt rest = this->VisitStmt(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
Array<Stmt> seq = Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->first = std::move(first);
n->rest = std::move(rest);
n->seq = std::move(seq);
return Stmt(n);
}
}
// advanced visit function for seqstmt.
Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op,
bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate) {
if (flatten_before_visit) {
// Pass 1, check if we need to flatten.
bool need_flatten = false;
for (size_t i = 0; i < op->seq.size(); ++i) {
Stmt tmp = (*op)[i];
if (tmp.as<SeqStmtNode>()) need_flatten = true;
}
flatten_before_visit = need_flatten;
}
// function to run the visit.
auto frunvisit = [&](const SeqStmtNode* op) {
Array<Stmt> seq =
fmutate != nullptr ?
MutateArray(op->seq, fmutate, allow_copy_on_write_) :
Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->seq = std::move(seq);
return Stmt(n);
}
};
if (flatten_before_visit) {
Array<Stmt> seq;
SeqStmt::Flattener flattener(&seq);
flattener(0, op->seq);
// NOTE: If copy on write is allowed
// the assignment to seq below will
// destruct the original seq.
//
// Such destruction removes duplicated reference
// count to children and still enables COW for
// child Stmt.
ObjectPtr<SeqStmtNode> n = CopyOnWrite(op);
n->seq = std::move(seq);
return frunvisit(n.operator->());
} else {
return frunvisit(op);
}
}
Stmt StmtMutator::VisitStmt_(const AssertStmt* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
......
......@@ -51,10 +51,10 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (const auto* block = s.as<Block>()) {
auto n = make_object<Block>(*block);
CHECK(is_no_op(n->rest));
n->rest = body;
} else if (const auto* seq = s.as<SeqStmtNode>()) {
auto n = make_object<SeqStmtNode>(*seq);
CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
n->seq.Set(n->size() - 1, body);
body = Stmt(n);
} else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_object<AssertStmt>(*assert_);
......@@ -80,14 +80,5 @@ Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body) {
return body;
}
Stmt MergeSeq(const std::vector<Stmt>& seq) {
if (seq.size() == 0) return Evaluate::make(0);
Stmt body = seq[0];
for (size_t i = 1; i < seq.size(); ++i) {
body = Block::make(body, seq[i]);
}
return body;
}
} // namespace ir
} // namespace tvm
......@@ -48,13 +48,6 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body);
Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body);
/*!
* \brief combine sequence of operations.
* \param seq The sequence.
* \return The combined Stmt
*/
Stmt MergeSeq(const std::vector<Stmt>& seq);
/*!
* \brief update array with an unary function
* \param arr array
* \param fupdate an unary function
......
......@@ -75,17 +75,56 @@ class AttrScopeLifter : public StmtMutator {
}
}
Stmt VisitStmt_(const Block* op) final {
std::vector<Stmt> seq;
FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq);
seq = MutateSeq(seq);
if (seq.size() == 2 &&
seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) {
return GetRef<Stmt>(op);
Stmt VisitStmt_(const SeqStmtNode* op) final {
// remember the decorations.
std::vector<ObjectRef> attr_node;
std::vector<Expr> attr_value;
auto fmutate = [&](const Stmt& s) {
attr_node_ = ObjectRef();
attr_value_ = Expr();
Stmt ret = this->VisitStmt(s);
attr_node.push_back(attr_node_);
attr_value.push_back(attr_value_);
return ret;
};
Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
if (attr_node.size() == 0) return ret;
op = ret.as<SeqStmtNode>();
CHECK(op != nullptr);
Array<Stmt> reorg;
// check if all decorations are common.
for (size_t begin = 0; begin < attr_node.size();) {
size_t end = begin + 1;
while (end < attr_node.size() &&
attr_node[end].same_as(attr_node[begin]) &&
ValueSame(attr_value[end], attr_value[begin])) {
++end;
}
// covers everything
// lift attr to parent.
if (begin == 0 && end == attr_node.size()) {
attr_node_ = attr_node[0];
attr_value_ = attr_value[0];
return ret;
}
// construct subsegments.
Array<Stmt> seq;
for (size_t i = begin; i < end; ++i) {
seq.push_back(op->seq[i]);
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
stmt = AttrStmt::make(
attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
begin = end;
}
return MergeSeq(seq);
attr_node_ = ObjectRef();
attr_value_ = Expr();
return SeqStmt::Flatten(reorg);
}
Stmt VisitStmt_(const IfThenElse* op) final {
......@@ -132,71 +171,10 @@ class AttrScopeLifter : public StmtMutator {
}
private:
void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
if (const Block* op = s.as<Block>()) {
FlattenSeq(op->first, res);
FlattenSeq(op->rest, res);
} else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
if (!op->is_producer) {
FlattenSeq(op->body, res);
} else {
res->emplace_back(s);
}
} else {
res->emplace_back(s);
}
}
std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
std::vector<Stmt> res_seq;
ObjectRef curr_node;
Expr curr_value;
Stmt curr_stmt;
for (const Stmt & stmt : seq) {
attr_node_ = ObjectRef();
attr_value_ = Expr();
Stmt rest = this->VisitStmt(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
curr_value.defined() &&
attr_node_.same_as(curr_node) &&
ValueSame(attr_value_, curr_value)) {
curr_stmt = Block::make(curr_stmt, rest);
} else {
if (curr_stmt.defined()) {
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
}
curr_stmt = rest;
curr_node = attr_node_;
curr_value = attr_value_;
}
}
if (curr_stmt.defined()) {
// keep attr_node_, attr_node_
if (res_seq.size() == 0) {
return {curr_stmt};
}
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
// reset
attr_node_ = ObjectRef();
attr_value_ = Expr();
}
return res_seq;
}
// value comparison that also compares content of int constant
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
if (a.dtype() != b.dtype()) return false;
if (const IntImm* op = a.as<IntImm>()) {
......
......@@ -106,14 +106,16 @@ class CandidateSelector final : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const Block* op) final {
bool temp = no_split_;
this->VisitStmt(op->first);
// erase the no split state of first when visit rest.
std::swap(temp, no_split_);
this->VisitStmt(op->rest);
// restore the no split flag.
no_split_ = no_split_ || temp;
void VisitStmt_(const SeqStmtNode* op) final {
bool init_no_split = no_split_;
for (Stmt stmt : op->seq) {
// erase the no split state of before visiting the next one.
bool temp = init_no_split;
std::swap(temp, no_split_);
this->VisitStmt(stmt);
// restore the no split flag.
no_split_ = no_split_ || temp;
}
}
void VisitExpr_(const Call* op) final {
......@@ -402,16 +404,6 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
return std::make_pair(interval, cond_set);
}
Stmt AppendStmts(const Stmt& a, const Stmt& b) {
if (!a.defined()) {
return b;
} else if (!b.defined()) {
return a;
} else {
return Block::make(a, b);
}
}
/*
* Tries to recursively partition the range of the variable (given by var) of
* the for loop (given by node and stmt) into a
......@@ -577,8 +569,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
}
}
}
s = AppendStmts(pre_stmt, mid_stmt);
s = AppendStmts(s, post_stmt);
s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
} else {
Expr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
......
......@@ -185,7 +185,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
Var buffer_var = Downcast<Var>(call->args[2+size+i]);
stores[i] = Store::make(buffer_var, values[i], 0, pred);
}
return Block::make(stores);
return SeqStmt::Flatten(stores);
}
// Whether the threadIdx.x is involved in reduction.
if (vred[0].scope.dim_index == 0) {
......@@ -218,7 +218,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{Expr(group_extent), Expr(reduce_extent)},
pred, Evaluate::make(0));
}
return MergeSeq(seq);
return SeqStmt::Flatten(seq);
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
......@@ -252,7 +252,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
for (size_t i = 0; i < size; ++i) {
stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
}
return Block::make(stores);
return SeqStmt::Flatten(stores);
};
// Step one, check for
if (reduce_align > reduce_extent) {
......@@ -280,11 +280,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
seq.emplace_back(SyncThread("warp"));
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = MergeSeq(in_warp_seq);
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
return MergeSeq(seq);
return SeqStmt::Flatten(seq);
}
// Flatten the thread index.
// Also return a warp number,
......
......@@ -72,11 +72,14 @@ class BuiltinLower : public StmtExprMutator {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0);
while (prep_seq_.size() != 0) {
stmt = Block::make(prep_seq_.back(), stmt);
prep_seq_.pop_back();
if (prep_seq_.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
prep_seq_.clear();
return ret;
} else {
return stmt;
}
return stmt;
}
Stmt VisitStmt_(const Allocate* op) {
......@@ -107,12 +110,12 @@ class BuiltinLower : public StmtExprMutator {
intrinsic::tvm_throw_last_error, {},
Call::Intrinsic));
Stmt body = Block::make(
Stmt body = SeqStmt({
IfThenElse::make(Call::make(DataType::Bool(1),
intrinsic::tvm_handle_is_null,
{op->buffer_var}, Call::PureIntrinsic),
throw_last_error),
op->body);
op->body});
Stmt alloca = LetStmt::make(
op->buffer_var,
......@@ -133,7 +136,7 @@ class BuiltinLower : public StmtExprMutator {
op->buffer_var},
Call::Extern);
Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = Block::make(alloca, free_stmt);
body = SeqStmt({alloca, free_stmt});
body = AttrStmt::make(
op->buffer_var, attr::storage_alignment,
make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
......
......@@ -189,7 +189,7 @@ LoweredFunc MakeAPI(Stmt body,
DataType::Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic)));
body = Block::make(set_device, body);
body = SeqStmt({set_device, body});
}
n->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
......
......@@ -93,15 +93,35 @@ class NoOpRemover : public StmtMutator {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
return Evaluate::make(0);
}
Stmt VisitStmt_(const Block* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<Block>();
if (is_no_op(op->first)) {
return op->rest;
} else if (is_no_op(op->rest)) {
return op->first;
Stmt VisitStmt_(const SeqStmtNode* op) final {
Stmt ret = StmtMutator::VisitSeqStmt_(op, true);
op = ret.as<SeqStmtNode>();
CHECK(op != nullptr);
bool need_compact = false;
for (size_t i = 0; i < op->size(); ++i) {
if (is_no_op(op->seq[i])) need_compact = true;
}
if (need_compact) {
auto n = CopyOnWrite(op);
size_t top = 0;
for (size_t i = 0; i < n->seq.size(); ++i) {
if (!is_no_op(n->seq[i])) {
n->seq.Set(top++, n->seq[i]);
}
}
if (top == 1) {
return n->seq[0];
} else {
n->seq.resize(top);
return Stmt(n);
}
} else {
return stmt;
if (op->size() == 1) {
return op->seq[0];
} else {
return ret;
}
}
}
......@@ -118,7 +138,7 @@ class NoOpRemover : public StmtMutator {
for (Expr e : values) {
if (HasSideEffect(e)) {
if (stmt.defined()) {
stmt = Block::make(stmt, Evaluate::make(e));
stmt = SeqStmt({stmt, Evaluate::make(e)});
} else {
stmt = Evaluate::make(e);
}
......
......@@ -216,7 +216,7 @@ class ThreadSyncInserter : public StmtExprMutator {
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = Block::make(barrier, ret);
ret = SeqStmt({barrier, ret});
return ret;
} else {
return StmtExprMutator::VisitStmt(stmt);
......@@ -313,10 +313,10 @@ class ThreadSyncInserter : public StmtExprMutator {
rw_stats_.clear();
Stmt kinit = Evaluate::make(
Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
body = Block::make(kinit, body);
body = SeqStmt({kinit, body});
body = AttrStmt::make(
op->node, op->attr_key, op->value, body);
return Block::make(prep, body);
return SeqStmt({prep, body});
}
Stmt MakeGlobalBarrier() {
CHECK(sync_scope_.rank == StorageRank::kGlobal);
......
......@@ -120,26 +120,21 @@ class LoopUnroller : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const Block* op) final {
Stmt first = this->VisitStmt(op->first);
// cleanup state
int step_count = step_count_;
int unroll_depth = unroll_depth_;
int normal_loop_depth = normal_loop_depth_;
step_count_ = 0;
unroll_depth_ = 0;
normal_loop_depth_ = 0;
// work on rest part
Stmt rest = this->VisitStmt(op->rest);
step_count_ += step_count;
normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return GetRef<Stmt>(op);
} else {
return Block::make(first, rest);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
auto fmutate = [this](const Stmt& s) {
int step_count = step_count_;
int unroll_depth = unroll_depth_;
int normal_loop_depth = normal_loop_depth_;
step_count_ = 0;
unroll_depth_ = 0;
normal_loop_depth_ = 0;
Stmt ret = this->VisitStmt(s);
step_count_ += step_count;
normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
Stmt Unroll(const For* op) {
......@@ -149,17 +144,13 @@ class LoopUnroller : public StmtExprMutator {
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
Array<Stmt> unrolled;
for (int i = 0; i < value; ++i) {
vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
unrolled.push_back(step);
}
return unrolled;
return SeqStmt::Flatten(unrolled);
}
private:
......
......@@ -53,7 +53,7 @@ Stmt MakePipeline(const Stage& s,
if (consumer.defined() && !is_no_op(consumer)) {
consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer);
pipeline = SeqStmt({producer, consumer});
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation.
......
......@@ -155,6 +155,9 @@ TEST(IRF, StmtMutator) {
Expr VisitExpr_(const Add* op) final {
return op->a;
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
return StmtMutator::VisitSeqStmt_(op, true);
}
Expr VisitExpr(const Expr& expr) final {
return ExprMutator::VisitExpr(expr);
}
......@@ -219,6 +222,35 @@ TEST(IRF, StmtMutator) {
auto res = v(std::move(body));
CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));
}
{
auto body = fmakealloc();
Stmt body2 = Evaluate::make(1);
auto* ref2 = body2.get();
auto* extentptr = body.as<Allocate>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
CHECK(body.as<SeqStmtNode>()->size() == 3);
CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() == extentptr);
CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
}
{
// Cannot cow because of bref
auto body = fmakealloc();
Stmt body2 = Evaluate::make(1);
auto* extentptr = body.as<Allocate>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
auto bref = body;
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() != extentptr);
}
}
int main(int argc, char ** argv) {
......
......@@ -123,12 +123,12 @@ def test_outer_product():
assert ibody.extent.name == 'm'
#Check loop body
jblock = ibody.body
assert isinstance(jblock, tvm.stmt.Block)
jbody = jblock.first
assert isinstance(jblock, tvm.stmt.SeqStmt)
jbody = jblock[0]
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!"
jbody = jblock.rest
jbody = jblock[1]
assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
......@@ -191,12 +191,12 @@ def test_fanout():
assert abody.func.name == 'sigma'
#Check i loop body
rbody = abody.body
assert isinstance(rbody.first, tvm.stmt.Provide)
assert rbody.first.func.name == 'sigma'
assert len(rbody.first.args) == 1
assert rbody.first.args[0].value == 0
assert isinstance(rbody[0], tvm.stmt.Provide)
assert rbody[0].func.name == 'sigma'
assert len(rbody[0].args) == 1
assert rbody[0].args[0].value == 0
#Check fanout loop
jloop = rbody.rest.first
jloop = rbody[1]
assert jloop.loop_var.name == 'j'
assert jloop.min.value == 0
assert jloop.extent.value == 3
......@@ -214,7 +214,7 @@ def test_fanout():
assert value.b.name == 'a'
assert len(value.b.args) == 1
assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody.rest.rest.first
divide= rbody[2]
assert isinstance(divide, tvm.stmt.Provide)
assert len(divide.args) == 1
assert divide.args[0].value == 0
......@@ -224,7 +224,7 @@ def test_fanout():
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert abs(value.b.value - (1 / 3.0)) < 1e-5
write = rbody.rest.rest.rest
write = rbody[3]
assert isinstance(write, tvm.stmt.Provide)
assert write.func.name == 'b'
assert write.value.name == 'sigma'
......@@ -257,9 +257,9 @@ def test_looptype():
ir = d.op.body
except:
return
iloop = ir.first
jloop = ir.rest.first
kloop = ir.rest.rest
iloop = ir[0]
jloop = ir[1]
kloop = ir[2]
assert iloop.for_type == tvm.stmt.For.Parallel
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
......@@ -802,7 +802,7 @@ def test_array_inputs():
inputs = []
for i in range(n):
inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32'))
out = sum_array(tvm.convert(inputs))
assert len(out.op.inputs) == n
......
......@@ -34,8 +34,8 @@ def test_for():
body = body.body
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.Block)
assert isinstance(body.rest, tvm.stmt.For)
assert isinstance(body, tvm.stmt.SeqStmt)
assert isinstance(body[1], tvm.stmt.For)
def test_if():
ib = tvm.ir_builder.create()
......
......@@ -146,12 +146,6 @@ def test_stmt_constructor():
assert isinstance(x, tvm.stmt.AttrStmt)
assert x.value.value == 1
x = tvm.stmt.Block(tvm.stmt.Evaluate(11),
nop)
assert isinstance(x, tvm.stmt.Block)
assert x.first.value.value == 11
assert x.rest == nop
x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"),
tvm.convert("hellow"),
nop)
......
......@@ -171,8 +171,8 @@ def test_tensor_compute2():
s = tvm.create_schedule(C.op)
stmt = tvm.lower(s, [A, B, C], simple_mode=True)
assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate)
assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate)
def test_tensor_scan():
m = tvm.var("m")
......
......@@ -53,6 +53,7 @@ def test_equal_compute():
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
A[j] = A[j] + 2
return ib.get()
assert tvm.ir_pass.Equal(func1(), func1())
......
......@@ -92,8 +92,8 @@ def test_vthread_if_then_else():
B[i] = A[i * nthread + tx] + 2
stmt = ib.get()
stmt = tvm.ir_pass.InjectVirtualThread(stmt)
assert stmt.body.body.body.first.else_case != None
assert stmt.body.body.body.rest.else_case == None
assert stmt.body.body.body[0].else_case != None
assert stmt.body.body.body[1].else_case == None
if __name__ == "__main__":
test_vthread_extern()
......
......@@ -31,10 +31,29 @@ def test_coproc_lift():
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[j] = A[j] + 2
A[j] = A[j] + 3
A[j] = A[j] + 3
body = ib.get()
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
assert body.body.body.node == cp
# only able to lift to the common pattern of the last two fors.
ib = tvm.ir_builder.create()
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[i] = A[i] + 2
body = ib.get()
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
assert body.body.body.body[1].node == cp
assert len(body.body.body.body) == 2
if __name__ == "__main__":
test_coproc_lift()
......@@ -64,7 +64,7 @@ def test_basic():
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
assert('if' not in str(stmt.body.body.body[0]))
def test_const_loop():
n = 21
......@@ -79,7 +79,7 @@ def test_const_loop():
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
assert('if' not in str(stmt.body.body.body[0]))
def test_multi_loop():
ib = tvm.ir_builder.create()
......@@ -95,7 +95,7 @@ def test_multi_loop():
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if():
ib = tvm.ir_builder.create()
......@@ -115,7 +115,7 @@ def test_multi_if():
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.first))
assert('if' not in str(stmt.body[0]))
def test_thread_axis():
m = tvm.var('m')
......@@ -134,7 +134,7 @@ def test_thread_axis():
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert('if' not in str(stmt.body.body.body.first))
assert('if' not in str(stmt.body.body.body[0]))
def test_vectorize():
n = tvm.var('n')
......@@ -169,7 +169,7 @@ def test_condition():
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, False)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select))))
def test_condition_EQ():
ib = tvm.ir_builder.create()
......@@ -181,7 +181,7 @@ def test_condition_EQ():
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select))))
def test_thread_axis2():
n = tvm.convert(4096)
......@@ -197,7 +197,7 @@ def test_thread_axis2():
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = lower(s, [A, B])
for_body = stmt.body.body.body.body.body.first
for_body = stmt.body.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent))
def test_everything_during_deduction():
......
......@@ -16,6 +16,9 @@
# under the License.
import tvm
def nop():
return tvm.stmt.Evaluate(0)
def test_remove_no_op():
i = tvm.var('i')
j = tvm.var('j')
......@@ -37,12 +40,13 @@ def test_remove_no_op():
store = tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1)
stmt2 = tvm.make.Block(stmt, store)
stmt2 = tvm.stmt.SeqStmt([nop(), tvm.stmt.SeqStmt([store, nop()])])
assert(tvm.ir_pass.RemoveNoOp(stmt2) == store)
# remove zero extent loop
stmt3 = tvm.make.For(i, 0, 0, 0, 0, store)
ret = tvm.ir_pass.RemoveNoOp(stmt3)
assert(isinstance(ret, tvm.stmt.Evaluate))
if __name__ == "__main__":
test_remove_no_op()
......@@ -119,10 +119,10 @@ def test_coproc_sync3():
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
slist = tvm.make.stmt_list(stmt.first.body.body)
slist = tvm.make.stmt_list(stmt[0].body.body)
push_st = slist[2]
slist = tvm.make.stmt_list(slist[-1])
pop_st = slist[0].body.first
pop_st = slist[0].body[0]
assert(push_st.value.name == "cop.coproc_dep_push")
assert(__check_list(push_st.value.args, [2,3]))
......
......@@ -43,13 +43,13 @@ def test_unroll_loop():
ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16)
ib.emit(stmt)
wrapped = ib.get()
wrapped = tvm.make.Block(wrapped, stmt)
wrapped = tvm.stmt.SeqStmt([wrapped, stmt])
assert isinstance(ret, tvm.stmt.For)
ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret.first, tvm.stmt.For)
assert ret.first.for_type == tvm.stmt.For.Unrolled
assert isinstance(ret.rest, tvm.stmt.For)
assert ret.rest.for_type != tvm.stmt.For.Unrolled
assert isinstance(ret[0], tvm.stmt.For)
assert ret[0].for_type == tvm.stmt.For.Unrolled
assert isinstance(ret[1], tvm.stmt.For)
assert ret[1].for_type != tvm.stmt.For.Unrolled
def test_unroll_fake_loop():
ib = tvm.ir_builder.create()
......@@ -65,7 +65,7 @@ def test_unroll_fake_loop():
stmt = ib.get()
ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret.first, tvm.stmt.Store)
assert isinstance(ret[0], tvm.stmt.Store)
def test_unroll_single_count_loops():
n = tvm.var('n')
......
......@@ -71,7 +71,7 @@ def test_schedule_scan():
s = tvm.create_schedule(res.op)
s = s.normalize()
ir = tvm.lower(s, [s_state], simple_mode=True)
assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition")
assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition")
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......
......@@ -135,7 +135,7 @@ def fold_uop_loop(stmt_in):
if body == stmt.body:
return stmt
ends = list(reversed(ends))
body = tvm.make.stmt_seq(*(begins + [body] + ends))
body = tvm.stmt.stmt_seq(*(begins + [body] + ends))
return tvm.make.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body)
return None
......@@ -307,7 +307,7 @@ def inject_coproc_sync(stmt_in):
success[0] = True
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)])
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.stmt.For)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment