Commit 605813e4 by tqchen

schedule over operation

parent cac1b5a8
...@@ -37,6 +37,8 @@ class Var : public Halide::VarExpr { ...@@ -37,6 +37,8 @@ class Var : public Halide::VarExpr {
public: public:
explicit Var(const std::string& name_hint = "v", explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {} Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
}; };
} // namespace tvm } // namespace tvm
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "./expr.h" #include "./expr.h"
#include "./domain.h" #include "./domain.h"
namespace tvm { namespace tvm {
// internal node container for Operation // internal node container for Operation
...@@ -38,15 +37,15 @@ class OperationNode : public Node { ...@@ -38,15 +37,15 @@ class OperationNode : public Node {
Domain domain; Domain domain;
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \brief index iteration variables on the domain of operation. */
Array<Var> iter_var;
}; };
/*! /*!
* \brief A Compute op that compute a tensor over certain range. * \brief A Compute op that compute a tensor on certain domain.
*/ */
class ComputeOpNode : public OperationNode { class ComputeOpNode : public OperationNode {
public: public:
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief the compute expression */ /*! \brief the compute expression */
Expr body; Expr body;
/*! \brief constructor */ /*! \brief constructor */
...@@ -58,12 +57,12 @@ class ComputeOpNode : public OperationNode { ...@@ -58,12 +57,12 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain); v->Visit("domain", &domain);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("iter_var", &iter_var); v->Visit("dim_var", &dim_var);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(Domain domain, static Operation make(Domain domain,
std::string name, std::string name,
Array<Var> iter_var, Array<Var> dim_var,
Expr body); Expr body);
}; };
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./split.h" #include "./split.h"
#include "./tensor.h" #include "./operation.h"
namespace tvm { namespace tvm {
...@@ -30,7 +30,7 @@ class Schedule : public NodeRef { ...@@ -30,7 +30,7 @@ class Schedule : public NodeRef {
public: public:
Schedule() {} Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Schedule(Tensor tensor, std::string scope); Schedule(Operation op, std::string scope);
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -77,11 +77,11 @@ class AttachSpecNode : public Node { ...@@ -77,11 +77,11 @@ class AttachSpecNode : public Node {
/*! \brief represents the schedule of the tensor */ /*! \brief represents the schedule of the tensor */
class ScheduleNode : public Node { class ScheduleNode : public Node {
public: public:
/*! \brief Tensor to be scheduled */ /*! \brief The operation to be scheduled */
Tensor tensor; Operation op;
/*! \brief The thread scope level of the schedule */ /*! \brief The thread scope level of the schedule */
std::string scope; std::string scope;
/*! \brief Splits over domains or rdomains */ /*! \brief Splits over iteration domains */
Array<Split> splits; Array<Split> splits;
/*! \brief attach specifications */ /*! \brief attach specifications */
Array<AttachSpec> attachs; Array<AttachSpec> attachs;
...@@ -90,7 +90,7 @@ class ScheduleNode : public Node { ...@@ -90,7 +90,7 @@ class ScheduleNode : public Node {
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("tensor", &tensor); v->Visit("op", &op);
v->Visit("splits", &splits); v->Visit("splits", &splits);
v->Visit("attachs", &attachs); v->Visit("attachs", &attachs);
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_SPLIT_H_ #define TVM_SPLIT_H_
#include "./base.h" #include "./base.h"
#include "./expr.h"
#include "./domain.h" #include "./domain.h"
namespace tvm { namespace tvm {
...@@ -34,15 +35,13 @@ class Split : public NodeRef { ...@@ -34,15 +35,13 @@ class Split : public NodeRef {
*/ */
class SplitNode : public Node { class SplitNode : public Node {
public: public:
/*! \brief whether the split is over reduction domain*/ /*! \brief the variable to be splitted on */
bool split_over_rdom{false}; Var var;
}; };
/*! \brief simple split node that splits over one dimension */ /*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode { class DimSplitNode : public SplitNode {
public: public:
/*! \brief The dimension to split on */
int dim_index;
/*! \brief The factor of the split */ /*! \brief The factor of the split */
Expr factor; Expr factor;
/*! \brief constructor */ /*! \brief constructor */
...@@ -51,13 +50,10 @@ class DimSplitNode : public SplitNode { ...@@ -51,13 +50,10 @@ class DimSplitNode : public SplitNode {
return "DimSplit"; return "DimSplit";
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("split_over_rdom", &split_over_rdom); v->Visit("var", &var);
v->Visit("dim_index", &dim_index);
v->Visit("factor", &factor); v->Visit("factor", &factor);
} }
static Split make(int dim_index, static Split make(Var var, Expr factor);
Expr factor,
bool over_rdom);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -130,9 +130,9 @@ class TensorNode : public FunctionBaseNode { ...@@ -130,9 +130,9 @@ class TensorNode : public FunctionBaseNode {
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
/*! \brief the source operation, can be None */ /*! \brief the source operation, can be None */
Operation source_op; Operation op;
/*! \brief the output index from source operation */ /*! \brief the output index from source operation */
int source_index{0}; int value_index{0};
/*! \brief constructor */ /*! \brief constructor */
TensorNode() {} TensorNode() {}
const char* type_key() const final { const char* type_key() const final {
...@@ -142,8 +142,8 @@ class TensorNode : public FunctionBaseNode { ...@@ -142,8 +142,8 @@ class TensorNode : public FunctionBaseNode {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("source_op", &source_op); v->Visit("op", &op);
v->Visit("source_index", &source_index); v->Visit("value_index", &value_index);
} }
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
...@@ -154,8 +154,8 @@ class TensorNode : public FunctionBaseNode { ...@@ -154,8 +154,8 @@ class TensorNode : public FunctionBaseNode {
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name, std::string name,
Type dtype, Type dtype,
Operation source_op, Operation op,
int source_index); int value_index);
}; };
// implementations // implementations
......
...@@ -91,7 +91,10 @@ def compute(shape, fcompute, name="TensorCompute"): ...@@ -91,7 +91,10 @@ def compute(shape, fcompute, name="TensorCompute"):
The created tensor The created tensor
""" """
ndim = len(shape) ndim = len(shape)
dim_var = [Var("dim_var%d" % i) for i in range(ndim)] arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension")
dim_var = [Var(x) for x in arg_names]
body = fcompute(*dim_var) body = fcompute(*dim_var)
dom = [Range(0, x) for x in shape] dom = [Range(0, x) for x in shape]
op_node = _function_internal._ComputeOp( op_node = _function_internal._ComputeOp(
......
...@@ -102,7 +102,7 @@ TVM_REGISTER_API(_RDomain) ...@@ -102,7 +102,7 @@ TVM_REGISTER_API(_RDomain)
TVM_REGISTER_API(_DimSplit) TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1), args.at(2)); *ret = DimSplitNode::make(args.at(0), args.at(1));
}); });
TVM_REGISTER_API(_Schedule) TVM_REGISTER_API(_Schedule)
......
...@@ -28,24 +28,24 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { ...@@ -28,24 +28,24 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
dom.push_back(Range(0, shape[i])); dom.push_back(Range(0, shape[i]));
} }
op_node->iter_var = Array<Var>(dim_index); op_node->dim_var = Array<Var>(dim_index);
op_node->domain = Domain(dom); op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->iter_var); op_node->body = fcompute(op_node->dim_var);
op_node->name = name; op_node->name = name;
node->dtype = op_node->body.type(); node->dtype = op_node->body.type();
node->source_op = Operation(op_node); node->op = Operation(op_node);
node->source_index = 0; node->value_index = 0;
return Tensor(node); return Tensor(node);
} }
Operation ComputeOpNode::make(Domain domain, Operation ComputeOpNode::make(Domain domain,
std::string name, std::string name,
Array<Var> iter_var, Array<Var> dim_var,
Expr body) { Expr body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->domain = domain; n->domain = domain;
n->name = name; n->name = name;
n->iter_var = iter_var; n->dim_var = dim_var;
n->body = body; n->body = body;
return Operation(n); return Operation(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
namespace tvm { namespace tvm {
Schedule::Schedule(Tensor tensor, std::string scope) { Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
n->tensor = tensor; n->op = op;
n->scope = scope; n->scope = scope;
node_ = n; node_ = n;
} }
......
...@@ -6,13 +6,11 @@ ...@@ -6,13 +6,11 @@
namespace tvm { namespace tvm {
Split DimSplitNode::make(int dim_index, Split DimSplitNode::make(Var var,
Expr factor, Expr factor) {
bool over_rdom) {
auto n = std::make_shared<DimSplitNode>(); auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1); CHECK_EQ(factor.type().lanes(), 1);
n->split_over_rdom = over_rdom; n->var = var;
n->dim_index = dim_index;
n->factor = factor; n->factor = factor;
return Split(n); return Split(n);
} }
......
...@@ -30,14 +30,14 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -30,14 +30,14 @@ Expr Tensor::operator()(Array<Expr> indices) const {
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
std::string name, std::string name,
Type dtype, Type dtype,
Operation source_op, Operation op,
int source_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = std::make_shared<TensorNode>();
n->shape = shape; n->shape = shape;
n->name = name; n->name = name;
n->dtype = dtype; n->dtype = dtype;
n->source_op = source_op; n->op = op;
n->source_index = source_index; n->value_index = value_index;
return Tensor(n); return Tensor(n);
} }
......
...@@ -7,7 +7,7 @@ def test_inline(): ...@@ -7,7 +7,7 @@ def test_inline():
X = T(100) X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100)) stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, T.source_op.iter_var, T.source_op.body, stmt) T, T.op.dim_var, T.op.body, stmt)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
......
...@@ -9,10 +9,11 @@ def test_schedule_create(): ...@@ -9,10 +9,11 @@ def test_schedule_create():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared") sch = tvm.Schedule(T, scope="shared")
tk1 = tvm.Split(0, 10) tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule) assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit) assert isinstance(tk1, tvm.schedule.DimSplit)
print(tk1.var)
print(sch.scope) print(sch.scope)
print(sch.attachs) print(sch.attachs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment