Commit d114dfc9 by Tianqi Chen Committed by GitHub

[SCHEDULE] Mutate dataflow in schedule, refactor Stage (#44)

parent 820a8597
...@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>; ...@@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor. * \param dtype the data type of the tensor.
* \param name The name of the Tensor. * \param name The name of the Tensor.
*/ */
Tensor Placeholder(Array<Expr> shape, Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32), Type dtype = Float(32),
std::string name = "placeholder"); std::string name = "placeholder");
...@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape, ...@@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor. * \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
*/ */
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*! /*!
* \brief Construct new tensors by scan over scan_axis. * \brief Construct new tensors by scan over scan_axis.
...@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor" ...@@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states. * \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
*/ */
Array<Tensor> Scan(IterVar scan_axis, Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
std::string name = "scan"); std::string name = "scan");
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f, std::function<Expr(Var)> f,
std::string name = "tensor") { std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name); return compute(shape, fc, name);
} }
inline Tensor Compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f, std::function<Expr(Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name); return compute(shape, fc, name);
} }
inline Tensor Compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f, std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name); return compute(shape, fc, name);
} }
inline Tensor Compute(Array<Expr> shape, inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f, std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") { std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name); return compute(shape, fc, name);
} }
} // namespace tvm } // namespace tvm
......
...@@ -132,6 +132,13 @@ class Stage : public NodeRef { ...@@ -132,6 +132,13 @@ class Stage : public NodeRef {
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); Expr x_factor, Expr y_factor);
/*! /*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*!
* \brief Vectorize iteration. * \brief Vectorize iteration.
* \param var The axis to be vectorized. * \param var The axis to be vectorized.
* \return reference to self. * \return reference to self.
...@@ -180,6 +187,28 @@ class Schedule : public NodeRef { ...@@ -180,6 +187,28 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op); return this->operator[](tensor->op);
} }
/*! /*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule. * \brief Normalize the schedule.
* This is needed before bound inference. * This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars * Insert necessary RebaseNode to make sure all leaf_iter_vars
...@@ -193,6 +222,11 @@ class Schedule : public NodeRef { ...@@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const ScheduleNode* operator->() const; inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type // declare container type
using ContainerType = ScheduleNode; using ContainerType = ScheduleNode;
}; };
...@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef { ...@@ -244,10 +278,16 @@ class IterVarAttr : public NodeRef {
*/ */
class StageNode : public Node { class StageNode : public Node {
public: public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the stage */ /*! \brief The thread scope level of the stage */
std::string scope; std::string scope;
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */ /*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars; Array<IterVar> all_iter_vars;
/*! /*!
...@@ -255,6 +295,11 @@ class StageNode : public Node { ...@@ -255,6 +295,11 @@ class StageNode : public Node {
* Operations can only be performed in leaves. * Operations can only be performed in leaves.
*/ */
Array<IterVar> leaf_iter_vars; Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */ /*! \brief additional attributes about iter var. */
...@@ -265,17 +310,22 @@ class StageNode : public Node { ...@@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar attach_ivar; IterVar attach_ivar;
/*! \brief The stage this node attaches to */ /*! \brief The stage this node attaches to */
Stage attach_stage; Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars); v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("relations", &relations); v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs); v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage); v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
} }
static constexpr const char* _type_key = "Stage"; static constexpr const char* _type_key = "Stage";
...@@ -285,18 +335,18 @@ class StageNode : public Node { ...@@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */ /*! \brief node container for schedule */
class ScheduleNode : public Node { class ScheduleNode : public Node {
public: public:
/*! \brief The root operations */ /*! \brief The output operations in original data flow graph */
Array<Operation> roots; Array<Operation> outputs;
/*! /*!
* \brief list of all stages for non-placeholder ops * \brief list of all stages for non-placeholder ops.
* The stage are ordered in PostDFS order of their op. * The stages are sorted in dependency order.
*/ */
Array<Stage> stages; Array<Stage> stages;
/*! \brief map of operation to the stages */ /*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map; Map<Operation, Stage> stage_map;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots); v->Visit("outputs", &outputs);
v->Visit("stages", &stages); v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map); v->Visit("stage_map", &stage_map);
} }
...@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() { ...@@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {
inline bool Stage::is_scheduled() const { inline bool Stage::is_scheduled() const {
const StageNode* n = operator->(); const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone); return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
} }
inline const ScheduleNode* Schedule::operator->() const { inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const { inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get()); return static_cast<const IterVarRelationNode*>(node_.get());
......
...@@ -63,7 +63,6 @@ def build(sch, ...@@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x) arg_list.append(x)
else: else:
raise ValueError("args must be Tensor, Buffer or Var") raise ValueError("args must be Tensor, Buffer or Var")
# lowering # lowering
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
......
...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import collections as _collections
@register_node @register_node
class Buffer(NodeBase): class Buffer(NodeBase):
...@@ -41,6 +42,53 @@ class Schedule(NodeBase): ...@@ -41,6 +42,53 @@ class Schedule(NodeBase):
""" """
_api_internal._ScheduleNormalize(self) _api_internal._ScheduleNormalize(self)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)
@register_node @register_node
class Stage(NodeBase): class Stage(NodeBase):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
...@@ -104,6 +152,18 @@ class Stage(NodeBase): ...@@ -104,6 +152,18 @@ class Stage(NodeBase):
""" """
return _api_internal._StageSetScope(self, scope) return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
......
...@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash) ...@@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)
TVM_REGISTER_API(_Placeholder) TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0], *ret = placeholder(args[0],
args[1], args[1],
args[2]); args[2]);
}); });
...@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile) ...@@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageOutermostThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
});
TVM_REGISTER_API(_StageUnroll) TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
...@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize) ...@@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize(); .normalize();
}); });
TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});
} // namespace tvm } // namespace tvm
...@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name, ...@@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) { Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0); return PlaceholderOpNode::make(name, shape, dtype).output(0);
} }
...@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const { ...@@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape); return Array<Expr>(shape);
} }
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
...@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n); return Operation(n);
} }
Array<Tensor> Scan(IterVar scan_axis, Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
......
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
namespace tvm { namespace tvm {
namespace ir { namespace schedule {
using namespace ir;
class ElemWiseDetector : public IRVisitor { class ElemWiseDetector : public ir::IRVisitor {
public: public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {} explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
...@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor { ...@@ -25,10 +27,7 @@ class ElemWiseDetector : public IRVisitor {
} }
for (size_t i = 0; i < axis_.size(); ++i) { for (size_t i = 0; i < axis_.size(); ++i) {
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if (!axis[i].same_as(axis_[i]->var)) { if (!axis[i].same_as(axis_[i]->var)) {
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false; is_elem_wise_ = false;
return; return;
} }
...@@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) { ...@@ -52,22 +51,10 @@ bool IsElemWise(const Operation& op) {
return false; return false;
} }
} // namespace ir
namespace schedule {
void AutoInlineElemWise(Schedule sch) { void AutoInlineElemWise(Schedule sch) {
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
if (!s.is_scheduled() && ir::IsElemWise(s->op)) { if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
bool is_root = false; s.compute_inline();
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
s.compute_inline();
} }
} }
} }
......
...@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan, ...@@ -294,7 +294,6 @@ void GatherOpBound(const ScanOpNode* scan,
const TensorDom& d = tmap.at(output[i]); const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
} }
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis)); CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom; Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom); Range r = arith::Union(time_dom).cover_range(sdom);
...@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op, ...@@ -321,7 +320,7 @@ void GatherOpBound(const Operation& op,
const ComputeOpNode* compute = op.as<ComputeOpNode>(); const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0)); const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) { for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom); Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i])); CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r; (*rmap)[compute->axis[i]] = r;
} }
...@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage, ...@@ -392,6 +391,8 @@ void InferRootBound(const Stage& stage,
direct_consume_by_parent = true; direct_consume_by_parent = true;
} }
} }
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
} }
} }
// The relax set // The relax set
...@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage, ...@@ -486,7 +487,11 @@ void InferRootBound(const Stage& stage,
} }
FeedGraph CreateFeedGraph(const Schedule& sch) { FeedGraph CreateFeedGraph(const Schedule& sch) {
auto g = CreateReadGraph(sch->roots); Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg; FeedGraph fg;
for (auto kv : g) { for (auto kv : g) {
for (Tensor t : kv.second) { for (Tensor t : kv.second) {
...@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) { ...@@ -523,6 +528,7 @@ AttachPath CreateAttachPath(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) { Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch); FeedGraph feed_graph = CreateFeedGraph(sch);
AttachPath attach_path = CreateAttachPath(sch); AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret; std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1]; const Stage& stage = sch->stages[i - 1];
......
...@@ -6,10 +6,10 @@ TEST(Tensor, Basic) { ...@@ -6,10 +6,10 @@ TEST(Tensor, Basic) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A"); Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B"); Tensor B = placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = compute({m, n}, [&](Var i, Var j) {
return A[i][j]; return A[i][j];
}, "C"); }, "C");
...@@ -20,11 +20,11 @@ TEST(Tensor, Basic) { ...@@ -20,11 +20,11 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) { TEST(Tensor, Reduce) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A = Placeholder({m, l}, Float(32), "A"); Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B"); Tensor B = placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k"); IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = compute({m, n}, [&](Var i, Var j) {
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
}, "C"); }, "C");
LOG(INFO) << C->op.as<ComputeOpNode>()->body; LOG(INFO) << C->op.as<ComputeOpNode>()->body;
......
...@@ -2,17 +2,6 @@ import tvm ...@@ -2,17 +2,6 @@ import tvm
from tvm.addon import nvcc_compiler from tvm.addon import nvcc_compiler
import numpy as np import numpy as np
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx
@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code
def test_gemm(): def test_gemm():
# graph # graph
nn = 1024 nn = 1024
...@@ -22,21 +11,14 @@ def test_gemm(): ...@@ -22,21 +11,14 @@ def test_gemm():
l = n l = n
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B') B = tvm.placeholder((m, l), name='B')
AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA")
BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB")
k = tvm.IterVar((0, l), name='k') k = tvm.IterVar((0, l), name='k')
CC = tvm.compute( C = tvm.compute(
(n, m), (n, m),
lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k), lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
name='CC') name='CC')
C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C")
# schedule # schedule
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
xtile, ytile = 32, 32 xtile, ytile = 32, 32
s[AA].set_scope("shared")
s[BB].set_scope("shared")
scale = 8 scale = 8
num_thread = 8 num_thread = 8
block_factor = scale * num_thread block_factor = scale * num_thread
...@@ -45,6 +27,9 @@ def test_gemm(): ...@@ -45,6 +27,9 @@ def test_gemm():
block_y = tvm.IterVar(thread_tag="blockIdx.y") block_y = tvm.IterVar(thread_tag="blockIdx.y")
thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y") thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y")
CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y) _, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x) _, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi) s[C].reorder(block_y, block_x, yi, xi)
...@@ -92,8 +77,8 @@ def test_gemm(): ...@@ -92,8 +77,8 @@ def test_gemm():
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda") check_device("cuda")
#tvm.init_opencl() tvm.init_opencl()
#check_device("opencl") check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
...@@ -22,13 +22,13 @@ def test_schedule_create(): ...@@ -22,13 +22,13 @@ def test_schedule_create():
json_str = tvm.save_json(s) json_str = tvm.save_json(s)
s_loaded = tvm.load_json(json_str) s_loaded = tvm.load_json(json_str)
assert isinstance(s_loaded, tvm.schedule.Schedule) assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
# pickle unpickle # pickle unpickle
dump = pkl.dumps(s) dump = pkl.dumps(s)
s_loaded = pkl.loads(dump) s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule) assert isinstance(s_loaded, tvm.schedule.Schedule)
assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body))
def test_reorder(): def test_reorder():
m = tvm.Var('m') m = tvm.Var('m')
......
...@@ -74,6 +74,20 @@ def test_auto_inline(): ...@@ -74,6 +74,20 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
s = tvm.Schedule(C.op)
AA = s.cache_read(A, "shared", readers=[C])
CC = s.cache_write(C, "shared")
s[AA].compute_at(s[CC], CC.op.axis[0])
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_scan() test_schedule_scan()
...@@ -81,3 +95,4 @@ if __name__ == "__main__": ...@@ -81,3 +95,4 @@ if __name__ == "__main__":
test_schedule1() test_schedule1()
test_schedule2() test_schedule2()
test_auto_inline() test_auto_inline()
test_schedule_cache()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment