Commit 605813e4 by tqchen

schedule over operation

parent cac1b5a8
......@@ -37,6 +37,8 @@ class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
};
} // namespace tvm
......
......@@ -10,7 +10,6 @@
#include "./expr.h"
#include "./domain.h"
namespace tvm {
// internal node container for Operation
......@@ -38,15 +37,15 @@ class OperationNode : public Node {
Domain domain;
/*! \brief optional name of the operation */
std::string name;
/*! \brief index iteration variables on the domain of operation. */
Array<Var> iter_var;
};
/*!
* \brief A Compute op that compute a tensor over certain range.
* \brief A Compute op that compute a tensor on certain domain.
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -58,12 +57,12 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
v->Visit("iter_var", &iter_var);
v->Visit("dim_var", &dim_var);
v->Visit("body", &body);
}
static Operation make(Domain domain,
std::string name,
Array<Var> iter_var,
Array<Var> dim_var,
Expr body);
};
......
......@@ -9,7 +9,7 @@
#include <string>
#include "./base.h"
#include "./split.h"
#include "./tensor.h"
#include "./operation.h"
namespace tvm {
......@@ -30,7 +30,7 @@ class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Schedule(Tensor tensor, std::string scope);
Schedule(Operation op, std::string scope);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -77,11 +77,11 @@ class AttachSpecNode : public Node {
/*! \brief represents the schedule of the tensor */
class ScheduleNode : public Node {
public:
/*! \brief Tensor to be scheduled */
Tensor tensor;
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the schedule */
std::string scope;
/*! \brief Splits over domains or rdomains */
/*! \brief Splits over iteration domains */
Array<Split> splits;
/*! \brief attach specifications */
Array<AttachSpec> attachs;
......@@ -90,7 +90,7 @@ class ScheduleNode : public Node {
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("tensor", &tensor);
v->Visit("op", &op);
v->Visit("splits", &splits);
v->Visit("attachs", &attachs);
}
......
......@@ -7,6 +7,7 @@
#define TVM_SPLIT_H_
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
namespace tvm {
......@@ -34,15 +35,13 @@ class Split : public NodeRef {
*/
class SplitNode : public Node {
public:
/*! \brief whether the split is over reduction domain*/
bool split_over_rdom{false};
/*! \brief the variable to be splitted on */
Var var;
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The dimension to split on */
int dim_index;
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
......@@ -51,13 +50,10 @@ class DimSplitNode : public SplitNode {
return "DimSplit";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("split_over_rdom", &split_over_rdom);
v->Visit("dim_index", &dim_index);
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(int dim_index,
Expr factor,
bool over_rdom);
static Split make(Var var, Expr factor);
};
// Implementations of inline functions
......
......@@ -130,9 +130,9 @@ class TensorNode : public FunctionBaseNode {
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Operation source_op;
Operation op;
/*! \brief the output index from source operation */
int source_index{0};
int value_index{0};
/*! \brief constructor */
TensorNode() {}
const char* type_key() const final {
......@@ -142,8 +142,8 @@ class TensorNode : public FunctionBaseNode {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("source_op", &source_op);
v->Visit("source_index", &source_index);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
const std::string& func_name() const final {
return name;
......@@ -154,8 +154,8 @@ class TensorNode : public FunctionBaseNode {
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Operation source_op,
int source_index);
Operation op,
int value_index);
};
// implementations
......
......@@ -91,7 +91,10 @@ def compute(shape, fcompute, name="TensorCompute"):
The created tensor
"""
ndim = len(shape)
dim_var = [Var("dim_var%d" % i) for i in range(ndim)]
arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension")
dim_var = [Var(x) for x in arg_names]
body = fcompute(*dim_var)
dom = [Range(0, x) for x in shape]
op_node = _function_internal._ComputeOp(
......
......@@ -102,7 +102,7 @@ TVM_REGISTER_API(_RDomain)
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1), args.at(2));
*ret = DimSplitNode::make(args.at(0), args.at(1));
});
TVM_REGISTER_API(_Schedule)
......
......@@ -28,24 +28,24 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
dom.push_back(Range(0, shape[i]));
}
op_node->iter_var = Array<Var>(dim_index);
op_node->dim_var = Array<Var>(dim_index);
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->iter_var);
op_node->body = fcompute(op_node->dim_var);
op_node->name = name;
node->dtype = op_node->body.type();
node->source_op = Operation(op_node);
node->source_index = 0;
node->op = Operation(op_node);
node->value_index = 0;
return Tensor(node);
}
Operation ComputeOpNode::make(Domain domain,
std::string name,
Array<Var> iter_var,
Array<Var> dim_var,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->domain = domain;
n->name = name;
n->iter_var = iter_var;
n->dim_var = dim_var;
n->body = body;
return Operation(n);
}
......
......@@ -6,9 +6,9 @@
namespace tvm {
Schedule::Schedule(Tensor tensor, std::string scope) {
Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>();
n->tensor = tensor;
n->op = op;
n->scope = scope;
node_ = n;
}
......
......@@ -6,13 +6,11 @@
namespace tvm {
Split DimSplitNode::make(int dim_index,
Expr factor,
bool over_rdom) {
Split DimSplitNode::make(Var var,
Expr factor) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->split_over_rdom = over_rdom;
n->dim_index = dim_index;
n->var = var;
n->factor = factor;
return Split(n);
}
......
......@@ -30,14 +30,14 @@ Expr Tensor::operator()(Array<Expr> indices) const {
Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype,
Operation source_op,
int source_index) {
Operation op,
int value_index) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->name = name;
n->dtype = dtype;
n->source_op = source_op;
n->source_index = source_index;
n->op = op;
n->value_index = value_index;
return Tensor(n);
}
......
......@@ -7,7 +7,7 @@ def test_inline():
X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline(
T, T.source_op.iter_var, T.source_op.body, stmt)
T, T.op.dim_var, T.op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......
......@@ -9,10 +9,11 @@ def test_schedule_create():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared")
tk1 = tvm.Split(0, 10)
tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
print(tk1.var)
print(sch.scope)
print(sch.attachs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment