Commit 8f51c5fd by Tianqi Chen Committed by GitHub

[SCHEDULE] Add group, refactor thread bind api. (#82)

* [SCHEDULE] Add group, refactor thread bind api.

* fix doc

* fix g++-4.8

* More testscase

* Remove graph context from fix pt analysis
parent 6268e183
Subproject commit ce80d58741688b200f498fed8c7b0ea33e0516c8 Subproject commit 59fdca16978b6184bab87fbff7a00c95f1804686
...@@ -32,17 +32,6 @@ struct TensorDom { ...@@ -32,17 +32,6 @@ struct TensorDom {
}; };
/*! /*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
/*!
* \brief Base class of all operation nodes * \brief Base class of all operation nodes
*/ */
class OperationNode : public FunctionBaseNode { class OperationNode : public FunctionBaseNode {
...@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode { ...@@ -102,13 +91,11 @@ class OperationNode : public FunctionBaseNode {
* Set the range of each root_iter_vars in the op to out_dom_map * Set the range of each root_iter_vars in the op to out_dom_map
* *
* \param self The reference to self. * \param self The reference to self.
* \param graph_ctx The global graph context information.
* \param tensor_dom Domain map of Tensor->access set of each dimension. * \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted. * \param out_dom_map The output domain map of each IterVar to be setted.
*/ */
virtual void GatherBound( virtual void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0; std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*! /*!
...@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode { ...@@ -162,7 +149,6 @@ class PlaceholderOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode { ...@@ -214,7 +200,6 @@ class ComputeOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode { ...@@ -253,6 +238,11 @@ class ScanOpNode : public OperationNode {
/*! \brief The placeholder to refer as states in update. */ /*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder; Array<Tensor> state_placeholder;
/*! /*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output. * \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs. * They corresponds to flattened spatial axis of the outputs.
* *
...@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode { ...@@ -279,7 +269,6 @@ class ScanOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode { ...@@ -296,13 +285,15 @@ class ScanOpNode : public OperationNode {
v->Visit("init", &init); v->Visit("init", &init);
v->Visit("update", &update); v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder); v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_); v->Visit("spatial_axis_", &spatial_axis_);
} }
static Operation make(std::string name, static Operation make(std::string name,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder); Array<Tensor> state_placeholder,
Array<Tensor> input);
static constexpr const char* _type_key = "ScanOp"; static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
...@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode { ...@@ -339,7 +330,6 @@ class ExternOpNode : public OperationNode {
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
...@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape, ...@@ -388,16 +378,19 @@ Tensor placeholder(Array<Expr> shape,
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*! /*!
* \brief Construct new tensors by scan over scan_axis. * \brief Construct new tensors by scan.
* *
* \param init The intialize tensor of first K steps. * \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp. * \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states. * \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor. * \param name The optional name of the tensor.
*/ */
Array<Tensor> scan(Array<Tensor> init, Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan"); std::string name = "scan");
// same as compute, specialized for different fcompute function // same as compute, specialized for different fcompute function
......
...@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const { ...@@ -139,12 +139,20 @@ inline bool TVMArgValue::IsNodeType() const {
// extensions for TVMRetValue // extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=( inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) { const std::shared_ptr<Node>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other); SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
}
return *this; return *this;
} }
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_); SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
}
return *this; return *this;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./expr.h"
#include "./tensor.h" #include "./tensor.h"
namespace tvm { namespace tvm {
...@@ -23,8 +24,7 @@ class IterVarAttrNode; ...@@ -23,8 +24,7 @@ class IterVarAttrNode;
/*! \brief the attachment type */ /*! \brief the attachment type */
enum AttachType : int { enum AttachType : int {
kNone = 0, kGroupRoot = 1,
kRoot = 1,
kInline = 2, kInline = 2,
kInlinedAlready = 3, kInlinedAlready = 3,
kScope = 4, kScope = 4,
...@@ -64,44 +64,50 @@ class Stage : public NodeRef { ...@@ -64,44 +64,50 @@ class Stage : public NodeRef {
*/ */
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*! /*!
* \brief Compute the function inline, attach it at parent. * \brief Compute the function inline.
* \return reference to self. * \return reference to self.
*/ */
Stage& compute_inline(); // NOLINT(*) Stage& compute_inline(); // NOLINT(*)
/*! /*!
* \brief Compute the function at root, attach it to its parent. * \brief Compute the function at group root.
* \return reference to self. * \return reference to self.
*/ */
Stage& compute_root(); // NOLINT(*) Stage& compute_root(); // NOLINT(*)
/*! /*!
* \brief Rebase the parent iter var as rebased variable. * \brief Bind the ivar to thread index.
* *
* \param parent The parent iteration domain. * \param ivar The IterVar to be binded.
* \param rebased The variable to be used in rebase. * \param thread_ivar The thread axis to be binded.
* \return reference to self. * \return reference to self.
*/ */
Stage& rebase(IterVar parent, IterVar rebased); Stage& bind(IterVar ivar, IterVar thread_ivar);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
* \param threads The threads to be launched around the scope.
* \note Each thread can only appear in one env_threads.
* \return reference to self.
*/
Stage& env_threads(Array<IterVar> threads);
/*! /*!
* \brief Split the parent by factor, generate * \brief Split the parent by factor, generate
* \param parent The parent iteration domain. * \param parent The parent iteration domain.
* \param factor The split factor of the loop.
* \param p_outer The result outer domain * \param p_outer The result outer domain
* \param p_inner The result inner domain. * \param p_inner The result inner domain.
* \param factor The split factor of the loop.
* \return reference to self. * \return reference to self.
*/ */
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*! /*!
* \brief Split the iteration with a given outer domain, * \brief Split the iteration with given number of parts.
* the outer domain must have a thread-tag.
* *
* \param parent The parent domain. * \param parent The parent domain.
* \param outer The outer domain to be spliited, must have a thread_tag. * \param nparts The number of parts in the outer domain.
* \param p_outer The result outer domain.
* \param p_inner The result inner domain. * \param p_inner The result inner domain.
* \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self. * \return reference to self.
*/ */
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*! /*!
* \brief Fuse the inner outer domain to the target * \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused * \param inner The inner domain to be fused
...@@ -123,25 +129,18 @@ class Stage : public NodeRef { ...@@ -123,25 +129,18 @@ class Stage : public NodeRef {
* *
* \param x_parent The original x dimension * \param x_parent The original x dimension
* \param y_parent The original y dimension * \param y_parent The original y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \param p_x_outer Outer axis of x dimension * \param p_x_outer Outer axis of x dimension
* \param p_y_outer Outer axis of y dimension * \param p_y_outer Outer axis of y dimension
* \param p_x_inner Inner axis of x dimension * \param p_x_inner Inner axis of x dimension
* \param p_y_inner Inner axis of y dimension * \param p_y_inner Inner axis of y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \return reference to self. * \return reference to self.
*/ */
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner);
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*! /*!
* \brief Vectorize iteration. * \brief Vectorize iteration.
* \param var The axis to be vectorized. * \param var The axis to be vectorized.
...@@ -164,7 +163,15 @@ class Stage : public NodeRef { ...@@ -164,7 +163,15 @@ class Stage : public NodeRef {
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
inline bool is_scheduled() const; bool is_scheduled() const;
/*!
* \brief Get attachment spec of current stage.
* If the stage compute at Group root, this function
* will traverse the group function to get the
* final spec from the group.
* \return A stage representing the attach spec of the group.
*/
Stage GetAttachSpec() const;
// declare container type // declare container type
using ContainerType = StageNode; using ContainerType = StageNode;
}; };
...@@ -197,6 +204,18 @@ class Schedule : public NodeRef { ...@@ -197,6 +204,18 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op); return this->operator[](tensor->op);
} }
/*! /*!
* \brief Create a new stage group for all intermediate
* operations between inputs and outputs.
*
* \param outputs The output boundary of the group.
* \param inputs The input boundary of the group.
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
Stage create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs = false);
/*!
* \brief create a cache read of original tensor for readers. * \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers. * This will mutate the body of the readers.
* A new stage will be created for the tensor. * A new stage will be created for the tensor.
...@@ -274,7 +293,6 @@ class IterVarRelation : public NodeRef { ...@@ -274,7 +293,6 @@ class IterVarRelation : public NodeRef {
class IterVarAttr : public NodeRef { class IterVarAttr : public NodeRef {
public: public:
IterVarAttr() {} IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
...@@ -283,26 +301,27 @@ class IterVarAttr : public NodeRef { ...@@ -283,26 +301,27 @@ class IterVarAttr : public NodeRef {
inline const IterVarAttrNode* operator->() const; inline const IterVarAttrNode* operator->() const;
}; };
// defintion of node containers
/*! /*!
* \brief represents the schedule of the tensor * \brief represents a stage.
* *
* A schedule is a Directed acylic hypergraph. * relations form a Directed acylic hypergraph in bipartite manner.
* With each node is represented by a IterVar, * With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation. * and each hyper-edge is represented by a IterVarRelation.
* The relations connects the IterVars in the graph.
* *
* The relations can be Split/Fuse. * Besides typical stage that corresponds to operations.
* * There is also group stage, which groups stages together.
* The current data structure stores the hyper graph in its * Each stage's group(given by group) represent an constraint,
* bipartite representation. * the stage can only be attached to stages within the group.
* *
* The relations connects the IterVars in the graph. * The group stage node can be attached to IterVars as in normal stage.
*/ */
class StageNode : public Node { class StageNode : public Node {
public: public:
/*! \brief The thread scope level of the stage */ /*!
std::string scope; * \brief The operation of stage, can be different from original op.
/*! \brief The operation of stage, can be different from original op. */ * If it is null, then this stage is a group stage.
*/
Operation op; Operation op;
/*! /*!
* \brief The original operator. * \brief The original operator.
...@@ -312,42 +331,50 @@ class StageNode : public Node { ...@@ -312,42 +331,50 @@ class StageNode : public Node {
Operation origin_op; Operation origin_op;
/*! \brief All the nodes in the iter var */ /*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars; Array<IterVar> all_iter_vars;
/*! /*! \brief The current active leaf iter vars in the stage. */
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars; Array<IterVar> leaf_iter_vars;
/*! /*!
* \brief Specify threads to be launched at the stage. * \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan. * This is only valid for composite ops such as Scan.
*/ */
Array<IterVar> outermost_threads; Array<IterVar> env_threads;
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */ /*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs; Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type{kNone}; AttachType attach_type{kGroupRoot};
/*! \brief The attach point of this schedule. */ /*! \brief The attach point of this schedule. */
IterVar attach_ivar; IterVar attach_ivar;
/*! \brief The stage this node attaches to */ /*! \brief The stage this node attaches to */
Stage attach_stage; Stage attach_stage;
/*! \brief The thread storage scope level of the stage */
std::string scope;
/*! \brief Whether this is an output stage */ /*! \brief Whether this is an output stage */
bool is_output{false}; bool is_output{false};
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
*/
Stage group;
/*! \brief Number of direct child stages, only used for group stage.*/
int num_child_stages{0};
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("origin_op", &origin_op); v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars); v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads); v->Visit("env_threads", &env_threads);
v->Visit("relations", &relations); v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs); v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage); v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope);
v->Visit("is_output", &is_output); v->Visit("is_output", &is_output);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
} }
static constexpr const char* _type_key = "Stage"; static constexpr const char* _type_key = "Stage";
...@@ -360,19 +387,34 @@ class ScheduleNode : public Node { ...@@ -360,19 +387,34 @@ class ScheduleNode : public Node {
/*! \brief The output operations in original data flow graph */ /*! \brief The output operations in original data flow graph */
Array<Operation> outputs; Array<Operation> outputs;
/*! /*!
* \brief list of all stages for non-placeholder ops. * \brief list of all stages for ops.
* The stages are sorted in dependency order. * The stages are sorted in dependency order.
*/ */
Array<Stage> stages; Array<Stage> stages;
/*! \brief map of operation to the stages */ /*!
* \brief List of all stage groups.
*/
Array<Stage> groups;
/*! \brief map of original operation to the stages */
Map<Operation, Stage> stage_map; Map<Operation, Stage> stage_map;
/*!
* \brief Internal stage map to map internal ops to stages.
* This is created on demand and can be invalidated.
*/
std::unordered_map<const Node*, Stage> op2stage_cache_;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
v->Visit("stages", &stages); v->Visit("stages", &stages);
v->Visit("groups", &groups);
v->Visit("stage_map", &stage_map); v->Visit("stage_map", &stage_map);
} }
/*! \brief Initialize temp cache. */
void InitCache();
/*! \brief Invalidate temp cache. */
void InvalidateCache();
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
}; };
...@@ -381,10 +423,13 @@ class ScheduleNode : public Node { ...@@ -381,10 +423,13 @@ class ScheduleNode : public Node {
class IterVarAttrNode : public Node { class IterVarAttrNode : public Node {
public: public:
/*! \brief The iteration type. */ /*! \brief The iteration type. */
IterVarType iter_type; IterVarType iter_type{kDataPar};
/*! \brief The thread this iter Var binds, can be null */
IterVar bind_thread;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type); v->Visit("iter_type", &iter_type);
v->Visit("bind_thread", &bind_thread);
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
...@@ -412,17 +457,22 @@ class SplitNode : public IterVarRelationNode { ...@@ -412,17 +457,22 @@ class SplitNode : public IterVarRelationNode {
IterVar inner; IterVar inner;
/*! \brief The split factor */ /*! \brief The split factor */
Expr factor; Expr factor;
/*! \brief Number of parts, only factor or nparts can be given */
Expr nparts;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent); v->Visit("parent", &parent);
v->Visit("outer", &outer); v->Visit("outer", &outer);
v->Visit("inner", &inner); v->Visit("inner", &inner);
v->Visit("factor", &factor); v->Visit("factor", &factor);
v->Visit("nparts", &nparts);
} }
static IterVarRelation make( static IterVarRelation make(IterVar parent,
IterVar parent, IterVar outer, IterVar outer,
IterVar inner, Expr factor); IterVar inner,
Expr factor,
Expr nparts);
static constexpr const char* _type_key = "Split"; static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode); TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
...@@ -485,12 +535,6 @@ inline StageNode* Stage::operator->() { ...@@ -485,12 +535,6 @@ inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get()); return static_cast<StageNode*>(node_.get());
} }
inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
inline const ScheduleNode* Schedule::operator->() const { inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
...@@ -505,6 +549,5 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const { ...@@ -505,6 +549,5 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const { inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get()); return static_cast<const IterVarAttrNode*>(node_.get());
} }
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_H_ #endif // TVM_SCHEDULE_H_
...@@ -11,6 +11,7 @@ from ._ctypes._function import Function ...@@ -11,6 +11,7 @@ from ._ctypes._function import Function
from ._ctypes._function import _init_api_functions, register_func, get_global_func from ._ctypes._function import _init_api_functions, register_func, get_global_func
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import _base
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
...@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"): ...@@ -142,7 +143,7 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0) return op_node.output(0)
def scan(init, update, state_placeholder, name="scan"): def scan(init, update, state_placeholder, inputs=None, name="scan"):
"""Construct new tensors by scanning over axis. """Construct new tensors by scanning over axis.
Parameters Parameters
...@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -156,6 +157,10 @@ def scan(init, update, state_placeholder, name="scan"):
state_placeholder: Tensor or list of Tensor state_placeholder: Tensor or list of Tensor
The placeholder variables used by update. The placeholder variables used by update.
inputs: Tensor or list of Tensor, optional
The list of inputs to the scan. This is not required, but can
be useful for the compiler to detect scan body faster.
name: str, optional name: str, optional
The name hint of the tensor The name hint of the tensor
...@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -173,7 +178,7 @@ def scan(init, update, state_placeholder, name="scan"):
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state) res = tvm.scan(s_init, s_update, s_state, X)
""" """
if isinstance(init, _tensor.Tensor): if isinstance(init, _tensor.Tensor):
init = [init] init = [init]
...@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -181,10 +186,14 @@ def scan(init, update, state_placeholder, name="scan"):
update = [update] update = [update]
if isinstance(state_placeholder, _tensor.Tensor): if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder] state_placeholder = [state_placeholder]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
if inputs is None:
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder): if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length") raise ValueError("init, update, state_placeholder must have same length")
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) op = _api_internal._ScanOp(name, axis, init, update, state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
...@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''): ...@@ -340,20 +349,25 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
return _api_internal._IterVar(dom, var, iter_type, thread_tag) return _api_internal._IterVar(dom, var, iter_type, thread_tag)
def thread_axis(dom, tag, name=''): def thread_axis(dom=None, tag='', name=''):
"""Create a new IterVar to represent thread index. """Create a new IterVar to represent thread index.
Parameters Parameters
---------- ----------
dom : Range dom : Range or str
The domain of iteration. The domain of iteration
When str is passed, dom is set to None and str is used as tag
tag : str tag : str, optional
The thread tag The thread tag
name : str, optional name : str, optional
The name of the var. The name of the var.
""" """
if isinstance(dom, _base.string_types):
tag, dom = dom, None
if len(tag) == 0:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag name = name if name else tag
return _IterVar(dom, name, 1, tag) return _IterVar(dom, name, 1, tag)
......
...@@ -41,6 +41,30 @@ class Schedule(NodeBase): ...@@ -41,6 +41,30 @@ class Schedule(NodeBase):
""" """
_api_internal._ScheduleNormalize(self) _api_internal._ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _api_internal._ScheduleCreateGroup(
self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers): def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers. """Create a cache read of original tensor for readers.
...@@ -112,25 +136,7 @@ class Schedule(NodeBase): ...@@ -112,25 +136,7 @@ class Schedule(NodeBase):
@register_node @register_node
class Stage(NodeBase): class Stage(NodeBase):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
def rebase(self, parent, rebased): def split(self, parent, factor=None, nparts=None):
"""Rebase parent by an existing thread axis.
Parameters
----------
parent : IterVar
The parent iter var.
rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased
def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both """Split the stage either by factor providing outer scope, or both
Parameters Parameters
...@@ -141,8 +147,8 @@ class Stage(NodeBase): ...@@ -141,8 +147,8 @@ class Stage(NodeBase):
factor : Expr, optional factor : Expr, optional
The splitting factor The splitting factor
outer : IterVar, optional nparts : Expr, optional
The outer split variable The number of outer parts.
Returns Returns
------- -------
...@@ -152,11 +158,13 @@ class Stage(NodeBase): ...@@ -152,11 +158,13 @@ class Stage(NodeBase):
inner : IterVar inner : IterVar
The inner variable of iteration. The inner variable of iteration.
""" """
if outer is not None: if nparts is not None:
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor) if factor is not None:
raise ValueError("Donot need to provide both outer and nparts")
outer, inner = _api_internal._StageSplitByNParts(self, parent, nparts)
else: else:
if factor is None: if factor is None:
raise ValueError("either outer or factor need to be provided") raise ValueError("Either nparts or factor need to be provided")
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor) outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
...@@ -188,8 +196,21 @@ class Stage(NodeBase): ...@@ -188,8 +196,21 @@ class Stage(NodeBase):
""" """
return _api_internal._StageSetScope(self, scope) return _api_internal._StageSetScope(self, scope)
def outermost_threads(self, threads): def bind(self, ivar, thread_ivar):
"""Force launch threads at outermost scope of the stage. """Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_api_internal._StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters Parameters
---------- ----------
...@@ -198,7 +219,7 @@ class Stage(NodeBase): ...@@ -198,7 +219,7 @@ class Stage(NodeBase):
""" """
if isinstance(threads, _collections.IterVar): if isinstance(threads, _collections.IterVar):
threads = [threads] threads = [threads]
_api_internal._StageOutermostThreads(self, threads) _api_internal._StageEnvThreads(self, threads)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the stage at parent's scope """Attach the stage at parent's scope
......
...@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp) ...@@ -182,7 +182,8 @@ TVM_REGISTER_API(_ScanOp)
args[1], args[1],
args[2], args[2],
args[3], args[3],
args[4]); args[4],
args[5]);
}); });
TVM_REGISTER_API(_ExternOp) TVM_REGISTER_API(_ExternOp)
...@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope) ...@@ -219,27 +220,26 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]); .set_scope(args[1]);
}); });
TVM_REGISTER_API(_StageRebase) TVM_REGISTER_API(_StageBind)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.rebase(args[1], args[2]); .bind(args[1], args[2]);
}); });
TVM_REGISTER_API(_StageSplitByFactor) TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.split(args[1], &outer, &inner, args[2]); .split(args[1], args[2], &outer, &inner);
*ret = Array<IterVar>({outer, inner}); *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageSplitByOuter) TVM_REGISTER_API(_StageSplitByNParts)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar inner; IterVar outer, inner;
args[0].operator Stage() args[0].operator Stage()
.split(args[1], args[2], &inner, args[3]); .split_by_nparts(args[1], args[2], &outer, &inner);
*ret = inner; *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_StageFuse) TVM_REGISTER_API(_StageFuse)
...@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile) ...@@ -278,15 +278,17 @@ TVM_REGISTER_API(_StageTile)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner; IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage() args[0].operator Stage()
.tile(args[1], args[2], &x_outer, &y_outer, .tile(args[1], args[2],
&x_inner, &y_inner, args[3], args[4]); args[3], args[4],
&x_outer, &y_outer,
&x_inner, &y_inner);
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageOutermostThreads) TVM_REGISTER_API(_StageEnvThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage() args[0].operator Stage()
.outermost_threads(args[1]); .env_threads(args[1]);
}); });
TVM_REGISTER_API(_StageUnroll) TVM_REGISTER_API(_StageUnroll)
...@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize) ...@@ -313,6 +315,12 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize(); .normalize();
}); });
TVM_REGISTER_API(_ScheduleCreateGroup)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API(_ScheduleCacheRead) TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
......
...@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise) ...@@ -34,9 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath); REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps); REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule } // namespace schedule
......
...@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) { ...@@ -35,7 +35,6 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
} }
std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) { std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
LOG(INFO) << "ssa get id";
if (name_alloc_map_.count(src)) return src; if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src); auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) { if (it != ssa_assign_map_.end()) {
......
...@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs( ...@@ -132,7 +132,6 @@ void ComputeOpNode::PropBoundToInputs(
void ComputeOpNode::GatherBound( void ComputeOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0)); const TensorDom& tdom = tensor_dom.at(self.output(0));
......
...@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs( ...@@ -99,7 +99,6 @@ void ExternOpNode::PropBoundToInputs(
void ExternOpNode::GatherBound( void ExternOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
} }
......
...@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage, ...@@ -38,20 +38,29 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = iv->var; value_map[iv] = iv->var;
continue; continue;
} }
// Bind iv could be another thread.
IterVar bind_iv = iv;
if (stage->iter_var_attrs.count(iv)) {
IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
if (bind_thread.defined()) bind_iv = bind_thread;
}
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
// initialize the offset and loop_level // initialize the offset and loop_level
Var var = iv->var; Var var = bind_iv->var;
if (new_loop_var) { if (new_loop_var) {
var = Var(iv->var->name_hint + ".init", iv->var.type()); var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
} }
// Mark the iter var in the IR, to remember the point // Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) { if (bind_iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial; ForType for_type = ForType::Serial;
if (stage->iter_var_attrs.count(iv)) { if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) { switch (stage->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break; case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break; case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break; case kParallelized: for_type = ForType::Parallel; break;
case kDataPar: break;
default: LOG(FATAL) << "Unknown iter type" default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type << stage->iter_var_attrs[iv]->iter_type
<< " in the iter_var_attrs"; << " in the iter_var_attrs";
...@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage, ...@@ -67,7 +76,7 @@ MakeLoopNest(const Stage& stage,
for_type, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
value_map[iv] = var; value_map[iv] = var;
} else { } else {
Var idx(iv->var->name_hint + ".idx", iv->var.type()); Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent, For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
...@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage, ...@@ -76,29 +85,29 @@ MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op)); LetStmt::make(var, new_value, no_op));
} }
} else if (iv->thread_tag == "vthread") { } else if (bind_iv->thread_tag == "vthread") {
// virtual thread // virtual thread
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent)); CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var; value_map[iv] = var;
} else if (iv->thread_tag == "pipeline") { } else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker. // pipeline marker.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
CHECK(is_one(dom->extent)); CHECK(is_one(dom->extent));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pipeline_exec_scope, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
// Always restrict threaded IterVar to starts from 0. // Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min)); CHECK(is_zero(dom->min));
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) { if (is_one(dom->extent)) {
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
......
...@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const { ...@@ -33,7 +33,6 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
return shape; return shape;
} }
Operation PlaceholderOpNode::make(std::string name, Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape, Array<Expr> shape,
Type dtype) { Type dtype) {
...@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs( ...@@ -66,7 +65,6 @@ void PlaceholderOpNode::PropBoundToInputs(
void PlaceholderOpNode::GatherBound( void PlaceholderOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
} }
......
...@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name, ...@@ -47,7 +47,8 @@ Operation ScanOpNode::make(std::string name,
IterVar axis, IterVar axis,
Array<Tensor> init, Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder) { Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
auto n = std::make_shared<ScanOpNode>(); auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size()); CHECK_EQ(init.size(), state_placeholder.size());
...@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name, ...@@ -89,12 +90,14 @@ Operation ScanOpNode::make(std::string name,
n->init = init; n->init = init;
n->update = update; n->update = update;
n->state_placeholder = state_placeholder; n->state_placeholder = state_placeholder;
n->inputs = inputs;
return Operation(n); return Operation(n);
} }
Array<Tensor> scan(Array<Tensor> init, Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> state_placeholder,
Array<Tensor> inputs,
std::string name) { std::string name) {
IterVar scan_axis = IterVar scan_axis =
IterVarNode::make( IterVarNode::make(
...@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init, ...@@ -102,7 +105,7 @@ Array<Tensor> scan(Array<Tensor> init,
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered); Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make( Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder); name, scan_axis, init, update, state_placeholder, inputs);
Array<Tensor> res; Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) { for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i)); res.push_back(op.output(i));
...@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs( ...@@ -179,7 +182,6 @@ void ScanOpNode::PropBoundToInputs(
void ScanOpNode::GatherBound( void ScanOpNode::GatherBound(
const Operation& self, const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
...@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound( ...@@ -200,8 +202,7 @@ void ScanOpNode::GatherBound(
Range r = arith::Union(time_dom).cover_range(sdom); Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent( (*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(this, graph_ctx.feed_graph); Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self, body);
// Update for spatial axis. // Update for spatial axis.
size_t sp_idx = 0; size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) { for (size_t i = 0; i < output.size(); ++i) {
......
...@@ -15,10 +15,23 @@ ...@@ -15,10 +15,23 @@
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
// check if scope // check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { inline bool ScopeRelax(const IterVar& ivar,
const std::unordered_map<IterVar, IterVar>& bind_map,
const std::string& scope) {
using runtime::ThreadScope; using runtime::ThreadScope;
using runtime::StorageScope; using runtime::StorageScope;
auto it = bind_map.find(ivar);
IterVar iv = ivar;
if (it != bind_map.end()) {
iv = it->second;
}
if (iv->thread_tag.length() == 0) return false; if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false; if (scope.length() == 0) return false;
...@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -28,10 +41,16 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
void InferRootBound(const Stage& stage, void InferRootBound(const Stage& stage,
const GraphContext& ctx, const GraphContext& ctx,
const AttachPath& attach_path, const AttachPath& attach_path,
const std::unordered_map<IterVar, IterVar>& bind_map,
std::unordered_map<IterVar, Range>* rmap) { std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline) CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops"; << "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return; if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
<< "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) { if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : stage->op->root_iter_vars()) { for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
...@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage, ...@@ -42,8 +61,10 @@ void InferRootBound(const Stage& stage,
} }
// parent stage, if any // parent stage, if any
Stage parent; Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) { Stage attach_spec = stage.GetAttachSpec();
parent = stage->attach_stage; if (attach_spec->attach_type == kScope ||
attach_spec->attach_type == kScanUpdate) {
parent = attach_spec->attach_stage;
} }
// The tensor domain. // The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap; std::unordered_map<Tensor, TensorDom> tmap;
...@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage, ...@@ -72,13 +93,11 @@ void InferRootBound(const Stage& stage,
// from the already inferred bounds. // from the already inferred bounds.
std::unordered_map<const Variable*, IntSet> relax_set; std::unordered_map<const Variable*, IntSet> relax_set;
for (IterVar iv : attach_path.at(stage->op)) { for (IterVar iv : attach_path.at(stage->op)) {
if (ScopeRelax(iv, stage->scope)) { if (ScopeRelax(iv, bind_map, stage->scope)) {
relax_set[iv->var.get()] = IntSet::range(rmap->at(iv)); relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
} }
} }
if (direct_consume_by_parent) { if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent. // Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state; std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true; bool fix_value = true;
...@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage, ...@@ -89,16 +108,16 @@ void InferRootBound(const Stage& stage,
CHECK(is_zero(vrange->min)) CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, " << "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. " << " call schedule.normalize to achieve this. "
<< " stage=" << parent; << " stage=" << parent << ", vrange=" << vrange->min;
// special optimization to remove trivial loop // special optimization to remove trivial loop
if (is_one(vrange->extent)) { if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min); up_state[iv] = IntSet::single_point(vrange->min);
} else if (fix_value && !ScopeRelax(iv, stage->scope)) { } else if (fix_value && !ScopeRelax(iv, bind_map, stage->scope)) {
up_state[iv] = IntSet::single_point(iv->var); up_state[iv] = IntSet::single_point(iv->var);
} else { } else {
up_state[iv] = IntSet::range(vrange); up_state[iv] = IntSet::range(vrange);
} }
if (stage->attach_ivar == iv) { if (attach_spec->attach_ivar == iv) {
fix_value = false; fix_value = false;
} }
} }
...@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage, ...@@ -159,7 +178,7 @@ void InferRootBound(const Stage& stage,
} }
op->PropBoundToInputs(op, dom_map, &tmap); op->PropBoundToInputs(op, dom_map, &tmap);
} }
stage->op->GatherBound(stage->op, ctx, tmap, rmap); stage->op->GatherBound(stage->op, tmap, rmap);
} }
Map<IterVar, Range> InferBound(const Schedule& sch) { Map<IterVar, Range> InferBound(const Schedule& sch) {
...@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) { ...@@ -167,18 +186,25 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (Operation op : sch->outputs) { for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op); roots.push_back(sch->stage_map[op]->op);
} }
std::unordered_map<IterVar, IterVar> bind_map;
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!bind_map.count(kv.first));
bind_map[kv.first] = kv.second->bind_thread;
}
}
}
GraphContext ctx; GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots)); ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch); AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret; std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1]; const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, attach_path, &ret); InferRootBound(stage, ctx, attach_path, bind_map, &ret);
// pass down to get bound of all iter vars. // pass down to get bound of all iter vars.
PassDownDomain(stage, &ret); PassDownDomain(stage, &ret);
// setup outer most threads. for (IterVar iv : stage->env_threads) {
for (IterVar iv : stage->outermost_threads) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
ret[iv] = iv->dom; ret[iv] = iv->dom;
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map>
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
...@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) { ...@@ -82,6 +83,60 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
return rmap; return rmap;
} }
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
const Operation& op,
const std::unordered_set<const Node*>& boundary,
bool include_bounary,
std::unordered_map<const Node*, bool>* visited,
Array<Operation>* result) {
if (visited->count(op.get())) {
return visited->at(op.get());
}
if (boundary.count(op.get())) {
(*visited)[op.get()] = true;
if (include_bounary) {
result->push_back(op);
}
return true;
}
// mark to avoid loop
// Not necessary for DAG.
(*visited)[op.get()] = false;
// check if we can reach boundary.
bool reach_boundary = false;
for (Tensor t : op->InputTensors()) {
if (GetSubGraphByPostDFS_(t->op, boundary,
include_bounary,
visited, result)) {
reach_boundary = true;
}
}
(*visited)[op.get()] = reach_boundary;
if (reach_boundary) {
result->push_back(op);
}
return reach_boundary;
}
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
Array<Operation> result;
std::unordered_set<const Node*> boundary;
for (Tensor t : inputs) {
boundary.insert(t->op.get());
}
std::unordered_map<const Node*, bool> visited;
for (Tensor t : outputs) {
GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
&visited, &result);
}
return result;
}
void PostDFSOrder(const Operation& op, void PostDFSOrder(const Operation& op,
const ReadGraph& g, const ReadGraph& g,
std::unordered_set<Operation>* visited, std::unordered_set<Operation>* visited,
...@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) { ...@@ -118,30 +173,38 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) {
AttachPath CreateAttachPath(Schedule sch) { AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret; AttachPath ret;
for (Stage stage : sch->stages) { for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) { std::unordered_set<const Node*> visited;
const Stage& parent = stage->attach_stage;
stage->attach_ivar =
parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
}
}
for (Stage stage : sch->stages) {
Array<IterVar> path; Array<IterVar> path;
for (Stage s = stage; s.defined();) {
for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) { CHECK(!visited.count(s.get()))
IterVar attach_ivar = s->attach_ivar; << "Find loop in compute_at attach group";
s = s->attach_stage; visited.insert(s.get());
bool start_attach = false; Stage spec = s.GetAttachSpec();
bool start_attach;
IterVar attach_ivar;
if (spec->attach_type == kScope) {
attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
start_attach = false;
CHECK(attach_ivar.defined());
} else if (spec->attach_type == kScanUpdate) {
s = spec->attach_stage;
start_attach = true;
} else {
break;
}
CHECK(s.defined());
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1]; IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true; if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
}
if (start_attach) path.push_back(iv); if (start_attach) path.push_back(iv);
} }
CHECK(start_attach) CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar << "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op; << " in the schedule of " << s->op;
} }
if (!ret.count(stage->op)) { if (!ret.count(stage->op)) {
ret.Set(stage->op, path); ret.Set(stage->op, path);
} }
...@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -203,53 +266,22 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
return reach; return reach;
} }
// Get all the operations that forms body of scan Array<Operation> ScanGetBody(const Operation& scan_op) {
void ScanGetBodyPostDFS_( const ScanOpNode* scan = scan_op.as<ScanOpNode>();
Operation op, // Get the body.
const ScanOpNode* scan, Array<Tensor> inputs;
const FeedGraph& feed_graph,
std::unordered_set<const Node*>* visited,
Array<Operation>* result) {
if (op.get() == scan) return;
bool empty_feed = true;
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = feed_graph.find(op.output(i));
if (it != feed_graph.end() && it->second.size()) {
empty_feed = false;
for (const Operation& xop : it->second) {
if (visited->count(xop.get())) continue;
visited->insert(xop.get());
ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
result->push_back(xop);
}
}
}
if (empty_feed && op.get() != scan) {
LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
}
}
Array<Operation> ScanGetBody_(
const ScanOpNode* scan,
const FeedGraph& feed_graph) {
CHECK(scan != nullptr);
std::unordered_set<const Node*> visited;
Array<Operation> result;
for (Tensor t : scan->state_placeholder) { for (Tensor t : scan->state_placeholder) {
ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result); inputs.push_back(t);
} }
return result; for (Tensor t : scan->inputs) {
} inputs.push_back(t);
}
Array<Operation> ScanGetBody(const Operation& scan) { return GetSubGraph(scan->update, inputs, false);
return ScanGetBody_(scan.as<ScanOpNode>(),
CreateFeedGraph(CreateReadGraph({scan})));
} }
Map<IterVar, Expr> ScanFixPointAnalysis( Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
const Operation& scan_op, const Array<Operation>& body) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>(); const ScanOpNode* scan = scan_op.as<ScanOpNode>();
CHECK(body[0].get() == scan); Array<Operation> body = ScanGetBody(scan_op);
std::unordered_map<TensorDimKey, const Node*> exact_reach; std::unordered_map<TensorDimKey, const Node*> exact_reach;
std::unordered_set<const Node*> fail_set; std::unordered_set<const Node*> fail_set;
...@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis( ...@@ -276,8 +308,8 @@ Map<IterVar, Expr> ScanFixPointAnalysis(
} }
}; };
// prop exact reach back. // prop exact reach back.
for (size_t i = body.size(); i != 1; --i) { for (size_t i = 0; i < body.size(); ++i) {
const Operation& op = body[i - 1]; const Operation& op = body[i];
if (op.as<ScanOpNode>()) { if (op.as<ScanOpNode>()) {
const auto& update = op.as<ScanOpNode>()->update; const auto& update = op.as<ScanOpNode>()->update;
const auto& init = op.as<ScanOpNode>()->init; const auto& init = op.as<ScanOpNode>()->init;
......
...@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >; ...@@ -27,6 +27,11 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
using AttachPath = Map<Operation, Array<IterVar> >; using AttachPath = Map<Operation, Array<IterVar> >;
/*! /*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief Get read graph of each operation to all the * \brief Get read graph of each operation to all the
* Tensors that it directly depends on. * Tensors that it directly depends on.
* *
...@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >; ...@@ -37,6 +42,23 @@ using AttachPath = Map<Operation, Array<IterVar> >;
ReadGraph CreateReadGraph(const Array<Operation>& roots); ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*! /*!
* \brief Get minimum subgraph between outputs and inputs.
* The operations contains node which input-reachable from any inputs
* output reachable to any outputs.
*
* The inputs won't be included in the subgraph, the outputs will be inclued.
*
* \param outputs The outputs of the subgraph
* \param inputs The inputs to the subgraph.
* \param include_inputs Whether to include inputs
*
* \return The subgraph.
*/
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs);
/*!
* \brief Get a post DFS ordered of operations in the graph. * \brief Get a post DFS ordered of operations in the graph.
* \param roots The root of the graph. * \param roots The root of the graph.
* \param g The read graph. * \param g The read graph.
...@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch); ...@@ -67,14 +89,10 @@ AttachPath CreateAttachPath(Schedule sch);
/*! /*!
* \brief Get all operations inside the recursion of scan. * \brief Get all operations inside the recursion of scan.
* \param scan The scan node. * \param scan_op The scan node ops.
* \param feed_graph The feed graph to help analysis.
* \return The body operations, in read dependency order. * \return The body operations, in read dependency order.
*/ */
Array<Operation> ScanGetBody_( Array<Operation> ScanGetBody(const Operation& scan_op);
const ScanOpNode* scan, const FeedGraph& feed_graph);
// same as ScanGetBody_, but create FeedGraph internally.
Array<Operation> ScanGetBody(const Operation& scan);
/*! /*!
* \brief Analyze each spatial dimension of scan's result. * \brief Analyze each spatial dimension of scan's result.
...@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan); ...@@ -85,11 +103,9 @@ Array<Operation> ScanGetBody(const Operation& scan);
* next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...] * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
* *
* \param scan The scan node. * \param scan The scan node.
* \param body The body of scan, sorted in reverse PostDFSOrder.
* \return Map of spatial_axis -> IntImm * \return Map of spatial_axis -> IntImm
*/ */
Map<IterVar, Expr> ScanFixPointAnalysis( Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
const Operation& scan, const Array<Operation>& body);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
......
...@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) { ...@@ -22,6 +22,23 @@ inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs)); return is_zero(ir::Simplify(lhs - rhs));
} }
void Update(std::unordered_map<IterVar, Range>* p_state,
const IterVar& iv,
Range r) {
auto it = p_state->find(iv);
if (it == p_state->end()) {
(*p_state)[iv] = r;
} else {
bool match = is_zero(it->second->min);
if (!prove_equal(r->extent, it->second->extent)) match = false;
CHECK(match)
<< iv
<< " domain already inferred,"
<< " cannot prove their extents are the same "
<< it->second->extent << " vs " << r->extent;
}
}
void PassDownDomain(const Stage& stage, void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state, std::unordered_map<IterVar, Range>* p_state,
bool allow_missing) { bool allow_missing) {
...@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage, ...@@ -36,30 +53,15 @@ void PassDownDomain(const Stage& stage,
CHECK(!state.count(r->inner)); CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent); const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) { if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor); Update(p_state, r->inner, Range::make_with_min_extent(0, r->factor));
if (r->outer->dom.defined()) { Update(p_state, r->outer,
state[r->outer] = r->outer->dom; Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor)));
} else { } else {
if (!state.count(r->outer)) { Update(p_state, r->outer, Range::make_with_min_extent(0, r->nparts));
state[r->outer] = Range::make_with_min_extent( Update(p_state, r->inner,
0, DivCeil(range_parent->extent, r->factor)); Range::make_with_min_extent(
} else { 0, DivCeil(range_parent->extent, r->nparts)));
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
}
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
} }
} else if (const FuseNode* r = rel.as<FuseNode>()) { } else if (const FuseNode* r = rel.as<FuseNode>()) {
if (!state.count(r->outer) || !state.count(r->inner)) { if (!state.count(r->outer) || !state.count(r->inner)) {
...@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage, ...@@ -75,20 +77,20 @@ void PassDownDomain(const Stage& stage,
CHECK(allow_missing); CHECK(allow_missing);
continue; continue;
} }
Range res = Range::make_with_min_extent( Update(p_state, r->rebased,
0, state.at(r->parent)->extent); Range::make_with_min_extent(
if (r->rebased->dom.defined()) { 0, state.at(r->parent)->extent));
Range rebase_rng = r->rebased->dom;
bool match = is_zero(rebase_rng->min);
if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
CHECK(match) << r->rebased
<< " does not match parent scope's range";
}
state[r->rebased] = res;
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
} }
// update the extents of binded threads.
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(state.count(kv.first));
Update(p_state, kv.second->bind_thread, state.at(kv.first));
}
}
} }
void PassUpIndex(const Stage& stage, void PassUpIndex(const Stage& stage,
......
...@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages, ...@@ -55,6 +55,7 @@ void ReplaceDataFlow(const Array<Stage>& stages,
Tensor Schedule::cache_read(const Tensor& tensor, Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope, const std::string& scope,
const Array<Operation>& readers) { const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping. // create identity mapping.
std::ostringstream os; std::ostringstream os;
os << tensor->op->name; os << tensor->op->name;
...@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -81,18 +82,25 @@ Tensor Schedule::cache_read(const Tensor& tensor,
} }
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite(); ArrayNode* stages = (*this)->stages.CopyOnWrite();
size_t pos = FindNodeRef(stages, operator[](tensor->op)); Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages, op_stage);
Stage cache_stage = Stage(cache->op); Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope); cache_stage.set_scope(scope);
CHECK_LT(pos, stages->data.size()); CHECK_LT(pos, stages->data.size());
stages->data.insert(stages->data.begin() + pos + 1, stages->data.insert(stages->data.begin() + pos + 1,
cache_stage.node_); cache_stage.node_);
(*this)->stage_map.Set(cache->op, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache; return cache;
} }
Tensor Schedule::cache_write(const Tensor& tensor, Tensor Schedule::cache_write(const Tensor& tensor,
const std::string& scope) { const std::string& scope) {
(*this)->InvalidateCache();
Stage orig_stage = operator[](tensor->op); Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(compute) CHECK(compute)
...@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -123,7 +131,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage // mutate orig stage
orig_stage->op = orig_new_op; orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
...@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -137,6 +144,11 @@ Tensor Schedule::cache_write(const Tensor& tensor,
stages->data.insert(stages->data.begin() + pos, stages->data.insert(stages->data.begin() + pos,
cache_stage.node_); cache_stage.node_);
(*this)->stage_map.Set(cache_op, cache_stage); (*this)->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache_tensor; return cache_tensor;
} }
...@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -152,6 +164,11 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
attach_mark[s.get()] = 1; attach_mark[s.get()] = 1;
} }
} }
for (Stage s : sch->groups) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
if (!attach_mark.count(s.get())) continue; if (!attach_mark.count(s.get())) continue;
...@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -176,6 +193,12 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
s->attach_ivar = rebase_map.at(s->attach_ivar); s->attach_ivar = rebase_map.at(s->attach_ivar);
} }
} }
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
} }
void SetScanAttach(const Schedule& sch) { // NOLINT(*) void SetScanAttach(const Schedule& sch) { // NOLINT(*)
...@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*) ...@@ -188,8 +211,8 @@ void SetScanAttach(const Schedule& sch) { // NOLINT(*)
} }
} }
void InjectInline(ScheduleNode* sch) {
void InjectInline(const Schedule& sch) { sch->InvalidateCache();
std::vector<Expr> new_body(sch->stages.size()); std::vector<Expr> new_body(sch->stages.size());
// inline all the ops // inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
...@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) { ...@@ -241,12 +264,13 @@ void InjectInline(const Schedule& sch) {
void Schedule::normalize() { void Schedule::normalize() {
RebaseNonZeroMinLoop(*this); RebaseNonZeroMinLoop(*this);
SetScanAttach(*this); SetScanAttach(*this);
InjectInline(*this); InjectInline(operator->());
} }
// Handle reduction factor. // Handle reduction factor.
Tensor Schedule::rfactor(const Tensor& tensor, Tensor Schedule::rfactor(const Tensor& tensor,
const IterVar& axis) { const IterVar& axis) {
(*this)->InvalidateCache();
using ir::Reduce; using ir::Reduce;
CHECK_EQ(axis->iter_type, kCommReduce) CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis"; << "Can only factor reduction axis";
......
...@@ -36,39 +36,28 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) ...@@ -36,39 +36,28 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0; return 0;
} }
void CheckSplit(StageNode* self, IterVar parent, IterVar outer) { void Split(StageNode* self,
IterVar parent,
Expr factor,
Expr nparts,
IterVar* p_outer,
IterVar* p_inner) {
// Check if split is valid. // Check if split is valid.
if (self->attach_type == kScanUpdate) {
CHECK(!parent.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
if (outer.defined()) {
if (outer->iter_type == kThreadIndex) {
CHECK_EQ(parent->iter_type, kDataPar)
<< "Split by by kThreadIndex requires kDataPar IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else if (outer->iter_type == kCommReduce) {
CHECK_EQ(parent->iter_type, kCommReduce)
<< "Split by by kCommReduce requires kCommReduce IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else {
LOG(FATAL) << "Cannot take " << IterVarType2String(parent->iter_type)
<< " as outer IterVar";
}
} else {
CHECK(parent->iter_type == kDataPar || CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce || parent->iter_type == kCommReduce ||
parent->iter_type == kOrdered) parent->iter_type == kOrdered)
<< "Cannot split on " << IterVarType2String(parent->iter_type); << "Cannot split on " << IterVarType2String(parent->iter_type);
} IterVar outer = IterVarNode::make(
} Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
IterVar inner = IterVarNode::make(
void Split(StageNode* self, IterVar parent, Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
IterVar outer, IterVar inner, Expr factor) { *p_outer = outer;
*p_inner = inner;
// The splits
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent); size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
self->relations.push_back(SplitNode::make(parent, outer, inner, factor)); self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
// add vars to all vars // add vars to all vars
all_vars->data.push_back(outer.node_); all_vars->data.push_back(outer.node_);
all_vars->data.push_back(inner.node_); all_vars->data.push_back(inner.node_);
...@@ -98,6 +87,21 @@ Stage::Stage(Operation op) { ...@@ -98,6 +87,21 @@ Stage::Stage(Operation op) {
node_ = n; node_ = n;
} }
bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kGroupRoot &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}
Stage Stage::GetAttachSpec() const {
Stage attach_spec = *this;
while (attach_spec->attach_type == kGroupRoot &&
attach_spec->group.defined()) {
attach_spec = attach_spec->group;
}
return attach_spec;
}
Stage& Stage::set_scope(std::string scope) { // NOLINT(*) Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
(*this)->scope = scope; (*this)->scope = scope;
return *this; return *this;
...@@ -106,6 +110,17 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*) ...@@ -106,6 +110,17 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate) CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates"; << "Cannot specify compute_at for scan updates";
// Group constraint checking.
Stage group = (*this)->group;
if (group.defined()) {
Stage pg = parent->group;
while (pg.defined() && !pg.same_as(group)) {
pg = pg->group;
}
CHECK(pg.same_as(group))
<< "Can only assign compute_at to stages within the same group";
}
(*this)->attach_type = kScope; (*this)->attach_type = kScope;
(*this)->attach_ivar = scope; (*this)->attach_ivar = scope;
(*this)->attach_stage = parent; (*this)->attach_stage = parent;
...@@ -117,7 +132,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) ...@@ -117,7 +132,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
} }
CHECK(found) CHECK(found)
<< "Cannot find the axis " << scope << "Cannot find the axis " << scope
<< " in parent's leaf_iter_vars or outermost_threads:" << " in parent's leaf_iter_vars"
<< " parent=" << parent; << " parent=" << parent;
return *this; return *this;
} }
...@@ -132,61 +147,73 @@ Stage& Stage::compute_inline() { // NOLINT(*) ...@@ -132,61 +147,73 @@ Stage& Stage::compute_inline() { // NOLINT(*)
Stage& Stage::compute_root() { // NOLINT(*) Stage& Stage::compute_root() { // NOLINT(*)
CHECK_NE((*this)->attach_type, kScanUpdate) CHECK_NE((*this)->attach_type, kScanUpdate)
<< "Cannot specify compute_at for scan updates"; << "Cannot specify compute_at for scan updates";
(*this)->attach_type = kRoot; (*this)->attach_type = kGroupRoot;
return *this; return *this;
} }
Stage& Stage::rebase(IterVar parent, IterVar rebased) { // NOLINT(*) Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
CHECK(parent->iter_type == kDataPar || StageNode* self = operator->();
parent->iter_type == kCommReduce) CHECK(ivar->iter_type == kDataPar ||
<< "Cannot rebase " << IterVarType2String(parent->iter_type); ivar->iter_type == kCommReduce)
CHECK(rebased->iter_type == kThreadIndex) << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
<< "Cannot rebase by " << IterVarType2String(rebased->iter_type) CHECK(thread_ivar->iter_type == kThreadIndex)
<< "Cannot rebase by " << IterVarType2String(ivar->iter_type)
<< ", only thread axis is allowed so far"; << ", only thread axis is allowed so far";
ArrayNode* all_vars = (*this)->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = (*this)->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent); FindLeafVar(all_vars, leaf_vars, ivar);
(*this)->relations.push_back(RebaseNode::make(parent, rebased));
// add vars to all vars auto it = self->iter_var_attrs.find(ivar);
all_vars->data.push_back(rebased.node_); std::shared_ptr<IterVarAttrNode> n;
// replace the position. if (it != self->iter_var_attrs.end()) {
leaf_vars->data.erase(leaf_vars->data.begin() + pos); n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
leaf_vars->data.insert(leaf_vars->data.begin() + pos, rebased.node_); if (n->bind_thread.defined() &&
!n->bind_thread.same_as(thread_ivar)) {
LOG(WARNING) << "Axis " << ivar
<< " is already bind to another thread " << n->bind_thread;
}
} else {
n = std::make_shared<IterVarAttrNode>();
}
n->bind_thread = thread_ivar;
self->iter_var_attrs.Set(ivar, IterVarAttr(n));
return *this; return *this;
} }
Stage& Stage::split( Stage& Stage::env_threads(Array<IterVar> threads) {
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) StageNode* self = operator->();
CheckSplit(operator->(), parent, IterVar()); CHECK(self->op.defined() && self->op.as<ScanOpNode>())
IterVar outer = IterVarNode::make( << "env_threads is only valid for composite ops such as ScanOp";
Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); CHECK_EQ(self->env_threads.size(), 0U)
IterVar inner = IterVarNode::make( << "Already set env_threads";
Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
*p_outer = outer; ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
*p_inner = inner; std::vector<std::shared_ptr<Node> > temp;
Split(operator->(), parent, outer, inner, factor); for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
leaf_vars->data.insert(
leaf_vars->data.begin(), temp.begin(), temp.end());
all_vars->data.insert(
all_vars->data.end(), temp.begin(), temp.end());
self->env_threads = threads;
return *this; return *this;
} }
Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) Stage& Stage::split(
CheckSplit(operator->(), parent, outer); IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
std::string name_inner = parent->var->name_hint + ".inner"; Split(operator->(), parent, factor, Expr(), p_outer, p_inner);
IterVar inner = IterVarNode::make( return *this;
Range(), Var(name_inner, parent->var.type()), parent->iter_type); }
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor);
Stage& Stage::split_by_nparts(
IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
Split(operator->(), parent, Expr(), nparts, p_outer, p_inner);
return *this; return *this;
} }
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
StageNode* self = operator->(); StageNode* self = operator->();
if (self->attach_type == kScanUpdate) {
CHECK(!inner.same_as(self->all_iter_vars[0]))
<< "Cannot fuse on axis[0] of scan update";
CHECK(!outer.same_as(self->all_iter_vars[0]))
<< "Cannot fuse on axis[0] of scan update";
}
CHECK(outer->iter_type == kDataPar || CHECK(outer->iter_type == kDataPar ||
outer->iter_type == kCommReduce || outer->iter_type == kCommReduce ||
outer->iter_type == kOrdered) outer->iter_type == kOrdered)
...@@ -236,10 +263,6 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) ...@@ -236,10 +263,6 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
std::vector<size_t> pos; std::vector<size_t> pos;
for (size_t i = 0; i < order.size(); ++i) { for (size_t i = 0; i < order.size(); ++i) {
if ((*this)->attach_type == kScanUpdate) {
CHECK(!order[i].same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update";
}
pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
} }
std::vector<std::shared_ptr<Node> > temp; std::vector<std::shared_ptr<Node> > temp;
...@@ -254,66 +277,48 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) ...@@ -254,66 +277,48 @@ Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
} }
Stage& Stage::tile(IterVar x_parent, IterVar y_parent, Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner) {
Expr x_factor, Expr y_factor) { // NOLINT(*) split(x_parent, x_factor, p_x_outer, p_x_inner);
split(x_parent, p_x_outer, p_x_inner, x_factor); split(y_parent, y_factor, p_y_outer, p_y_inner);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this; return *this;
} }
Stage& Stage::outermost_threads(Array<IterVar> threads) { inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
StageNode* self = operator->();
CHECK(self->op.as<ScanOpNode>())
<< "outermost_threads is only valid for composite ops such as ScanOp";
CHECK_EQ(self->outermost_threads.size(), 0U)
<< "Already set outermost_threads";
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
std::vector<std::shared_ptr<Node> > temp;
for (IterVar iv : threads) {
temp.push_back(iv.node_);
}
leaf_vars->data.insert(
leaf_vars->data.begin(), temp.begin(), temp.end());
all_vars->data.insert(
all_vars->data.end(), temp.begin(), temp.end());
(*this)->outermost_threads = threads;
return *this;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var); FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var); auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) { if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type) n = std::make_shared<IterVarAttrNode>(*(*it).second.operator->());
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
} else { } else {
self->iter_var_attrs.Set(var, attr); n = std::make_shared<IterVarAttrNode>();
} }
n->iter_type = iter_type;
self->iter_var_attrs.Set(var, IterVarAttr(n));
} }
Stage& Stage::vectorize(IterVar var) { // NOLINT(*) Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized)); SetAttrIterType(operator->(), var, kVectorized);
return *this; return *this;
} }
Stage& Stage::unroll(IterVar var) { // NOLINT(*) Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled)); SetAttrIterType(operator->(), var, kUnrolled);
return *this; return *this;
} }
Stage& Stage::parallel(IterVar var) { // NOLINT(*) Stage& Stage::parallel(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kParallelized)); SetAttrIterType(operator->(), var, kParallelized);
return *this; return *this;
} }
Schedule::Schedule(Array<Operation> ops) { Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
node_ = n;
n->outputs = ops; n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs); auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g); Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
...@@ -330,14 +335,24 @@ Schedule::Schedule(Array<Operation> ops) { ...@@ -330,14 +335,24 @@ Schedule::Schedule(Array<Operation> ops) {
// mark scan updates. // mark scan updates.
if (op.as<ScanOpNode>()) { if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>(); const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) { for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op]; Stage s = n->stage_map[scan->update[i]->op];
s->attach_type = kScanUpdate; CHECK(scan_group.same_as(s->group));
s->attach_stage = stage;
} }
} }
} }
node_ = std::move(n);
} }
Stage Schedule::operator[](const Operation& op) { Stage Schedule::operator[](const Operation& op) {
...@@ -348,14 +363,174 @@ Stage Schedule::operator[](const Operation& op) { ...@@ -348,14 +363,174 @@ Stage Schedule::operator[](const Operation& op) {
return (*it).second; return (*it).second;
} }
IterVarRelation SplitNode::make( Stage LeastCommonAncestor(Stage g1, Stage g2) {
IterVar parent, IterVar outer, if (!g1.defined()) return g1;
IterVar inner, Expr factor) { if (!g2.defined()) return g2;
if (g1.same_as(g2)) return g1;
Stage g = g1;
while (g.defined()) {
if (g.same_as(g2)) return g2;
g = g->group;
}
g = g2;
while (g.defined()) {
if (g.same_as(g1)) return g1;
g = g->group;
}
return g;
}
Array<Tensor> RemapTensor(ScheduleNode* self,
const Array<Tensor>& arr) {
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
Array<Tensor> ret;
for (Tensor t : arr) {
if (!op2stage_cache.count(t->op.get())) {
CHECK(self->stage_map.count(t->op))
<< "Given tensor is not in the schedule plan";
t = self->stage_map[t->op]->op.output(t->value_index);
}
ret.push_back(t);
}
return ret;
}
// Group the schedule stages.
Stage Schedule::create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs) {
ScheduleNode* self = operator->();
self->InitCache();
const auto& op2stage_cache = self->op2stage_cache_;
// Get the ops.
Array<Operation> ops = schedule::GetSubGraph(
RemapTensor(self, outputs),
RemapTensor(self, inputs),
include_inputs);
// local counter entry
// Automatically initialize to 0 during creation.
struct Entry {
int count{0};
};
// Map of group->touched counter
std::unordered_map<Stage, Entry, NodeHash, NodeEqual> counter;
// The parent group;
Stage parent_group;
// Detect common parent and child.
for (size_t i = 0; i < ops.size(); ++i) {
Operation op = ops[i];
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage op_group = it->second->group;
if (i == 0) {
parent_group = op_group;
} else {
parent_group = LeastCommonAncestor(parent_group, op_group);
}
if (op_group.defined()) {
++counter[op_group].count;
}
}
// Create the new group stage.
Stage gstage(std::make_shared<StageNode>());
gstage->group = parent_group;
if (parent_group.defined()) {
++parent_group->num_child_stages;
}
// Propagate the counter statistics from by checking if subgroup
// Is full and propagate.
std::vector<Stage> stack;
for (auto &kv : counter) {
if (!kv.first.same_as(parent_group)) {
if (kv.first->num_child_stages == kv.second.count) {
stack.push_back(kv.first);
}
}
}
while (!stack.empty()) {
Stage g = stack.back();
stack.pop_back();
if (g->group.defined() && !g->group.same_as(parent_group)) {
Entry& e = counter[g->group];
++e.count;
if (e.count == g->group->num_child_stages) {
stack.push_back(g->group);
}
}
}
// Verification and remappig the subgroups.
for (auto &kv : counter) {
if (kv.first.same_as(parent_group)) continue;
CHECK_EQ(kv.first->num_child_stages, kv.second.count)
<< "Trying to group region that intersect with an already existed group";
if (kv.first->group.same_as(parent_group)) {
Stage s = kv.first;
s->group = gstage;
++gstage->num_child_stages;
if (parent_group.defined()) {
--parent_group->num_child_stages;
}
}
}
// Remap the group of op stages.
for (Operation op : ops) {
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage s = it->second;
if (s->group.same_as(parent_group)) {
s->group = gstage;
++gstage->num_child_stages;
if (parent_group.defined()) {
--parent_group->num_child_stages;
}
}
}
// Correct the attach to keep everything in group.
for (Operation op : ops) {
auto it = op2stage_cache.find(op.get());
CHECK(it != op2stage_cache.end());
Stage s = it->second;
if (s->attach_type == kScope) {
Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
if (!cg.same_as(gstage)) {
LOG(WARNING) << "group invalidates some previous compute_at relation "
<< " and keeps things to be computed inside the group";
s.compute_root();
}
}
}
self->groups.push_back(gstage);
return gstage;
}
void ScheduleNode::InvalidateCache() {
op2stage_cache_.clear();
}
void ScheduleNode::InitCache() {
if (op2stage_cache_.size() == stages.size()) return;
InvalidateCache();
for (Stage s : stages) {
if (s->op.defined()) {
op2stage_cache_[s->op.get()] = s;
}
}
CHECK_EQ(op2stage_cache_.size(), stages.size());
}
IterVarRelation SplitNode::make(IterVar parent,
IterVar outer,
IterVar inner,
Expr factor,
Expr nparts) {
auto n = std::make_shared<SplitNode>(); auto n = std::make_shared<SplitNode>();
n->parent = parent; n->parent = parent;
n->outer = outer; n->outer = outer;
n->inner = inner; n->inner = inner;
n->factor = factor; n->factor = factor;
n->nparts = nparts;
return IterVarRelation(n); return IterVarRelation(n);
} }
...@@ -375,12 +550,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { ...@@ -375,12 +550,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n); return IterVarRelation(n);
} }
IterVarAttr::IterVarAttr(IterVarType t) {
std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
n->iter_type = t;
node_ = n;
}
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode); TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
...@@ -391,7 +560,11 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode); ...@@ -391,7 +560,11 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer // Printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) { .set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
} else {
p->stream << "group-stage(" << op << ")";
}
}) })
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) { .set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
p->stream << IterVarType2String(op->iter_type); p->stream << IterVarType2String(op->iter_type);
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./graph.h" #include "./graph.h"
#include "../op/op_util.h"
#include "../pass/ir_util.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s, ...@@ -44,8 +46,9 @@ Stmt MakePipeline(const Stage& s,
class InjectAttach : public IRMutator { class InjectAttach : public IRMutator {
public: public:
InjectAttach(const Stage& stage, InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map) const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {} : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator { ...@@ -53,10 +56,11 @@ class InjectAttach : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
op->type_key == attr::loop_scope) { op->type_key == attr::loop_scope) {
CHECK_NE(producer_.size(), 0U); if (attach_spec_->attach_type == kScope &&
if (op->node == stage_->attach_ivar && op->node == attach_spec_->attach_ivar) {
producer_.back() == stage_->attach_stage->op.get()) { CHECK(!found_attach)
CHECK(!found_attach); << "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->type_key, op->value, op->node, op->type_key, op->value,
...@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator { ...@@ -65,26 +69,16 @@ class InjectAttach : public IRMutator {
} }
return stmt; return stmt;
} }
Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
if (op->is_producer) {
producer_.push_back(op->func.get());
Stmt ret = IRMutator::Mutate_(op, s);
producer_.pop_back();
return ret;
} else {
return IRMutator::Mutate_(op, s);
}
}
// whether attach point is found // whether attach point is found
bool found_attach{false}; bool found_attach{false};
private: private:
// the operations to be carried // The stage.
const Stage& stage_; const Stage& stage_;
// The attach spec, may not contain op.
const Stage& attach_spec_;
// domain map // domain map
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// internal stack about realization scope.
std::vector<const Node*> producer_;
}; };
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
...@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator { ...@@ -128,7 +122,6 @@ class InjectScanStep : public IRMutator {
bool is_init_; bool is_init_;
}; };
// Postprocessing of schedule op // Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer. // Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator { class SchedulePostProc : public IRMutator {
...@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator { ...@@ -157,9 +150,8 @@ class SchedulePostProc : public IRMutator {
} }
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::loop_scope) { if (op->type_key == attr::loop_scope ||
return this->Mutate(op->body); op->type_key == attr::scan_init_scope) {
} else if (op->type_key == attr::scan_init_scope) {
return this->Mutate(op->body); return this->Mutate(op->body);
} else if (op->type_key == attr::scan_update_scope) { } else if (op->type_key == attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>(); const ScanOpNode* scan = op->node.as<ScanOpNode>();
...@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator { ...@@ -237,6 +229,15 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) { void Init(const Schedule& sch) {
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
for (auto kv : s->iter_var_attrs) {
// Update bind thread information.
if (kv.second->bind_thread.defined()) {
const Var& from = kv.first->var;
const Var& to = kv.second->bind_thread->var;
CHECK(!var_value_.count(from.get()));
var_value_[from.get()] = to;
}
}
// This must be checked for all ops, including scan. // This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) { if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0); Tensor target = s->origin_op.output(0);
...@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator { ...@@ -279,61 +280,67 @@ class SchedulePostProc : public IRMutator {
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) { Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt(); Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
// scan init and scan updates // scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach; std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) { for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>(); const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue; if (!scan) continue;
for (Tensor t : scan->init) { for (Tensor t : scan->init) {
if (scan_attach.count(t->op)) { if (scan_init.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op)) CHECK(scan_init.at(t->op).same_as(s->op))
<< "Scan init tensor can only belong to one scan"; << "Scan init tensor can only belong to one scan";
} else { } else {
scan_attach[t->op] = std::make_pair(s->op, true); scan_init[t->op] = s->op;
}
}
for (Tensor t : scan->update) {
if (scan_attach.count(t->op)) {
CHECK(scan_attach.at(t->op).first.same_as(s->op))
<< "Scan update tensor can only belong to one scan";
} else {
scan_attach[t->op] = std::make_pair(s->op, false);
} }
} }
} }
std::unordered_map<IterVar, Range> dom_map; // verify correctness of group.
for (auto kv : dom_map_) { for (Stage g : sch->groups) {
dom_map[kv.first] = kv.second; CHECK(!g->op.defined());
CHECK_EQ(g->leaf_iter_vars.size(), 0U);
} }
// reverse the post DFS order. // reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1]; Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline) CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops"; << "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op. // no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue; if (s->op.as<PlaceholderOpNode>()) continue;
if (scan_attach.count(s->op)) { // Remove grouping sugar, get the real attach spec.
CHECK(s->attach_type == kNone || Stage attach_spec = s.GetAttachSpec();
s->attach_type == kScanUpdate)
<< "Cannot specify compute_at for scan's init/update"; if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined()); CHECK(body.defined());
const auto& p = scan_attach.at(s->op); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
InjectScanStep mu(s, p.first, dom_map, p.second);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.init/update"; << "did not find attachment point for scan.update";
} else if (s->attach_type == kInlinedAlready) { } else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing // do nothing
} else if (s->attach_type == kRoot || s-> attach_type == kNone) { } else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body); body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) { } else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined()); CHECK(body.defined());
InjectAttach mutator(s, dom_map); InjectAttach mutator(s, attach_spec, dom_map);
body = mutator.Mutate(body); body = mutator.Mutate(body);
CHECK(mutator.found_attach) CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in" << "did not find attachment point for " << s << " in "
<< s->attach_stage->op << " x " << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
<< ", body:\n"
<< body; << body;
} }
} }
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
namespace { namespace {
......
...@@ -11,11 +11,11 @@ def test_add(): ...@@ -11,11 +11,11 @@ def test_add():
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 256 num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x") bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x)
_, x = s[C].split(x, factor=4) _, x = s[C].split(x, factor=4)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x) s[C].vectorize(x)
# one line to build the function. # one line to build the function.
......
...@@ -22,31 +22,41 @@ def test_gemm(): ...@@ -22,31 +22,41 @@ def test_gemm():
scale = 8 scale = 8
num_thread = 8 num_thread = 8
block_factor = scale * num_thread block_factor = scale * num_thread
block_x = tvm.thread_axis(None, "blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis("threadIdx.x")
block_y = tvm.thread_axis(None, "blockIdx.y") block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") thread_y = tvm.thread_axis("threadIdx.y")
CC = s.cache_write(C, "local") CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC]) AA = s.cache_read(A, "shared", [CC])
BB = s.cache_read(B, "shared", [CC]) BB = s.cache_read(B, "shared", [CC])
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y) by, yi = s[C].split(C.op.axis[0], factor=block_factor)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x) bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
s[C].reorder(block_y, block_x, yi, xi) s[C].reorder(by, bx, yi, xi)
_, yi = s[C].split(yi, outer=thread_y) s[C].bind(by, block_y)
_, xi = s[C].split(xi, outer=thread_x) s[C].bind(bx, block_x)
s[C].reorder(thread_y, thread_x, yi, xi) ty, yi = s[C].split(yi, nparts=num_thread)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(ty, tx, yi, xi)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
yo, xo = CC.op.axis yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo) s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], thread_x)
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[CC], k) s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k) s[BB].compute_at(s[CC], k)
_, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y) ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
_, xi = s[AA].split(xi, outer=thread_x) tx, xi = s[AA].split(xi, nparts=num_thread)
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y) s[AA].bind(ty, thread_y)
_, xi = s[BB].split(xi, outer=thread_x) s[AA].bind(tx, thread_x)
ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
tx, xi = s[BB].split(xi, nparts=num_thread)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
max_auto_unroll_step = 0 max_auto_unroll_step = 0
# lowering test # lowering test
...@@ -76,9 +86,9 @@ def test_gemm(): ...@@ -76,9 +86,9 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
if tvm.module.enabled("opencl"): if tvm.module.enabled("opencl"):
tvm.module.init_opencl() tvm.module.init_opencl()
check_device("cuda")
check_device("opencl") check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,10 +12,9 @@ def test_sum(): ...@@ -12,10 +12,9 @@ def test_sum():
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 1 num_thread = 1
block_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x) s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
_, x = s[B].split(x, outer=thread_x)
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
...@@ -52,10 +51,9 @@ def test_rfactor(): ...@@ -52,10 +51,9 @@ def test_rfactor():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf) kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0]) s[BF].parallel(BF.op.axis[0])
# one line to build the function. # one line to build the function.
...@@ -88,16 +86,14 @@ def test_rfactor_threads(): ...@@ -88,16 +86,14 @@ def test_rfactor_threads():
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
nthread = 16 nthread = 16
B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B') B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
tx = tvm.thread_axis((0, nthread), "threadIdx.x")
ty = tvm.thread_axis((0, nthread), "threadIdx.y")
bx = tvm.thread_axis(None, "blockIdx.x")
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
ko, kf = s[B].split(k, factor=nthread) ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx) bx, tx = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].rebase(xi, ty) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].rebase(s[B].op.reduce_axis[0], tx) s[B].bind(tx, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
s[BF].compute_at(s[B], tx) s[BF].compute_at(s[B], tx)
# one line to build the function. # one line to build the function.
...@@ -128,6 +124,6 @@ def test_rfactor_threads(): ...@@ -128,6 +124,6 @@ def test_rfactor_threads():
check_target("opencl") check_target("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor_threads()
test_rfactor() test_rfactor()
test_rfactor_threads()
test_sum() test_sum()
...@@ -15,10 +15,12 @@ def test_scan(): ...@@ -15,10 +15,12 @@ def test_scan():
num_thread = 256 num_thread = 256
block_x = tvm.thread_axis(None, "blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x) xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
_, x = s[s_init].split(x, outer=thread_x) s[s_init].bind(xo, block_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x) s[s_init].bind(xi, thread_x)
_, x = s[s_update].split(x, outer=thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
......
...@@ -11,10 +11,9 @@ def test_add_pipeline(): ...@@ -11,10 +11,9 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
num_thread = 256 num_thread = 256
grid_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
_, x = s[C].split(x, outer=thread_x)
# compile to IR # compile to IR
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
......
"""Test group effect"""
import tvm
def test_scan_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
res = tvm.scan(s_init, s_update3, s_state, inputs=x)
s = tvm.Schedule(res.op)
assert s[s_update1].group is not None
assert s[s_update2].group == s[s_update1].group
# Assign within group, is valid
s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
# create a new group, for [s_update2 and s_update1]
g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
assert g2.group is not None
assert g2.group == s[s_update3].group
assert s[s_update2].group == g2
assert s[s_update1].group == g2
g2.compute_at(s[s_update3], s_update3.op.axis[1])
assert g2.attach_stage == s[s_update3]
try:
# compute outside group error.
s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
assert False
except tvm.TVMError:
pass
def test_compute_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x1].group == g
assert s[x].group == g
g.compute_at(s[x2], x2.op.axis[1])
assert g.attach_stage == s[x2]
assert g.num_child_stages == 2
def test_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x1, inputs=x)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert set(s.groups) == set([g1, g2])
assert s[x].group == g2
assert s[x1].group == g1
assert g1.group == g2
assert g2.num_child_stages == 2
assert g1.num_child_stages == 1
if __name__ == "__main__":
test_nest_group()
test_compute_group()
test_scan_group()
...@@ -29,8 +29,8 @@ def test_basic_pipeline(): ...@@ -29,8 +29,8 @@ def test_basic_pipeline():
B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k) B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline") xo, xi = s[B].split(B.op.axis[0], nparts=1)
xo, xi = s[B].split(B.op.axis[0], outer=px) s[B].bind(xo, tvm.thread_axis("pipeline"))
xo, xi = s[B].split(xi, factor=4) xo, xi = s[B].split(xi, factor=4)
for S in stages: for S in stages:
s[S].compute_at(s[B], xo) s[S].compute_at(s[B], xo)
...@@ -50,8 +50,8 @@ def test_conv1d(): ...@@ -50,8 +50,8 @@ def test_conv1d():
return A[i-1] + A[i] + A[i+1] return A[i-1] + A[i] + A[i+1]
B = tvm.compute(n, computeB, name='B') B = tvm.compute(n, computeB, name='B')
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
px = tvm.thread_axis((0, 1), "pipeline") px, xi = s[B].split(B.op.axis[0], nparts=1)
xo, xi = s[B].split(B.op.axis[0], outer=px) s[B].bind(px, tvm.thread_axis("pipeline"))
s[A].compute_at(s[B], px) s[A].compute_at(s[B], px)
stmt = lower(s, [B]) stmt = lower(s, [B])
stmt = tvm.ir_pass.SplitPipeline(stmt, False) stmt = tvm.ir_pass.SplitPipeline(stmt, False)
......
...@@ -9,15 +9,15 @@ def test_storage_sync(): ...@@ -9,15 +9,15 @@ def test_storage_sync():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
block_x = tvm.thread_axis(None, "blockIdx.x") xo, xi = s[A2].split(A2.op.axis[0], factor=8)
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x) s[A2].bind(xo, tvm.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared") s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
......
...@@ -7,8 +7,9 @@ def test_virtual_thread(): ...@@ -7,8 +7,9 @@ def test_virtual_thread():
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2') A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx") vx = tvm.thread_axis("vthread", name="vx")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx) xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
s[A2].bind(xo, vx)
xo, xi = s[A2].split(xi, 8) xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
......
...@@ -36,11 +36,10 @@ def test_bound3(): ...@@ -36,11 +36,10 @@ def test_bound3():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
s[A1].set_scope("shared") s[A1].set_scope("shared")
thread_x = tvm.thread_axis((0, 16), "threadIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], 32) xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x) xi0, xi1 = s[A2].split(xi, nparts=16)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
yo, yi = s[A2].split(A2.op.axis[1], 16) yo, yi = s[A2].split(A2.op.axis[1], 16)
s[A2].reorder(xo, xi0, yo, xi1, yi) s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo) s[A1].compute_at(s[A2], yo)
...@@ -60,12 +59,10 @@ def test_bound_scan(): ...@@ -60,12 +59,10 @@ def test_bound_scan():
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
assert tuple(s_scan.shape) == (m, n) assert tuple(s_scan.shape) == (m, n)
s = tvm.Schedule(s_scan.op) s = tvm.Schedule(s_scan.op)
XX = s.cache_read(X, "local", s_update) XX = s.cache_read(X, "local", s_update)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
s[XX].compute_at(s[s_update], xo) s[XX].compute_at(s[s_update], xo)
s.normalize() s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
...@@ -105,22 +102,59 @@ def test_bound_rfactor(): ...@@ -105,22 +102,59 @@ def test_bound_rfactor():
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
k = tvm.reduce_axis((0, n)) k = tvm.reduce_axis((0, n))
B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
kf = tvm.reduce_axis((0, 4))
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
_, ki = s[B].split(k, outer=kf) kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf) BF = s.rfactor(B, kf)
s.normalize() s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert(bounds[BF.op.axis[0]].extent.value == 4) assert(bounds[BF.op.axis[0]].extent.value == 4)
assert(bounds[BF.op.axis[1]].extent.value == 1) assert(bounds[BF.op.axis[1]].extent.value == 1)
def test_bound_group_schedule():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
g.compute_at(s[x2], x2.op.axis[0])
assert s[x1].group == g
assert s[x].group == g
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent == n
def test_bound_nest_group():
m = tvm.Var("m")
n = tvm.Var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
s = tvm.Schedule(x2.op)
g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
assert s[x].group == g1
assert s[x1].group == g2
g2.compute_at(s[x2], x2.op.axis[0])
g1.compute_at(s[x1], s[x1].op.axis[1])
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert bounds[x.op.axis[0]].extent.value == 1
assert bounds[x.op.axis[1]].extent.value == 1
assert bounds[x1.op.axis[0]].extent.value == 1
assert bounds[x1.op.axis[1]].extent == n
if __name__ == "__main__": if __name__ == "__main__":
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
test_bound3()
test_bound_rfactor() test_bound_rfactor()
test_bound_blur() test_bound_blur()
test_bound_conv1d() test_bound_conv1d()
test_bound_scan()
test_bound3()
test_bound1() test_bound1()
test_bound2() test_bound2()
...@@ -59,7 +59,7 @@ def test_scan_fix_point(): ...@@ -59,7 +59,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op) body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
...@@ -69,8 +69,7 @@ def test_scan_fix_point(): ...@@ -69,8 +69,7 @@ def test_scan_fix_point():
s_update = tvm.compute((l, m, n), s_update = tvm.compute((l, m, n),
lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
s_scan = tvm.scan(s_init, s_update, s_state) s_scan = tvm.scan(s_init, s_update, s_state)
body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
...@@ -89,7 +88,7 @@ def test_scan_fix_point(): ...@@ -89,7 +88,7 @@ def test_scan_fix_point():
[s1_update, s2_update], [s1_update, s2_update],
[s1, s2]) [s1, s2])
body = tvm.schedule.ScanGetBody(r0.op) body = tvm.schedule.ScanGetBody(r0.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op, body) fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
assert(fxpt[r1.op.spatial_axis_[0]].value == 1) assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
test_scan0() test_scan0()
......
...@@ -35,8 +35,8 @@ def test_add_pipeline(): ...@@ -35,8 +35,8 @@ def test_add_pipeline():
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
grid_x = tvm.thread_axis((0, 1), "pipeline") px, x = s[C].split(C.op.axis[0], nparts=1)
_, x = s[C].split(C.op.axis[0], outer=grid_x) s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd") fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi) fsplits = tvm.ir_pass.SplitHostDevice(fapi)
print(fsplits[1].body) print(fsplits[1].body)
......
...@@ -57,9 +57,9 @@ def test_buffer_linebuff(): ...@@ -57,9 +57,9 @@ def test_buffer_linebuff():
# correctness checks # correctness checks
if (read_data_valid.get_int()): if (read_data_valid.get_int()):
# Derive convolution window indices # Derive convolution window indices
baseIdx = read_idx/(kernel_width*kernel_width) baseIdx = read_idx // (kernel_width*kernel_width)
offsetIdx = read_idx%(kernel_width*kernel_width) offsetIdx = read_idx % (kernel_width*kernel_width)
yOffset = offsetIdx/kernel_width yOffset = offsetIdx // kernel_width
xOffset = offsetIdx%kernel_width xOffset = offsetIdx%kernel_width
pixIndex = baseIdx + yOffset * window_width + xOffset pixIndex = baseIdx + yOffset * window_width + xOffset
assert(read_data.get_int()==test_data[pixIndex]) assert(read_data.get_int()==test_data[pixIndex])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment