Commit d114dfc9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)

parent 820a8597
......@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
......@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
......@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
} // namespace tvm
......
......@@ -132,6 +132,13 @@ class Stage : public NodeRef {
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
......@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op);
}
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
......@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type
using ContainerType = ScheduleNode;
};
......@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
*/
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
......@@ -255,6 +295,11 @@ class StageNode : public Node {
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
......@@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
}
static constexpr const char* _type_key = "Stage";
......@@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
* \brief list of all stages for non-placeholder ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
......@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
......
......@@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
......
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
from . import tensor as _tensor
from . import collections as _collections
@register_node
class Buffer(NodeBase):
......@@ -41,6 +42,53 @@ class Schedule(NodeBase):
"""
_api_internal._ScheduleNormalize(self)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
......@@ -104,6 +152,18 @@ class Stage(NodeBase):
"""
return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
......
......@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
*ret = placeholder(args[0],
args[1],
args[2]);
});
......@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_StageOutermostThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
......@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});
TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});
} // namespace tvm
......@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
......@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
......@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
......
......@@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
namespace schedule {
using namespace ir;
class ElemWiseDetector : public IRVisitor {
class ElemWiseDetector : public ir::IRVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
......@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
}
for (size_t i = 0; i < axis_.size(); ++i) {
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if (!axis[i].same_as(axis_[i]->var)) {
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false;
return;
}
......@@ -52,21 +51,9 @@ bool IsElemWise(const Operation& op) {
return false;
}
} // namespace ir
namespace schedule {
void AutoInlineElemWise(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && ir::IsElemWise(s->op)) {
bool is_root = false;
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
s.compute_inline();
}
}
......
......@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
......@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r;
}
......@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent = true;
}
}
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
}
}
// The relax set
......@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
auto g = CreateReadGraph(sch->roots);
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
......@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
......
......@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) {
auto C = compute({m, n}, [&](Var i, Var j) {
return A[i][j];
}, "C");
......@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
auto C = compute({m, n}, [&](Var i, Var j) {
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
}, "C");
LOG(INFO) << C->op.as<ComputeOpNode>()->body;
......
......@@ -2,17 +2,6 @@ import tvm
from tvm.addon import nvcc_compiler
import numpy as np
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx
@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code
def test_gemm():
# graph
nn = 1024
......@@ -22,21 +11,14 @@ def test_gemm():
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA")
BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB")
k = tvm.IterVar((0, l), name='k')
CC = tvm.compute(
C = tvm.compute(
(n, m),
lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k),
lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
name='CC')
C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C")
# schedule
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
s[BB].set_scope("shared")
scale = 8
num_thread = 8
block_factor = scale * num_thread
......@@ -45,6 +27,9 @@ def test_gemm():
block_y = tvm.IterVar(thread_tag="blockIdx.y")
thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi)
......@@ -92,8 +77,8 @@ def test_gemm():
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
#tvm.init_opencl()
#check_device("opencl")
tvm.init_opencl()
check_device("opencl")
if __name__ == "__main__":
test_gemm()
......@@ -22,13 +22,13 @@ def test_schedule_create():
json_str = tvm.save_json(s)
s_loaded = tvm.load_json(json_str)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
# pickle unpickle
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body))
assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_reorder():
m = tvm.Var('m')
......
......@@ -74,6 +74,20 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.Schedule(C.op)
AA = s.cache_read(A, "shared", readers=[C])
CC = s.cache_write(C, "shared")
s[AA].compute_at(s[CC], CC.op.axis[0])
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_schedule_scan()
......@@ -81,3 +95,4 @@ if __name__ == "__main__":
test_schedule1()
test_schedule2()
test_auto_inline()
test_schedule_cache()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment