Commit 7591714a by tqchen

checkin initial of itervar

parent 70d93028
...@@ -15,6 +15,7 @@ namespace tvm { ...@@ -15,6 +15,7 @@ namespace tvm {
/*! \brief container class of reduction domain */ /*! \brief container class of reduction domain */
class RDomainNode; class RDomainNode;
class IterDomainNode;
/*! /*!
* \brief same as Halide::IR::Range * \brief same as Halide::IR::Range
...@@ -40,6 +41,9 @@ class Range : public Halide::IR::Range { ...@@ -40,6 +41,9 @@ class Range : public Halide::IR::Range {
static Range make_with_min_extent(Expr min, Expr extent); static Range make_with_min_extent(Expr min, Expr extent);
}; };
/*! \brief Domain is a multi-dimensional range */ /*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>; using Domain = Array<Range>;
...@@ -83,6 +87,20 @@ class RDomain : public NodeRef { ...@@ -83,6 +87,20 @@ class RDomain : public NodeRef {
/*! \brief use RDom as alias of RDomain */ /*! \brief use RDom as alias of RDomain */
using RDom = RDomain; using RDom = RDomain;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional domain.
*/
class IterVarNode : public Node {
/*! \brief The */
Var var;
/*! \brief the domain of iteration */
Range dom;
/*! \brief additional tag on the iteration variable */
std::string tag;
};
/*! \brief reduction domain node */ /*! \brief reduction domain node */
class RDomainNode : public Node { class RDomainNode : public Node {
public: public:
......
...@@ -20,6 +20,7 @@ using Halide::Internal::ExprNode; ...@@ -20,6 +20,7 @@ using Halide::Internal::ExprNode;
using Halide::Internal::StmtNode; using Halide::Internal::StmtNode;
using Halide::Internal::IRNodeType; using Halide::Internal::IRNodeType;
using Halide::Internal::ForType; using Halide::Internal::ForType;
using Halide::DeviceAPI;
/*! \brief Reduction operator operator */ /*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> { struct Reduce : public ExprNode<Reduce> {
......
...@@ -38,8 +38,36 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) { ...@@ -38,8 +38,36 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return body; return body;
} }
void MakeLoop(const DimSplitNode* op,
const Split& s,
Scope<AttrKey, Expr>* pscope,
std::vector<Stmt>* nest) {
auto& scope = *pscope;
Expr out_min = scope[{op->var, "min"}];
Expr out_ext = scope[{op->var, "extent"}];
Expr stride = op->factor;
Var offset(s->var->name_hint + ".offset", Int(32));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest->emplace_back(
For::make(
offset, out_min, out_ext,
ForType::Parallel, DeviceAPI::None, Stmt()));
Expr in_min = offset + out_min;
Expr in_ext = min(stride, out_ext - offset);
// declare min and extent of the corresponding variable
nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt()));
nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt()));
// declare this is the loop
nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt()));
// setup the scope.
pscope->Push({op->var, "min"}, in_min);
pscope->Push({op->var, "extent"}, in_ext);
}
Stmt MakePipeline(const Schedule& sch, Stmt body) { Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body; return body;
} }
...@@ -50,10 +78,17 @@ class InjectRealize : public IRMutator { ...@@ -50,10 +78,17 @@ class InjectRealize : public IRMutator {
: sch_(sch) {} : sch_(sch) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr && if (op != nullptr &&
op->type_key == "Split" && op->type_key == "split" &&
op->node == sch_->attach_parent) { op->node == sch_->attach_parent) {
return AttrStmt::make( return AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->type_key, op->value,
...@@ -66,6 +101,7 @@ class InjectRealize : public IRMutator { ...@@ -66,6 +101,7 @@ class InjectRealize : public IRMutator {
private: private:
// the operations to be carried // the operations to be carried
Schedule sch_; Schedule sch_;
Scope<AttrKey, Expr> attr_scope_;
}; };
} // namespace } // namespace
......
...@@ -8,7 +8,7 @@ def test_schedule_create(): ...@@ -8,7 +8,7 @@ def test_schedule_create():
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared") sch = tvm.Schedule(T.op, scope="shared")
tk1 = tvm.Split(T.op.dim_var[0], 10) tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule) assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit) assert isinstance(tk1, tvm.schedule.DimSplit)
......
...@@ -21,6 +21,7 @@ def test_tensor_reduce(): ...@@ -21,6 +21,7 @@ def test_tensor_reduce():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1])) rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd)) C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(C.op.body) print(C.op.body)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment