Commit 3fb85796 by Tianqi Chen Committed by GitHub

[REFACTOR] Add Types to IterVar, Isolate Operator (#62)

* [IterVar/REFACTOR] Add types to IterVar

* [ARITH/REFACTOR] Move IntSet to include

* [REFACTOR/OP] Move Op detail to seperate folder.

* fix test
parent c8ebfbe3
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file int_set.h * \file arithmetic.h
* \brief Abstraction for all integer set operations. * \brief Algebra and set operations.
*/ */
#ifndef TVM_ARITHMETIC_INT_SET_H_ #ifndef TVM_ARITHMETIC_H_
#define TVM_ARITHMETIC_INT_SET_H_ #define TVM_ARITHMETIC_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <vector> #include <vector>
#include <unordered_map>
#include <memory>
#include "./expr.h"
namespace tvm { namespace tvm {
/*! \brief namespace of arithmetic */
namespace arith { namespace arith {
/*!
* \brief Sign of an expression or set.
*/
enum SignType { enum SignType {
kPositive, kPositive,
kNegative, kNegative,
...@@ -102,6 +106,41 @@ class IntSet : public NodeRef { ...@@ -102,6 +106,41 @@ class IntSet : public NodeRef {
}; };
/*! /*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};
/*!
* \brief Base class of all IntSet containers. * \brief Base class of all IntSet containers.
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Node {
...@@ -109,9 +148,6 @@ struct IntSetNode : public Node { ...@@ -109,9 +148,6 @@ struct IntSetNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
}; };
using ExprIntSetMap = std::unordered_map<Expr, IntSet,
Halide::ExprHash, Halide::ExprEqual>;
/*! /*!
* \brief Find an symbolic integer set that contains all possible values of * \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables. * e given the domain of each iteration variables.
...@@ -122,6 +158,13 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ...@@ -122,6 +158,13 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet,
*/ */
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const Variable*, IntSet>& dom_map);
...@@ -135,11 +178,18 @@ IntSet EvalSet(Expr e, ...@@ -135,11 +178,18 @@ IntSet EvalSet(Expr e,
*/ */
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param r The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
/*! /*!
* \brief Find the integer set of every sub-expression, given the * \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables. * domain of each iteration variables.
...@@ -148,7 +198,8 @@ IntSet EvalSet(Range r, ...@@ -148,7 +198,8 @@ IntSet EvalSet(Range r,
* \param dom_map The domain of each variable. * \param dom_map The domain of each variable.
* \return the map from the expression to its possible value. * \return the map from the expression to its possible value.
*/ */
ExprIntSetMap EvalSetForEachSubExpr(Expr r, ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! /*!
...@@ -165,11 +216,6 @@ IntSet Union(const Array<IntSet>& sets); ...@@ -165,11 +216,6 @@ IntSet Union(const Array<IntSet>& sets);
*/ */
IntSet Intersect(const Array<IntSet>& sets); IntSet Intersect(const Array<IntSet>& sets);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
/*! /*!
* \brief Deduce the bound of the target variable in a expression, * \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to * give the domain of each variables. Return undefined IntSet to
...@@ -178,18 +224,49 @@ inline const IntSetNode* IntSet::operator->() const { ...@@ -178,18 +224,49 @@ inline const IntSetNode* IntSet::operator->() const {
* \param v The target variable to be deduced. * \param v The target variable to be deduced.
* \param cond The conditional expression. * \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce. * \param hint_map The domain of variable, used to help deduce.
* \param relax The domain of each variable, used to relax the domain. * \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values. * \return An integer set that can cover all the possible values.
*/ */
IntSet DeduceBound(Expr v, Expr cond, IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map, const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map); const Map<Var, IntSet>& relax_map);
IntSet DeduceBound(Expr v, Expr e, /*!
const std::unordered_map<const Variable*, IntSet>& hint_map, * \brief Same as DeduceBound with unordered_map signature.
const std::unordered_map<const Variable*, IntSet>& relax_map); *
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
#endif // TVM_ARITHMETIC_H_
#endif // TVM_ARITHMETIC_INT_SET_H_
...@@ -22,6 +22,8 @@ using Halide::Bool; ...@@ -22,6 +22,8 @@ using Halide::Bool;
using Halide::Int; using Halide::Int;
using Halide::UInt; using Halide::UInt;
using Halide::Handle; using Halide::Handle;
using Halide::ExprHash;
using Halide::ExprEqual;
using Halide::Expr; using Halide::Expr;
using Halide::VarExpr; using Halide::VarExpr;
...@@ -57,7 +59,14 @@ class Var : public Halide::VarExpr { ...@@ -57,7 +59,14 @@ class Var : public Halide::VarExpr {
Type t = Int(32)) : VarExpr(name_hint, t) {} Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {} explicit Var(VarExpr v) : VarExpr(v) {}
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
/*! \brief type indicate the container type */ /*! \brief type indicate the container type */
using ContainerType = Variable; using ContainerType = Variable;
}; };
...@@ -91,6 +100,72 @@ class Range : public Halide::IR::Range { ...@@ -91,6 +100,72 @@ class Range : public Halide::IR::Range {
}; };
/*! /*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
*
* The type of iter var can be overriden via
* stage.iter_var_attrs given they are compatible.
*/
enum IterVarType : int {
/*!
* \brief Data parallel iteration.
* This normally corresponds to axis of Tensor.
* Allow all IterVar manipulations.
*
* \note This does not mean the loop
* have to be executed in parallel fashion.
*/
kDataPar = 0,
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
*
* Disallow: split/fuse/vectorize/parallel
*/
kThreadIndex = 1,
/*!
* \brief Communicative reduction.
* Cannot be directly parallelized.
*
* Disallow: parallel/vectorize
*/
kCommReduce = 2,
/*!
* \brief Serial loops with loop carry dependency,
* the iteration must execute in order.
* Cannot be re-ordered.
*
* Disallow: reorder/parallel/vectorize
*/
kOrdered = 3,
/*!
* \brief IterVar is opaque,
*
* May not corresponds to any generated loop
* Disallow all IterVar manipulations and compute_at
*
* \note This is usually used to implement composite op
* or external op, where the
*/
kOpaque = 4,
// The following are possible additional
// types that are provided during schedule
/*!
* \brief The execution is unrolled.
*/
kUnrolled = 5,
/*!
* \brief The loop is vectorized.
*/
kVectorized = 6,
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7
};
/*!
* \brief Iteration Variable, * \brief Iteration Variable,
* represents an iteration over an integer interval. * represents an iteration over an integer interval.
*/ */
...@@ -101,13 +176,6 @@ class IterVar : public NodeRef { ...@@ -101,13 +176,6 @@ class IterVar : public NodeRef {
// construct from shared ptr. // construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief construction of iteration variable.
* \param dom The iteration domain.
* \param var_name The name of iteration variable.
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
*/
explicit IterVar(Range dom, std::string var_name = "i", std::string thread_tag = "");
/*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
...@@ -120,6 +188,22 @@ class IterVar : public NodeRef { ...@@ -120,6 +188,22 @@ class IterVar : public NodeRef {
using ContainerType = IterVarNode; using ContainerType = IterVarNode;
}; };
/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
IterVar thread_axis(Range dom, std::string tag);
/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
IterVar reduce_axis(Range dom, std::string name = "rv");
using Domain = Array<Range>; using Domain = Array<Range>;
// functions // functions
...@@ -168,6 +252,8 @@ class IterVarNode : public Node { ...@@ -168,6 +252,8 @@ class IterVarNode : public Node {
Range dom; Range dom;
/*! \brief The looping variable */ /*! \brief The looping variable */
Var var; Var var;
/*! \brief The type of the IterVar */
IterVarType iter_type;
/*! /*!
* \brief additional tag on the iteration variable, * \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag. * set this if this is binded already to a known thread tag.
...@@ -177,10 +263,13 @@ class IterVarNode : public Node { ...@@ -177,10 +263,13 @@ class IterVarNode : public Node {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dom", &dom); v->Visit("dom", &dom);
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
v->Visit("thread_tag", &thread_tag); v->Visit("thread_tag", &thread_tag);
} }
static IterVar make(Range dom, Var var, std::string thread_tag); static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
...@@ -195,6 +284,20 @@ inline IterVar::operator Expr() const { ...@@ -195,6 +284,20 @@ inline IterVar::operator Expr() const {
return (*this)->var; return (*this)->var;
} }
inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommRedude";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
}
return "Unknown";
}
} // namespace tvm } // namespace tvm
namespace std { namespace std {
......
...@@ -32,15 +32,23 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -32,15 +32,23 @@ struct Reduce : public ExprNode<Reduce> {
Expr source; Expr source;
/*! \brief The reduction axis */ /*! \brief The reduction axis */
Array<IterVar> axis; Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief construct expr from op and rdom */ /*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom); static Expr make(std::string op, Expr src,
Array<IterVar> rdom,
Expr condition = make_const(Bool(1), true));
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("condition", &condition);
} }
static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
...@@ -86,6 +94,12 @@ constexpr const char* storage_scope = "storage_scope"; ...@@ -86,6 +94,12 @@ constexpr const char* storage_scope = "storage_scope";
* \brief Mark storage scope of realizations * \brief Mark storage scope of realizations
*/ */
constexpr const char* realize_scope = "realize_scope"; constexpr const char* realize_scope = "realize_scope";
/*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
} // namespace attr } // namespace attr
/*! \brief namespace of TVM Intrinsic functions */ /*! \brief namespace of TVM Intrinsic functions */
......
...@@ -61,8 +61,8 @@ namespace ir { ...@@ -61,8 +61,8 @@ namespace ir {
* // These traps may not happen if we program carefully * // These traps may not happen if we program carefully
* // But it is recommended to use ExprFunctor, which allows direct * // But it is recommended to use ExprFunctor, which allows direct
* // return the value, this helps us to avoid such problems. * // return the value, this helps us to avoid such problems.
* \encode
* *
* \endcode
*/ */
class IRVisitor { class IRVisitor {
public: public:
......
...@@ -7,11 +7,136 @@ ...@@ -7,11 +7,136 @@
#define TVM_OPERATION_H_ #define TVM_OPERATION_H_
#include <string> #include <string>
#include <vector>
#include <unordered_map>
#include "./expr.h" #include "./expr.h"
#include "./tensor.h" #include "./tensor.h"
#include "./schedule.h"
#include "./arithmetic.h"
namespace tvm { namespace tvm {
using arith::IntSet;
/*!
* \brief Temporary data structure to store union
* of bounds of each axis of Tensor.
*/
struct TensorDom {
// constructor
explicit TensorDom(int ndim)
: data(ndim) {}
/*! \brief The domain data */
std::vector<std::vector<IntSet> > data;
};
/*!
* \brief The map beteen tensor and operation it feeds to.
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
};
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*!
* \return The list of iteration variable at root
* \note root_iter_vars dedides the shape of the outputs.
*/
virtual Array<IterVar> root_iter_vars() const = 0;
/*!
* \brief Get data type. i-th output tensor.
* \param i The output index.
* \return type of i-th output.
*/
virtual Type output_dtype(size_t i) const = 0;
/*!
* \brief Get shape of i-th output tensor.
* \param i The output index.
* \return shape of i-th output.
*/
virtual Array<Expr> output_shape(size_t i) const = 0;
/*!
* \brief List all the input Tensors.
* \return List if input tensors.
*/
virtual Array<Tensor> InputTensors() const = 0;
/*!
* \brief Replace the input of the operation by pattern specified by rmap.
*
* \param self The reference to self.
* \param rmap The replacement map.
* \return self if nothing is replaced, otherwise return replaced op.
*/
virtual Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
virtual void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
* Set the range of each root_iter_vars in the op to out_dom_map
*
* \param self The reference to self.
* \param graph_ctx The global graph context information.
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
virtual void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
* \brief Build the Realize statement that realizes
* the op's output tensors.
* \param self The reference to self.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \return A statement that add production and wraps consumer.
*/
virtual Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const = 0;
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
};
/*! /*!
* \brief A placeholder op represents an input placeholder. * \brief A placeholder op represents an input placeholder.
*/ */
...@@ -21,13 +146,31 @@ class PlaceholderOpNode : public OperationNode { ...@@ -21,13 +146,31 @@ class PlaceholderOpNode : public OperationNode {
Array<Expr> shape; Array<Expr> shape;
/*! \brief The data type of the input. */ /*! \brief The data type of the input. */
Type dtype; Type dtype;
// override behavior.
int num_outputs() const final { int num_outputs() const final;
return 1;
}
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -55,13 +198,31 @@ class ComputeOpNode : public OperationNode { ...@@ -55,13 +198,31 @@ class ComputeOpNode : public OperationNode {
Expr body; Expr body;
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
// override functions
int num_outputs() const final { int num_outputs() const final;
return 1;
}
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -107,6 +268,26 @@ class ScanOpNode : public OperationNode { ...@@ -107,6 +268,26 @@ class ScanOpNode : public OperationNode {
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -188,19 +369,9 @@ inline Tensor compute(Array<Expr> shape, ...@@ -188,19 +369,9 @@ inline Tensor compute(Array<Expr> shape,
return compute(shape, fc, name); return compute(shape, fc, name);
} }
// inline function.
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::Tensor> {
std::size_t operator()(const ::tvm::Tensor& k) const {
if (k.defined() && k->op.defined()) {
return k->op.hash();
} else{
return k.hash();
}
}
};
} // namespace std
#endif // TVM_OPERATION_H_ #endif // TVM_OPERATION_H_
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./operation.h" #include "./tensor.h"
namespace tvm { namespace tvm {
...@@ -31,13 +31,6 @@ enum AttachType : int { ...@@ -31,13 +31,6 @@ enum AttachType : int {
kScanUpdate = 5 kScanUpdate = 5
}; };
/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2,
kParallel = 3
};
/*! \brief Stage, contains scheduling for a stage of computation. */ /*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef { class Stage : public NodeRef {
public: public:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./arithmetic.h"
namespace tvm { namespace tvm {
...@@ -156,34 +157,8 @@ class TensorNode : public Node { ...@@ -156,34 +157,8 @@ class TensorNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
}; };
/*!
* \brief base class of operation node.
*/
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
};
// Implementations of inline functions // Implementations of inline functions
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
}
inline const TensorNode* Tensor::operator->() const { inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(node_.get()); return static_cast<const TensorNode*>(node_.get());
} }
...@@ -249,5 +224,16 @@ struct hash<::tvm::Operation> { ...@@ -249,5 +224,16 @@ struct hash<::tvm::Operation> {
return k.hash(); return k.hash();
} }
}; };
}
template <>
struct hash<::tvm::Tensor> {
std::size_t operator()(const ::tvm::Tensor& k) const {
if (k.defined() && k->op.defined()) {
return k->op.hash();
} else{
return k.hash();
}
}
};
} // namespace std
#endif // TVM_TENSOR_H_ #endif // TVM_TENSOR_H_
...@@ -132,7 +132,7 @@ def compute(shape, fcompute, name="compute"): ...@@ -132,7 +132,7 @@ def compute(shape, fcompute, name="compute"):
if ndim != len(arg_names): if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var]) body = fcompute(*[v.var for v in dim_var])
body = convert(body) body = convert(body)
op_node = _api_internal._ComputeOp( op_node = _api_internal._ComputeOp(
...@@ -181,7 +181,7 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -181,7 +181,7 @@ def scan(init, update, state_placeholder, name="scan"):
state_placeholder = [state_placeholder] state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder): if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length") raise ValueError("init, update, state_placeholder must have same length")
axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res) return (res[0] if len(res) == 1 else res)
...@@ -225,16 +225,19 @@ def Buffer(shape, dtype=None, ...@@ -225,16 +225,19 @@ def Buffer(shape, dtype=None,
name, ptr, shape, strides, dtype) name, ptr, shape, strides, dtype)
def IterVar(dom=None, name=None, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
"""Create a iteration variable """Internal function to create IterVar
Parameters Parameters
---------- ----------
dom : Range dom : Range
The domain of iteration. The domain of iteration.
name : str name : str
The name of iteration variable. The name of iteration variable.
iter_type : int
The type of iteration.
thread_tag : str thread_tag : str
The thread tag of the iteration variable. The thread tag of the iteration variable.
...@@ -252,10 +255,41 @@ def IterVar(dom=None, name=None, thread_tag=''): ...@@ -252,10 +255,41 @@ def IterVar(dom=None, name=None, thread_tag=''):
if not isinstance(dom, _collections.Range): if not isinstance(dom, _collections.Range):
raise ValueError("dom need to be Range") raise ValueError("dom need to be Range")
if name is None:
name = thread_tag if thread_tag else name
name = name if name else 'iter' name = name if name else 'iter'
return _api_internal._IterVar(dom, name, thread_tag) var = Var(name)
return _api_internal._IterVar(dom, var, iter_type, thread_tag)
def thread_axis(dom, tag, name=''):
"""Create a new IterVar to represent thread index.
Parameters
----------
dom : Range
The domain of iteration.
tag : str
The thread tag
name : str, optional
The name of the var.
"""
name = name if name else tag
return _IterVar(dom, name, 1, tag)
def reduce_axis(dom, name="rv"):
"""Create a new IterVar for reduction.
Parameters
----------
dom : Range
The domain of iteration.
name : str
The name of the variable.
"""
return _IterVar(dom, name, 2)
def sum(expr, axis): def sum(expr, axis):
......
...@@ -55,7 +55,14 @@ class Range(NodeBase): ...@@ -55,7 +55,14 @@ class Range(NodeBase):
@register_node @register_node
class IterVar(NodeBase, _expr.ExprOp): class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.""" """Represent iteration variable."""
pass DataPar = 0
ThreadIndex = 1
CommReduce = 2
Ordered = 3
DimInfo = 4
Unrolled = 5
Vectorized = 6
Parallelized = 7
@register_node @register_node
......
# Code organization # Code Organization
- api API functionr registration Header files in include are public APIs that share across modules.
There can be internal header files within each module that sit in src.
The current code modules in src.
- api API function registration
- lang The definition of DSL related data structure - lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- arithmetic Arithmetic expression and set simplification - arithmetic Arithmetic expression and set simplification
- op The detail implementations about each operation(compute, scan, placeholder)
- schedule The operations on the schedule graph before converting to IR.
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- runtime Minimum runtime related codes. - codegen The code generator.
- codegen The code generator - runtime Minimum runtime related codes
...@@ -6,8 +6,7 @@ ...@@ -6,8 +6,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include "../arithmetic/int_set.h" #include <tvm/arithmetic.h>
#include "../arithmetic/modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
...@@ -188,10 +189,12 @@ TVM_REGISTER_API(_OpGetOutput) ...@@ -188,10 +189,12 @@ TVM_REGISTER_API(_OpGetOutput)
static_cast<size_t>(args[1].operator int64_t())); static_cast<size_t>(args[1].operator int64_t()));
}); });
TVM_REGISTER_API(_IterVar) TVM_REGISTER_API(_IterVar)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVar(args[0], args[1], args[2]); *ret = IterVarNode::make(
args[0], args[1],
static_cast<IterVarType>(args[2].operator int()),
args[3]);
}); });
TVM_REGISTER_API(_Schedule) TVM_REGISTER_API(_Schedule)
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include "./int_set.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* \brief Canonicalize simplification. * \brief Canonicalize simplification.
*/ */
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include "./int_set.h" #include <tvm/arithmetic.h>
#include "./canonical.h" #include "./canonical.h"
#include "./compute_expr.h" #include "./compute_expr.h"
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <pass/Interval.h> #include <pass/Interval.h>
#include <unordered_map> #include <unordered_map>
#include "./int_set.h"
#include "./compute_expr.h" #include "./compute_expr.h"
#include "./int_set_internal.h" #include "./int_set_internal.h"
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./int_set.h" #include <tvm/arithmetic.h>
#include "./modular.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/arithmetic.h>
#include <limits> #include <limits>
#include "./modular.h"
#include "./int_set_internal.h" #include "./int_set_internal.h"
namespace tvm { namespace tvm {
......
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_
#include <tvm/expr.h>
#include "./int_set.h"
namespace tvm {
namespace arith {
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.base = 0; e.coeff = 1;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};
/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/arithmetic.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include "./llvm_common.h" #include "./llvm_common.h"
#include "../../arithmetic/modular.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
......
...@@ -26,27 +26,36 @@ Range Range::make_with_min_extent(Expr min, Expr extent) { ...@@ -26,27 +26,36 @@ Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent)); return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
} }
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag) IterVar IterVarNode::make(Range dom, Var var,
: IterVar(IterVarNode::make(dom, Var(var_name, Int(32)), thread_tag)) {} IterVarType t, std::string thread_tag) {
IterVar IterVarNode::make(Range dom, Var var, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>(); std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->dom = dom; n->dom = dom;
n->var = var; n->var = var;
n->iter_type = t;
n->thread_tag = thread_tag; n->thread_tag = thread_tag;
return IterVar(n); return IterVar(n);
} }
IterVar thread_axis(Range dom, std::string tag) {
return IterVarNode::make(
dom, Var(tag), kThreadIndex, tag);
}
IterVar reduce_axis(Range dom, std::string name) {
return IterVarNode::make(
dom, Var(name), kCommReduce);
}
Expr sum(Expr source, Array<IterVar> rdom) { Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Add", source, rdom); return ir::Reduce::make("Add", source, rdom, make_const(Bool(1), true));
} }
Expr max(Expr source, Array<IterVar> rdom) { Expr max(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Max", source, rdom); return ir::Reduce::make("Max", source, rdom, make_const(Bool(1), true));
} }
Expr min(Expr source, Array<IterVar> rdom) { Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Min", source, rdom); return ir::Reduce::make("Min", source, rdom, make_const(Bool(1), true));
} }
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
......
...@@ -26,7 +26,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -26,7 +26,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op << op->op
<< ", "; << ", ";
p->print(op->source); p->print(op->source);
p->stream << ", axis=" << op->axis << ")"; p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", condition=" << op->condition;
}
p->stream << ")";
}); });
} // namespace Internal } // namespace Internal
...@@ -35,7 +39,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -35,7 +39,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm { namespace tvm {
namespace ir { namespace ir {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) { Expr Reduce::make(std::string op, Expr source,
Array<IterVar> axis, Expr condition) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
}
auto n = std::make_shared<Reduce>(); auto n = std::make_shared<Reduce>();
CHECK(source.defined()); CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
...@@ -45,6 +54,7 @@ Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) { ...@@ -45,6 +54,7 @@ Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
n->source = source; n->source = source;
n->op = op; n->op = op;
n->axis = axis; n->axis = axis;
n->condition = condition;
return Expr(n); return Expr(n);
} }
......
...@@ -4,209 +4,3 @@ ...@@ -4,209 +4,3 @@
*/ */
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(axis.back()->var);
}
op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(std::string name,
Array<IterVar> axis,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->axis = axis;
n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
}
return Operation(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
// Scan
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return static_cast<int>(update.size());
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis};
}
Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
IterVar scan_axis(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
name + ".idx");
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});
} // namespace tvm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file tensor.cc * \file tensor.cc
*/ */
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/operation.h>
#include <ir/IR.h> #include <ir/IR.h>
#include <memory> #include <memory>
...@@ -40,4 +41,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -40,4 +41,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \brief Compute Op.
* \file compute_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "./make_loop.h"
namespace tvm {
using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
int ComputeOpNode::num_outputs() const {
return 1;
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
args.push_back(axis.back()->var);
}
op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(std::string name,
Array<IterVar> axis,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->axis = axis;
n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
}
return Operation(n);
}
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
std::unordered_set<Tensor> visited;
ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) {
const ir::Call *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (!visited.count(t)) {
ret.push_back(t);
visited.insert(t);
}
}
});
return ret;
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Operation ComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
TensorReplacer repl(rmap);
Expr new_body = repl.Mutate(this->body);
if (repl.found) {
return ComputeOpNode::make(name, axis, new_body);
} else {
return self;
}
}
void ComputeOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
TensorDom& dom = out_dom_map->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
}
}
}
};
ir::PostOrderVisit(body, fvisit);
}
void ComputeOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
CHECK(!out_dom_map->count(this->reduce_axis[i]));
(*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
}
}
Stmt ComputeOpNode::BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& realize_body) const {
CHECK_EQ(self.operator->(), this);
Tensor t = self.output(0);
Halide::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
return ir::Realize::make(t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
const Tensor& t,
Stmt* init,
Stmt* provide) {
Stmt no_op = Evaluate::make(0);
std::vector<Stmt> nest;
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce);
Expr init_value, update_value;
if (reduce->op == "Add") {
init_value = make_zero(reduce->type);
update_value = Add::make(t(args), reduce->source);
} else if (reduce->op == "Max") {
init_value = reduce->type.min();
update_value = Max::make(t(args), reduce->source);
} else if (reduce->op == "Min") {
init_value = reduce->type.max();
update_value = Min::make(t(args), reduce->source);
} else {
LOG(FATAL) << "Unsupported reduction " << reduce->op;
}
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
}
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body, args);
}
// message passing to find if IterVar is related to reduction.
void PassDownReduceFlag(const Stage& s,
std::unordered_map<IterVar, int>* p_state) {
auto& state = *p_state;
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
int flag = state.at(s->parent);
state[s->outer] = flag;
state[s->inner] = flag;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
int flag_outer = state.at(s->outer);
int flag_inner = state.at(s->inner);
state[s->fused] = flag_outer | flag_inner;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
int flag = state.at(s->parent);
state[s->rebased] = flag;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt init, provide;
if (this->reduce_axis.size() == 0) {
provide = MakeProvide(this, stage->op.output(0));
} else {
MakeReduction(this, stage->op.output(0), &init, &provide);
}
// make loop nest
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
nest.push_back(op::MakeBoundCheck(
stage, dom_map, false,
std::unordered_set<IterVar>(), value_map));
provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : this->reduce_axis) {
update_state[iv] = 2;
}
for (IterVar iv : this->axis) {
update_state[iv] = 1;
}
// find which iter var is related to reduction and which is related to axis.
PassDownReduceFlag(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
init_value_map[iv] = value_map.at(iv);
}
// skip loops that does not relates to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first);
}
auto init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &init_value_map);
init_nest.push_back(
op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map));
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
return MergeNest(nest, provide);
}
}
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \brief Utility to make loop nest.
* \file make_loop.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/operation.h>
#include "./make_loop.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace op {
using namespace arith;
using namespace ir;
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
* \param s The schedule to be used.
* \param dom_map The domain map of each iteration variable's domain
* \param p_state The message passing state
* IterVar->The assignment.
*/
void PassUpOffset(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer);
Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = state[s->parent] + parent_min;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
}
if (!is_zero(inner_min)) {
state[s->inner] = state[s->inner] + inner_min;
}
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
Expr value = state.at(s->rebased);
Expr parent_min = dom_map.at(s->parent)->min;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = value + parent_min;
} else {
state[s->parent] = value;
}
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate::make(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
// skip this iteration.
value_map[iv] = iv->var;
continue;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
Var var = iv->var;
if (new_loop_var) {
var = Var(iv->var->name_hint + ".init", iv->var.type());
}
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (stage->iter_var_attrs.count(iv)) {
switch (stage->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
case kParallelized: for_type = ForType::Parallel; break;
default: LOG(FATAL) << "Unknown iter type"
<< stage->iter_var_attrs[iv]->iter_type
<< " in the iter_var_attrs";
}
}
if (is_one(dom->extent)) {
nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
For::make(var, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
value_map[iv] = var;
}
// annotate the extent of the IterVar
if (!new_loop_var) {
nest[i + 1].emplace_back(
AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
PassUpOffset(stage, dom_map, &value_map);
return nest;
}
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<Stmt> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map) {
Stmt no_op = Evaluate::make(0);
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(stage, dom_map, &bound_state);
// insert conditions
std::vector<Stmt> nest;
for (IterVar iv : stage->op->root_iter_vars()) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent;
nest.emplace_back(IfThenElse::make(condition, no_op));
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent;
nest.emplace_back(IfThenElse::make(condition, no_op));
}
}
return nest;
}
} // namespace op
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file make_loop.h
* \brief Utility to make loop nest from schedule stage info.
*/
#ifndef TVM_OP_MAKE_LOOP_H_
#define TVM_OP_MAKE_LOOP_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../pass/ir_util.h"
namespace tvm {
namespace op {
using ir::MergeNest;
/*!
* \brief Build loop nest for stage.
*
* \param stage The stage to create a loop nest.
* \param dom_map The range of each iter var.
* \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
*/
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map);
/*!
* \brief Create boundary check condition for given stage.
*
* \param stage The stage to create a loop nest.
* \param dom_map The range of each iter var.
* \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
* \param skip_iter Whether skip certain iteration.
* \param value_map The result value of each IterVar.
*/
std::vector<Stmt>
MakeBoundCheck(const Stage& stage,
const Map<IterVar, Range>& dom_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map);
} // namespace op
} // namespace tvm
#endif // TVM_OP_MAKE_LOOP_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Placeholder op.
* \file placeholder_op.cc
*/
#include <tvm/operation.h>
namespace tvm {
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
int PlaceholderOpNode::num_outputs() const {
return 1;
}
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
Array<Tensor> PlaceholderOpNode::InputTensors() const {
return {};
}
Operation PlaceholderOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
return self;
}
void PlaceholderOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
}
void PlaceholderOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt PlaceholderOpNode::BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
return body;
}
Stmt PlaceholderOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
return Stmt();
}
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \brief Scan Operator.
* \file scan_op.cc
*/
#include <tvm/operation.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./make_loop.h"
#include "../schedule/graph.h"
namespace tvm {
using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ScanOpNode);
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return static_cast<int>(update.size());
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
Array<IterVar> ret{scan_axis};
for (IterVar iv : spatial_axis_) {
ret.push_back(iv);
}
return ret;
}
Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVarNode::make(
Range::make_with_min_extent(0, update[i]->shape[k]),
Var(spatial_name.str()), kOpaque));
}
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
IterVar scan_axis =
IterVarNode::make(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
Var(name + ".idx"), kOrdered);
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}
Array<Tensor> ScanOpNode::InputTensors() const {
Array<Tensor> ret;
for (Tensor t : init) {
ret.push_back(t);
}
for (Tensor t : update) {
ret.push_back(t);
}
return ret;
}
Operation ScanOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
std::shared_ptr<ScanOpNode> n = std::make_shared<ScanOpNode>(*this);
for (size_t i = 0; i < n->init.size(); ++i) {
if (rmap.count(n->init[i])) {
n->init.Set(i, rmap.at(n->init[i]));
}
if (rmap.count(n->update[i])) {
n->update.Set(i, rmap.at(n->update[i]));
}
}
if (!n->init.same_as(init) ||
!n->update.same_as(update)) {
return Operation(n);
} else {
return self;
}
}
void ScanOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
TensorDom* init_dom = nullptr;
TensorDom* update_dom = nullptr;
if (out_dom_map->count(this->init[i])) {
init_dom = &out_dom_map->at(this->init[i]);
}
if (out_dom_map->count(this->update[i])) {
update_dom = &out_dom_map->at(this->update[i]);
}
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, this->init[i]->shape[0])));
}
if (update_dom) {
update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
}
// The update dimensions
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
}
}
}
void ScanOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
using namespace schedule;
CHECK(!out_dom_map->count(this->scan_axis));
std::vector<Tensor> output(this->num_outputs());
for (size_t i = 0; i < output.size(); ++i) {
output[i] = self.output(i);
}
// Update for time axis.
std::vector<IntSet> time_dom;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tensor_dom.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
CHECK(!out_dom_map->count(this->scan_axis));
Range sdom = this->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(this, graph_ctx.feed_graph);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self, body);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tensor_dom.at(output[i]);
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
CHECK(!out_dom_map->count(sp_ax));
CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it.
(*out_dom_map)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
} else {
// not a fix point, need to include everything.
(*out_dom_map)[sp_ax] = sp_ax->dom;
}
}
}
}
Stmt ScanOpNode::BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& dom_map,
const Stmt& body) const {
CHECK_EQ(self.operator->(), this);
Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_with_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
Stmt ret = body;
size_t sp_idx = 0;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = self.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
ret = ir::Realize::make(t->op, t->value_index, t->dtype,
bounds, const_true(), ret);
}
return ret;
}
Stmt ScanOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmt::make(
stage->op, attr::scan_update_scope, this->scan_axis->var,
Evaluate::make(0));
Stmt init = AttrStmt::make(
stage->op, attr::scan_init_scope, 0,
Evaluate::make(0));
size_t begin_scan = 0;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
CHECK_EQ(begin_scan, i);
begin_scan = i + 1;
}
}
std::unordered_map<IterVar, Expr> vmap;
std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap);
nest[begin_scan].push_back(init);
nest.push_back(
op::MakeBoundCheck(stage, dom_map, false, empty, vmap));
return MergeNest(nest, provide);
}
} // namespace tvm
...@@ -44,7 +44,7 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { ...@@ -44,7 +44,7 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
if (!r->extent.same_as(new_extent)) changed = true; if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make( new_dom[i] = IterVarNode::make(
Range::make_with_min_extent(new_min, new_extent), Range::make_with_min_extent(new_min, new_extent),
v->var, v->thread_tag); v->var, v->iter_type, v->thread_tag);
} }
if (!changed) { if (!changed) {
return rdom; return rdom;
...@@ -322,11 +322,13 @@ DEFINE_BIOP_EXPR_MUTATE_(Or) ...@@ -322,11 +322,13 @@ DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr new_source = this->Mutate(op->source); Expr new_source = this->Mutate(op->source);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) && if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source)) { op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) {
return e; return e;
} else { } else {
return Reduce::make(op->op, new_source, new_axis); return Reduce::make(op->op, new_source, new_axis, new_cond);
} }
} }
......
...@@ -6,15 +6,17 @@ ...@@ -6,15 +6,17 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "../arithmetic/int_set.h"
#include "../arithmetic/int_set_internal.h" #include "../arithmetic/int_set_internal.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using arith::IntSet; using arith::IntSet;
using arith::DeduceBound;
using arith::Intersect;
// a partition means the expr is equal to true in the interval // a partition means the expr is equal to true in the interval
struct Partition { struct Partition {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file auto_inline_elem_wise.cc * \file auto_inline_elem_wise.cc
*/ */
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include <tvm/operation.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
namespace tvm { namespace tvm {
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file bound.cc * \file bound.cc
...@@ -8,10 +7,11 @@ ...@@ -8,10 +7,11 @@
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/operation.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./graph.h" #include "./graph.h"
#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -55,8 +55,10 @@ void PassDown(const Stage& s, ...@@ -55,8 +55,10 @@ void PassDown(const Stage& s,
bool match = is_zero(outer_rng->min); bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false; if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match) CHECK(match)
<< r->outer
<< "IterVar is used in two places as outer scope," << "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same"; << " cannot prove their extents are the same "
<< outer_ext << " vs " << outer_rng->extent;
} }
} }
} else { } else {
...@@ -195,162 +197,6 @@ void PassUp(const Stage& s, ...@@ -195,162 +197,6 @@ void PassUp(const Stage& s,
} }
} }
// All the itervars that are needed to output bound of op.
// For most op, it is root_iter_vars
// For Scan, it also contains the additional spatial axis.
Array<IterVar> OutputRelatedIterVars(const Operation& op) {
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<IterVar> ret{scan->scan_axis};
for (IterVar iv : scan->spatial_axis_) {
ret.push_back(iv);
}
return ret;
} else {
return op->root_iter_vars();
}
}
/*! \brief temporary data structure to store Tensor domain */
struct TensorDom {
// constructor
explicit TensorDom(int ndim)
: data(ndim) {}
/*! \brief The domain data*/
std::vector<std::vector<IntSet> > data;
};
/*!
* \brief Propagate bound to target
* \param dom_map The domain map to be propagated
* \param out The tensor set to be passed
* \return The result bound
*/
void BoundProp(const Operation& op,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom> *out) {
if (op.as<ComputeOpNode>()) {
auto fvisit = [&dom_map, out](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && out->count(t)) {
TensorDom& dom = out->at(t);
for (size_t i = 0; i < t.ndim(); ++i) {
dom.data[i].push_back(EvalSet(call->args[i], dom_map));
}
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
size_t sp_idx = 0;
for (size_t i = 0; i < scan->init.size(); ++i) {
TensorDom* init_dom = nullptr;
TensorDom* update_dom = nullptr;
if (out->count(scan->init[i])) {
init_dom = &out->at(scan->init[i]);
}
if (out->count(scan->update[i])) {
update_dom = &out->at(scan->update[i]);
}
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, scan->init[i]->shape[0])));
}
if (update_dom) {
update_dom->data[0].push_back(dom_map.at(scan->scan_axis->var.get()));
}
// The update dimensions
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
}
}
} else if (op.as<PlaceholderOpNode>()) {
// do nothing
} else {
LOG(FATAL) << "unknown operation mode " << op->type_key();
}
}
// Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
std::vector<Tensor> output(op->num_outputs());
for (size_t i = 0; i < output.size(); ++i) {
output[i] = op.output(i);
}
// Update for time axis.
std::vector<IntSet> time_dom;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(scan, fg);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(op, body);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it.
(*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
} else {
// not a fix point, need to include everything.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
}
void GatherOpBound(const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
(*rmap)[compute->axis[i]] = r;
}
for (size_t i = 0; i < compute->reduce_axis.size(); ++i) {
CHECK(!rmap->count(compute->reduce_axis[i]));
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, fg, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
LOG(FATAL) << "unknown operation mode " << op->type_key();
}
}
// check if scope // check if scope
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
using runtime::ThreadScope; using runtime::ThreadScope;
...@@ -362,14 +208,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -362,14 +208,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
} }
void InferRootBound(const Stage& stage, void InferRootBound(const Stage& stage,
const FeedGraph& feed_graph, const GraphContext& ctx,
const AttachPath& attach_path, const AttachPath& attach_path,
std::unordered_map<IterVar, Range>* rmap) { std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline) CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops"; << "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return; if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) { if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : OutputRelatedIterVars(stage->op)) { for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
CHECK(!rmap->count(iv)); CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom; (*rmap)[iv] = iv->dom;
...@@ -390,8 +236,8 @@ void InferRootBound(const Stage& stage, ...@@ -390,8 +236,8 @@ void InferRootBound(const Stage& stage,
for (int i = 0; i < stage->op->num_outputs(); ++i) { for (int i = 0; i < stage->op->num_outputs(); ++i) {
Tensor t = stage->op.output(i); Tensor t = stage->op.output(i);
tmap.emplace(t, TensorDom(static_cast<int>(t.ndim()))); tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
auto it = feed_graph.find(t); auto it = ctx.feed_graph.find(t);
if (it != feed_graph.end()) { if (it != ctx.feed_graph.end()) {
for (const Operation& op : it->second) { for (const Operation& op : it->second) {
if (!parent.defined() || op != parent->op) { if (!parent.defined() || op != parent->op) {
consumers.insert(op); consumers.insert(op);
...@@ -443,7 +289,7 @@ void InferRootBound(const Stage& stage, ...@@ -443,7 +289,7 @@ void InferRootBound(const Stage& stage,
PassUp(parent, *rmap, &up_state); PassUp(parent, *rmap, &up_state);
std::unordered_map<const Variable*, IntSet> dom_map; std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : OutputRelatedIterVars(parent->op)) { for (auto iv : parent->op->root_iter_vars()) {
Range r; Range r;
if (up_state.count(iv)) { if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom); r = up_state.at(iv).cover_range(iv->dom);
...@@ -457,7 +303,7 @@ void InferRootBound(const Stage& stage, ...@@ -457,7 +303,7 @@ void InferRootBound(const Stage& stage,
} }
} }
// prop from parent. // prop from parent.
BoundProp(parent->op, dom_map, &tmap); parent->op->PropBoundToInputs(parent->op, dom_map, &tmap);
} }
// Bound prop by other consumers. // Bound prop by other consumers.
// To explain the the general logic, consider the example: // To explain the the general logic, consider the example:
...@@ -490,13 +336,13 @@ void InferRootBound(const Stage& stage, ...@@ -490,13 +336,13 @@ void InferRootBound(const Stage& stage,
CHECK(found || attach.size() == 0) CHECK(found || attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op << "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op; << " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : OutputRelatedIterVars(op)) { for (auto iv : op->root_iter_vars()) {
Range r = rmap->at(iv); Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set); dom_map[iv->var.get()] = EvalSet(r, relax_set);
} }
BoundProp(op, dom_map, &tmap); op->PropBoundToInputs(op, dom_map, &tmap);
} }
GatherOpBound(stage->op, feed_graph, tmap, rmap); stage->op->GatherBound(stage->op, ctx, tmap, rmap);
} }
Map<IterVar, Range> InferBound(const Schedule& sch) { Map<IterVar, Range> InferBound(const Schedule& sch) {
...@@ -504,13 +350,14 @@ Map<IterVar, Range> InferBound(const Schedule& sch) { ...@@ -504,13 +350,14 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
for (Operation op : sch->outputs) { for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op); roots.push_back(sch->stage_map[op]->op);
} }
FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots)); GraphContext ctx;
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch); AttachPath attach_path = CreateAttachPath(sch);
std::unordered_map<IterVar, Range> ret; std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1]; const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, feed_graph, attach_path, &ret); InferRootBound(stage, ctx, attach_path, &ret);
// pass down to get bound of all iter vars. // pass down to get bound of all iter vars.
PassDown(stage, &ret); PassDown(stage, &ret);
// setup outer most threads. // setup outer most threads.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <unordered_set> #include <unordered_set>
#include "./graph.h" #include "./graph.h"
...@@ -69,28 +70,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) { ...@@ -69,28 +70,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
while (!stack.empty()) { while (!stack.empty()) {
Operation op = stack.back(); Operation op = stack.back();
stack.pop_back(); stack.pop_back();
Array<Tensor> deps; Array<Tensor> deps = op->InputTensors();
if (op.as<ComputeOpNode>()) {
auto fvisit = [&deps](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index));
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
for (Tensor t : scan->init) {
deps.push_back(t);
}
for (Tensor t : scan->update) {
deps.push_back(t);
}
} else if (op.as<PlaceholderOpNode>()) {
} else {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
rmap.Set(op, deps); rmap.Set(op, deps);
for (Tensor t : deps) { for (Tensor t : deps) {
if (t->op.defined() && visited.count(t->op.get()) == 0) { if (t->op.defined() && visited.count(t->op.get()) == 0) {
...@@ -137,7 +117,6 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) { ...@@ -137,7 +117,6 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) {
AttachPath CreateAttachPath(Schedule sch) { AttachPath CreateAttachPath(Schedule sch) {
AttachPath ret; AttachPath ret;
for (Stage stage : sch->stages) { for (Stage stage : sch->stages) {
if (stage->attach_type == kScanUpdate) { if (stage->attach_type == kScanUpdate) {
const Stage& parent = stage->attach_stage; const Stage& parent = stage->attach_stage;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/operation.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -21,11 +22,6 @@ namespace schedule { ...@@ -21,11 +22,6 @@ namespace schedule {
using ReadGraph = Map<Operation, Array<Tensor> >; using ReadGraph = Map<Operation, Array<Tensor> >;
/*! /*!
* \brief The map beteen tensor and operation it feeds to
*/
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
/*!
* \brief AttachPath maps op-> a list of IterVar * \brief AttachPath maps op-> a list of IterVar
*/ */
using AttachPath = Map<Operation, Array<IterVar> >; using AttachPath = Map<Operation, Array<IterVar> >;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file schedule_dataflow_rewrite.cc * \file schedule_dataflow_rewrite.cc
*/ */
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <unordered_set> #include <unordered_set>
...@@ -19,35 +20,7 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { ...@@ -19,35 +20,7 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) {
return array_node->data.size(); return array_node->data.size();
} }
using ir::TensorKey;
// The replacer of cache. // The replacer of cache.
class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
ir::TensorKey key{op->func, op->value_index};
auto it = vmap_.find(key);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<TensorKey, Tensor>& vmap_;
};
class VarReplacer : public ir::IRMutator { class VarReplacer : public ir::IRMutator {
public: public:
explicit VarReplacer( explicit VarReplacer(
...@@ -66,46 +39,14 @@ class VarReplacer : public ir::IRMutator { ...@@ -66,46 +39,14 @@ class VarReplacer : public ir::IRMutator {
// Replace data flow appears in all stages given the tensor change. // Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced. // Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages, void ReplaceDataFlow(const Array<Stage>& stages,
std::unordered_map<TensorKey, Tensor>* vmap) { std::unordered_map<Tensor, Tensor>* vmap) {
for (Stage s : stages) { for (Stage s : stages) {
if (s->op.as<ComputeOpNode>()) { Operation op = s->op->ReplaceInputs(s->op, *vmap);
const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); if (!op.same_as(s->op)) {
TensorReplacer repl(*vmap); for (int i = 0; i < op->num_outputs(); ++i) {
Expr body = repl.Mutate(compute->body); (*vmap)[s->op.output(i)] = op.output(i);
if (repl.found) {
Operation op = ComputeOpNode::make(
compute->name, compute->axis, body);
(*vmap)[TensorKey{s->op, 0}] = op.output(0);
s->op = op;
}
} else if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
std::shared_ptr<ScanOpNode> n =
std::make_shared<ScanOpNode>(*scan);
// copy on write semantics ganrantees correctness
for (size_t i = 0; i < n->init.size(); ++i) {
TensorKey key{n->init[i]->op, n->init[i]->value_index};
if (vmap->count(key)) {
n->init.Set(i, vmap->at(key));
}
}
for (size_t i = 0; i < n->update.size(); ++i) {
TensorKey key{n->update[i]->op, n->update[i]->value_index};
if (vmap->count(key)) {
n->update.Set(i, vmap->at(key));
}
} }
if (!n->init.same_as(scan->init) || s->op = op;
!n->update.same_as(scan->update)) {
Operation op(n);
for (int i = 0; i < op->num_outputs(); ++i) {
(*vmap)[TensorKey{s->op, i}] = op.output(i);
}
s->op = op;
}
} else if (s->op.as<PlaceholderOpNode>()) {
} else {
LOG(FATAL) << "unhandled problem";
} }
} }
} }
...@@ -124,25 +65,17 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -124,25 +65,17 @@ Tensor Schedule::cache_read(const Tensor& tensor,
Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) { Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
return tensor(Array<Expr>(i.begin(), i.end())); return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str()); }, os.str());
std::unordered_map<TensorKey, Tensor> vsub; std::unordered_map<Tensor, Tensor> vsub;
vsub[TensorKey{tensor->op, tensor->value_index}] = cache; vsub[tensor] = cache;
std::unordered_map<TensorKey, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
for (Operation op : readers) { for (Operation op : readers) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute)
<< "cache read only take ComputeOp as readers";
Stage s = operator[](op); Stage s = operator[](op);
compute = s->op.as<ComputeOpNode>(); Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
CHECK(!repl_op.same_as(s->op))
TensorReplacer repl(vsub);
Expr body = repl.Mutate(compute->body);
CHECK(repl.found)
<< "Cannot find " << tensor << "Cannot find " << tensor
<< " in the body of specified reader " << op; << " in the inputs of " << s->op;
Operation repl_op = ComputeOpNode::make( vmap[s->op.output(0)] = repl_op.output(0);
compute->name, compute->axis, body);
vmap[TensorKey{s->op, 0}] = repl_op.output(0);
s->op = repl_op; s->op = repl_op;
} }
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
...@@ -172,7 +105,8 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -172,7 +105,8 @@ Tensor Schedule::cache_write(const Tensor& tensor,
std::unordered_map<const Variable*, Expr> vsub; std::unordered_map<const Variable*, Expr> vsub;
for (IterVar iv : compute->axis) { for (IterVar iv : compute->axis) {
args.push_back(iv->var); args.push_back(iv->var);
IterVar new_iv(iv->dom, iv->var->name_hint + ".c"); IterVar new_iv = IterVarNode::make(
iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv); new_axis.push_back(new_iv);
vsub[iv->var.get()] = new_iv->var; vsub[iv->var.get()] = new_iv->var;
} }
...@@ -185,8 +119,8 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -185,8 +119,8 @@ Tensor Schedule::cache_write(const Tensor& tensor,
compute->name, compute->axis, compute->name, compute->axis,
cache_tensor(args)); cache_tensor(args));
std::unordered_map<TensorKey, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow((*this)->stages, &vmap); ReplaceDataFlow((*this)->stages, &vmap);
// mutate orig stage // mutate orig stage
...@@ -227,7 +161,8 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -227,7 +161,8 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
size_t idx = FindNodeRef(leaf_vars, iv); size_t idx = FindNodeRef(leaf_vars, iv);
if (idx < leaf_vars->data.size()) { if (idx < leaf_vars->data.size()) {
// insert rebase // insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb"); IterVar rebased = IterVarNode::make(
Range(), iv->var.copy_with_suffix(".rb"), iv->iter_type);
s->relations.push_back(RebaseNode::make(iv, rebased)); s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_; leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased; rebase_map[iv] = rebased;
...@@ -286,7 +221,7 @@ void InjectInline(const Schedule& sch) { ...@@ -286,7 +221,7 @@ void InjectInline(const Schedule& sch) {
} }
} }
} }
std::unordered_map<TensorKey, Tensor> repl; std::unordered_map<Tensor, Tensor> repl;
// rewrite dataflow // rewrite dataflow
for (size_t i = 0; i < sch->stages.size(); ++i) { for (size_t i = 0; i < sch->stages.size(); ++i) {
if (new_body[i].defined() && if (new_body[i].defined() &&
...@@ -295,7 +230,7 @@ void InjectInline(const Schedule& sch) { ...@@ -295,7 +230,7 @@ void InjectInline(const Schedule& sch) {
CHECK(compute); CHECK(compute);
Operation op = ComputeOpNode::make( Operation op = ComputeOpNode::make(
compute->name, compute->axis, new_body[i]); compute->name, compute->axis, new_body[i]);
repl[TensorKey{sch->stages[i]->op, 0}] = op.output(0); repl[sch->stages[i]->op.output(0)] = op.output(0);
Stage s = sch->stages[i]; Stage s = sch->stages[i];
s->op = op; s->op = op;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file schedule_lang.cc * \file schedule_lang.cc
*/ */
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <unordered_set> #include <unordered_set>
#include "./graph.h" #include "./graph.h"
...@@ -35,16 +36,31 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) ...@@ -35,16 +36,31 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0; return 0;
} }
void Split(StageNode* self, IterVar parent, void CheckSplit(StageNode* self, IterVar parent, IterVar outer) {
IterVar outer, IterVar inner, Expr factor) { // Check if split is valid.
if (self->attach_type == kScanUpdate) { if (self->attach_type == kScanUpdate) {
CHECK(!parent.same_as(self->all_iter_vars[0])) CHECK(!parent.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update"; << "Cannot split on axis[0] of scan update";
} }
if (outer.defined()) {
CHECK_EQ(outer->iter_type, kThreadIndex)
<< "outer in split have to be ThreadIndex";
CHECK_EQ(parent->iter_type, kDataPar)
<< "Split by by kThreadIndex requires kDataPar IterVar "
<< " given " << IterVarType2String(parent->iter_type);
} else {
CHECK(parent->iter_type == kDataPar ||
parent->iter_type == kCommReduce ||
parent->iter_type == kOrdered)
<< "Cannot split on " << IterVarType2String(parent->iter_type);
}
}
void Split(StageNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
size_t pos = FindLeafVar(all_vars, leaf_vars, parent); size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
self->relations.push_back(SplitNode::make(parent, outer, inner, factor)); self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
// add vars to all vars // add vars to all vars
all_vars->data.push_back(outer.node_); all_vars->data.push_back(outer.node_);
...@@ -66,11 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -66,11 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) { .set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
switch (op->iter_type) { p->stream << IterVarType2String(op->iter_type);
case kUnrolled: p->stream << "unroll"; break;
case kVectorized: p->stream << "vectorize"; break;
case kParallel: p->stream << "parallel"; break;
}
}); });
Stage::Stage(Operation op) { Stage::Stage(Operation op) {
...@@ -78,7 +90,16 @@ Stage::Stage(Operation op) { ...@@ -78,7 +90,16 @@ Stage::Stage(Operation op) {
n->op = op; n->op = op;
n->origin_op = op; n->origin_op = op;
n->all_iter_vars = op->root_iter_vars(); n->all_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = n->all_iter_vars; // remove opaque var from leaf.
Array<IterVar> clean;
for (IterVar iv : n->all_iter_vars) {
if (iv->iter_type != kOpaque) clean.push_back(iv);
}
if (clean.size() == n->all_iter_vars.size()) {
n->leaf_iter_vars = n->all_iter_vars;
} else {
n->leaf_iter_vars = clean;
}
node_ = n; node_ = n;
} }
...@@ -122,18 +143,22 @@ Stage& Stage::compute_root() { // NOLINT(*) ...@@ -122,18 +143,22 @@ Stage& Stage::compute_root() { // NOLINT(*)
Stage& Stage::split( Stage& Stage::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results. CheckSplit(operator->(), parent, IterVar());
IterVar outer(Range(), parent->var->name_hint + ".outer"); IterVar outer = IterVarNode::make(
IterVar inner(Range(), parent->var->name_hint + ".inner"); Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
*p_outer = outer; *p_inner = inner; IterVar inner = IterVarNode::make(
Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
*p_outer = outer;
*p_inner = inner;
Split(operator->(), parent, outer, inner, factor); Split(operator->(), parent, outer, inner, factor);
return *this; return *this;
} }
Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results. CheckSplit(operator->(), parent, outer);
IterVar inner(Range(), parent->var->name_hint + ".inner"); std::string name_inner = parent->var->name_hint + ".inner";
IterVar inner = IterVarNode::make(
Range(), Var(name_inner, parent->var.type()), parent->iter_type);
*p_inner = inner; *p_inner = inner;
Split(operator->(), parent, outer, inner, factor); Split(operator->(), parent, outer, inner, factor);
...@@ -144,11 +169,27 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -144,11 +169,27 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
StageNode* self = operator->(); StageNode* self = operator->();
if (self->attach_type == kScanUpdate) { if (self->attach_type == kScanUpdate) {
CHECK(!inner.same_as(self->all_iter_vars[0])) CHECK(!inner.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update"; << "Cannot fuse on axis[0] of scan update";
CHECK(!outer.same_as(self->all_iter_vars[0])) CHECK(!outer.same_as(self->all_iter_vars[0]))
<< "Cannot split on axis[0] of scan update"; << "Cannot fuse on axis[0] of scan update";
} }
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); CHECK(outer->iter_type == kDataPar ||
outer->iter_type == kCommReduce ||
outer->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(outer->iter_type);
CHECK(inner->iter_type == kDataPar ||
inner->iter_type == kCommReduce ||
inner->iter_type == kOrdered)
<< "Cannot fuse " << IterVarType2String(outer->iter_type);
IterVarType iter_type = outer->iter_type;
if (inner->iter_type > iter_type) iter_type = inner->iter_type;
std::string fused_name =
outer->var->name_hint + "." + inner->var->name_hint + ".fused";
IterVar fused = IterVarNode::make(
Range(), Var(fused_name, outer->var.type()), iter_type);
*p_target = fused; *p_target = fused;
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
...@@ -169,8 +210,13 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT ...@@ -169,8 +210,13 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
StageNode* self = operator->(); StageNode* self = operator->();
CHECK(!self->op.as<ScanOpNode>()) for (IterVar iv : order) {
<< "Cannot reorder axis of scan"; CHECK(iv->iter_type == kDataPar ||
iv->iter_type == kCommReduce ||
iv->iter_type == kThreadIndex)
<< "Cannot reorder IterVar("
<< IterVarType2String(iv->iter_type) << ")";
}
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos; std::vector<size_t> pos;
...@@ -248,7 +294,7 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*) ...@@ -248,7 +294,7 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*)
} }
Stage& Stage::parallel(IterVar var) { // NOLINT(*) Stage& Stage::parallel(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kParallel)); SetAttr(operator->(), var, IterVarAttr(kParallelized));
return *this; return *this;
} }
......
...@@ -6,445 +6,25 @@ ...@@ -6,445 +6,25 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using namespace arith;
using namespace ir; using namespace ir;
// Two private scope marks
namespace attr {
constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_update_scope = "scan_update_scope";
constexpr const char* scan_init_scope = "scan_init_scope";
} // namespace attr
/*!
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassDownFlag(const Stage& s,
std::unordered_map<IterVar, int>* p_state) {
auto& state = *p_state;
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
int flag = state.at(s->parent);
state[s->outer] = flag;
state[s->inner] = flag;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
int flag_outer = state.at(s->outer);
int flag_inner = state.at(s->inner);
state[s->fused] = flag_outer | flag_inner;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
int flag = state.at(s->parent);
state[s->rebased] = flag;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
* \param s The schedule to be used.
* \param dom_map The domain map of each iteration variable's domain
* \param p_state The message passing state
* IterVar->The assignment.
*/
void PassUpOffset(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer);
Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = state[s->parent] + parent_min;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = state[s->outer] + outer_min;
}
if (!is_zero(inner_min)) {
state[s->inner] = state[s->inner] + inner_min;
}
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
Expr value = state.at(s->rebased);
Expr parent_min = dom_map.at(s->parent)->min;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = value + parent_min;
} else {
state[s->parent] = value;
}
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& sch,
const Map<IterVar, Range>& dom_map,
size_t begin_loop,
bool reduce_init_loop,
const std::unordered_map<IterVar, bool>& bound_state,
const std::unordered_map<IterVar, bool>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map) {
auto leaf_iter_vars = sch->leaf_iter_vars;
Stmt no_op = Evaluate::make(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
if (skip_iter.count(iv) && skip_iter.at(iv)) {
// skip this iteration.
value_map[iv] = iv->var;
continue;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
Var var = iv->var;
if (reduce_init_loop) {
var = Var(iv->var->name_hint + ".init", iv->var.type());
}
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (sch->iter_var_attrs.count(iv)) {
switch (sch->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kParallel: for_type = ForType::Parallel; break;
case kVectorized: for_type = ForType::Vectorized; break;
}
}
if (is_one(dom->extent)) {
nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
For::make(var, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(
LetStmt::make(var, new_value, no_op));
}
} else if (iv->thread_tag == "vthread") {
// virtual thread
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op));
value_map[iv] = var;
}
if (!reduce_init_loop) {
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &value_map);
// insert conditions
for (IterVar iv : sch->op->root_iter_vars()) {
if (skip_iter.count(iv)) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent;
nest.back().emplace_back(IfThenElse::make(condition, no_op));
}
CHECK(iv->dom.defined());
if (!reduce_init_loop && !iv->dom.same_as(dom)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent;
nest.back().emplace_back(IfThenElse::make(condition, no_op));
}
}
return nest;
}
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
for (const auto& kv : value_map) {
temp.Set(kv.first->var, kv.second);
}
return ir::Substitute(s, temp);
}
Stmt MakeLoop(const Stage& s,
const Map<IterVar, Range>& dom_map,
Stmt provide,
Stmt init) {
std::unordered_map<IterVar, Expr> value_map;
// bound check state.
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : s->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(s, dom_map, &bound_state);
auto nest = MakeLoopNest(
s, dom_map, 0, false,
bound_state, {{}}, &value_map);
provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (compute) {
for (IterVar iv : compute->reduce_axis) {
update_state[iv] = 2;
}
for (IterVar iv : compute->axis) {
update_state[iv] = 1;
}
} else if (scan) {
update_state[scan->scan_axis] = 2;
for (IterVar iv : s->outermost_threads) {
update_state[iv] = 1;
}
}
// find which iter var is related to reduction and which is related to axis.
PassDownFlag(s, &update_state);
auto leaf_iter_vars = s->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = update_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
init_value_map[iv] = value_map.at(iv);
}
// skip loops that does not relates to axis.
std::unordered_map<IterVar, bool> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter[kv.first] = true;
}
auto init_nest = MakeLoopNest(
s, dom_map, begin_loop, true,
bound_state, skip_iter, &init_value_map);
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
provide = MergeNest(reduce, provide);
return MergeNest(
common, Block::make(init, provide));
} else {
return MergeNest(nest, provide);
}
}
Stmt MakeProvide(const ComputeOpNode* op,
const std::vector<Tensor>& tensors) {
Tensor t = tensors[0];
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return Provide::make(t->op, t->value_index, op->body, args);
}
Stmt MakeRealize(const ComputeOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Tensor t = tensors[0];
Halide::Internal::Region bounds;
for (IterVar iv : op->axis) {
bounds.push_back(dom_map.at(iv));
}
return Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
Stmt MakeRealize(const ScanOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Range sdom = dom_map.at(op->scan_axis);
Range tdom = Range::make_with_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
size_t sp_idx = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
const Tensor& t = tensors[i];
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
for (size_t k = 1; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = op->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
body = Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
return body;
}
void MakeReduction(const ComputeOpNode* op,
const std::vector<Tensor>& tensors,
Stmt* init,
Stmt* provide) {
Stmt no_op = Evaluate::make(0);
Tensor t = tensors[0];
std::vector<Stmt> nest;
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce);
Expr init_value, update_value;
if (reduce->op == "Add") {
init_value = make_zero(reduce->type);
update_value = Add::make(t(args), reduce->source);
} else if (reduce->op == "Max") {
init_value = reduce->type.min();
update_value = Max::make(t(args), reduce->source);
} else if (reduce->op == "Min") {
init_value = reduce->type.max();
update_value = Min::make(t(args), reduce->source);
} else {
LOG(FATAL) << "Unsupported reduction " << reduce->op;
}
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
}
Stmt MakePipeline(const Stage& s, Stmt MakePipeline(const Stage& s,
const Map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
Stmt consumer) { Stmt consumer) {
std::vector<Tensor> tensors; Stmt producer = s->op->BuildProvide(s, dom_map);
for (int i = 0; i < s->op->num_outputs(); ++i) { if (producer.defined()) {
tensors.emplace_back(s->op.output(i)); producer = ProducerConsumer::make(s->op, true, producer);
} }
Stmt init, provide;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (compute) {
if (compute->reduce_axis.size() == 0) {
provide = MakeProvide(compute, tensors);
} else {
MakeReduction(compute, tensors, &init, &provide);
}
} else if (scan) {
// Provide is done by the sub operations.
provide = AttrStmt::make(
s->op, attr::scan_update_scope, scan->scan_axis->var,
Evaluate::make(0));
init = AttrStmt::make(
s->op, attr::scan_init_scope, 0,
Evaluate::make(0));
} else {
LOG(FATAL) << "not supported op " << s->op->type_key();
}
Stmt producer = MakeLoop(s, dom_map, provide, init);
producer = ProducerConsumer::make(s->op, true, producer);
Stmt pipeline = producer; Stmt pipeline = producer;
// check if consumer is nop. // check if consumer is nop.
bool is_no_op{false}; bool is_no_op{false};
...@@ -455,16 +35,7 @@ Stmt MakePipeline(const Stage& s, ...@@ -455,16 +35,7 @@ Stmt MakePipeline(const Stage& s,
consumer = ProducerConsumer::make(s->op, false, consumer); consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer); pipeline = Block::make(producer, consumer);
} }
pipeline = s->op->BuildRealize(s->op, dom_map, pipeline);
if (s->op.as<ComputeOpNode>()) {
pipeline = MakeRealize(s->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else if (s->op.as<ScanOpNode>()) {
pipeline = MakeRealize(s->op.as<ScanOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
}
// use attribute to mark scope of the operation. // use attribute to mark scope of the operation.
pipeline = AttrStmt::make( pipeline = AttrStmt::make(
s->op, ir::attr::realize_scope, s->op, ir::attr::realize_scope,
...@@ -477,7 +48,7 @@ Stmt MakePipeline(const Stage& s, ...@@ -477,7 +48,7 @@ Stmt MakePipeline(const Stage& s,
class InjectAttach : public IRMutator { class InjectAttach : public IRMutator {
public: public:
InjectAttach(const Stage& stage, InjectAttach(const Stage& stage,
const Map<IterVar, Range>& dom_map) const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), dom_map_(dom_map) {} : stage_(stage), dom_map_(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
...@@ -515,7 +86,7 @@ class InjectAttach : public IRMutator { ...@@ -515,7 +86,7 @@ class InjectAttach : public IRMutator {
// the operations to be carried // the operations to be carried
const Stage& stage_; const Stage& stage_;
// domain map // domain map
const Map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// internal stack about realization scope. // internal stack about realization scope.
std::vector<const Node*> producer_; std::vector<const Node*> producer_;
}; };
...@@ -525,7 +96,7 @@ class InjectScanStep : public IRMutator { ...@@ -525,7 +96,7 @@ class InjectScanStep : public IRMutator {
public: public:
InjectScanStep(const Stage& stage, InjectScanStep(const Stage& stage,
const Operation& scan_op, const Operation& scan_op,
const Map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool is_init) bool is_init)
: stage_(stage), scan_op_(scan_op), : stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init) {} dom_map_(dom_map), is_init_(is_init) {}
...@@ -556,23 +127,11 @@ class InjectScanStep : public IRMutator { ...@@ -556,23 +127,11 @@ class InjectScanStep : public IRMutator {
const Stage& stage_; const Stage& stage_;
const Operation& scan_op_; const Operation& scan_op_;
// domain map // domain map
const Map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// whether it is init. // whether it is init.
bool is_init_; bool is_init_;
}; };
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute != nullptr)
<< "can only inline compute op";
Array<Var> args;
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
return Inline(body, op, args, compute->body);
}
// Postprocessing of schedule op // Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer. // Replace the init and update's expression by scan's buffer.
...@@ -719,7 +278,7 @@ class SchedulePostProc : public IRMutator { ...@@ -719,7 +278,7 @@ class SchedulePostProc : public IRMutator {
}; };
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) { Schedule sch, Map<IterVar, Range> dom_map_) {
Stmt body = Stmt(); Stmt body = Stmt();
// scan init and scan updates // scan init and scan updates
std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach; std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
...@@ -743,6 +302,10 @@ Stmt ScheduleOps( ...@@ -743,6 +302,10 @@ Stmt ScheduleOps(
} }
} }
} }
std::unordered_map<IterVar, Range> dom_map;
for (auto kv : dom_map_) {
dom_map[kv.first] = kv.second;
}
// reverse the post DFS order. // reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
......
...@@ -22,7 +22,7 @@ TEST(Tensor, Reduce) { ...@@ -22,7 +22,7 @@ TEST(Tensor, Reduce) {
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A = placeholder({m, l}, Float(32), "A"); Tensor A = placeholder({m, l}, Float(32), "A");
Tensor B = placeholder({n, l}, Float(32), "B"); Tensor B = placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k"); IterVar rv = reduce_axis(Range{0, l}, "k");
auto C = compute({m, n}, [&](Var i, Var j) { auto C = compute({m, n}, [&](Var i, Var j) {
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
......
...@@ -29,7 +29,7 @@ def test_dot(): ...@@ -29,7 +29,7 @@ def test_dot():
n = tvm.Var('n') n = tvm.Var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
k = tvm.IterVar((0, n), name='k') k = tvm.reduce_axis((0, n), 'k')
C = tvm.compute((1,), lambda _: tvm.sum(A[k] * B[k], axis=k), name='C') C = tvm.compute((1,), lambda _: tvm.sum(A[k] * B[k], axis=k), name='C')
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
fapi = lower(s, [A, B, C]) fapi = lower(s, [A, B, C])
......
...@@ -11,8 +11,8 @@ def test_add(): ...@@ -11,8 +11,8 @@ def test_add():
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 256 num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x) _, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x) _, x = s[C].split(x, outer=thread_x)
_, x = s[C].split(x, factor=4) _, x = s[C].split(x, factor=4)
......
...@@ -11,7 +11,7 @@ def test_gemm(): ...@@ -11,7 +11,7 @@ def test_gemm():
l = n l = n
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B') B = tvm.placeholder((m, l), name='B')
k = tvm.IterVar((0, l), name='k') k = tvm.reduce_axis((0, l), name='k')
C = tvm.compute( C = tvm.compute(
(n, m), (n, m),
lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k), lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
...@@ -22,10 +22,10 @@ def test_gemm(): ...@@ -22,10 +22,10 @@ def test_gemm():
scale = 8 scale = 8
num_thread = 8 num_thread = 8
block_factor = scale * num_thread block_factor = scale * num_thread
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.IterVar(thread_tag="blockIdx.y") block_y = tvm.thread_axis(None, "blockIdx.y")
thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y") thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
CC = s.cache_write(C, "local") CC = s.cache_write(C, "local")
AA = s.cache_read(A, "shared", [CC]) AA = s.cache_read(A, "shared", [CC])
......
...@@ -6,14 +6,14 @@ def test_sum(): ...@@ -6,14 +6,14 @@ def test_sum():
n = tvm.Var('n') n = tvm.Var('n')
m = tvm.Var('m') m = tvm.Var('m')
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
k = tvm.IterVar((0, m)) k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
# schedule # schedule
s = tvm.Schedule(B.op) s = tvm.Schedule(B.op)
# create iter var and assign them tags. # create iter var and assign them tags.
num_thread = 1 num_thread = 1
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x) _, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x) _, x = s[B].split(x, outer=thread_x)
......
...@@ -4,7 +4,6 @@ import numpy as np ...@@ -4,7 +4,6 @@ import numpy as np
def test_scan(): def test_scan():
m = tvm.Var("m") m = tvm.Var("m")
n = tvm.Var("n") n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X") X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n)) s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_init = tvm.compute((1, n), lambda _, i: X[0, i])
...@@ -14,8 +13,8 @@ def test_scan(): ...@@ -14,8 +13,8 @@ def test_scan():
# schedule # schedule
s = tvm.Schedule(res.op) s = tvm.Schedule(res.op)
num_thread = 256 num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x) _, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x) _, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x) _, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x)
......
...@@ -11,8 +11,8 @@ def test_add_pipeline(): ...@@ -11,8 +11,8 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
num_thread = 256 num_thread = 256
grid_x = tvm.IterVar(thread_tag="blockIdx.x") grid_x = tvm.thread_axis(None, "blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) _, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x)
_, x = s[C].split(x, outer=thread_x) _, x = s[C].split(x, outer=thread_x)
......
...@@ -8,7 +8,10 @@ def test_llvm_add_pipeline(): ...@@ -8,7 +8,10 @@ def test_llvm_add_pipeline():
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
print(s[C])
print("a?")
xo, xi = s[C].split(C.op.axis[0], factor=4) xo, xi = s[C].split(C.op.axis[0], factor=4)
print("a?")
s[C].parallel(xo) s[C].parallel(xo)
s[C].vectorize(xi) s[C].vectorize(xi)
def check_llvm(): def check_llvm():
...@@ -83,6 +86,9 @@ def test_llvm_madd_pipeline(): ...@@ -83,6 +86,9 @@ def test_llvm_madd_pipeline():
if __name__ == "__main__": if __name__ == "__main__":
print("a")
test_llvm_add_pipeline() test_llvm_add_pipeline()
print("a")
test_llvm_flip_pipeline() test_llvm_flip_pipeline()
print("a")
test_llvm_madd_pipeline() test_llvm_madd_pipeline()
...@@ -86,8 +86,8 @@ def test_vectorize(): ...@@ -86,8 +86,8 @@ def test_vectorize():
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
s[T].vectorize(yi) s[T].vectorize(yi)
s[T].unroll(xi) s[T].unroll(xi)
UNROLL = 1 UNROLL = tvm.collections.IterVar.Unrolled
VECTORIZE = 2 VECTORIZE = tvm.collections.IterVar.Vectorized
assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
......
...@@ -25,7 +25,7 @@ def test_tensor_reduce(): ...@@ -25,7 +25,7 @@ def test_tensor_reduce():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k") rv = tvm.reduce_axis((0, A.shape[1]), "k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv)) C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv))
# json load save # json load save
C_json = tvm.save_json(C) C_json = tvm.save_json(C)
...@@ -37,7 +37,6 @@ def test_tensor_reduce(): ...@@ -37,7 +37,6 @@ def test_tensor_reduce():
def test_tensor_scan(): def test_tensor_scan():
m = tvm.Var("m") m = tvm.Var("m")
n = tvm.Var("n") n = tvm.Var("n")
t = tvm.IterVar((1, m), "t")
x = tvm.placeholder((m, n)) x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n)) s = tvm.placeholder((m, n))
res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
......
...@@ -9,7 +9,7 @@ def test_storage_sync(): ...@@ -9,7 +9,7 @@ def test_storage_sync():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.thread_axis(None, "blockIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x) xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared") s[A1].set_scope("shared")
......
...@@ -7,8 +7,7 @@ def test_virtual_thread(): ...@@ -7,8 +7,7 @@ def test_virtual_thread():
A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2') A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2')
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
vx = tvm.thread_axis((0, 2), "vthread", name="vx")
vx = tvm.IterVar((0, 2), "vx", thread_tag="vthread")
xo, xi = s[A2].split(A2.op.axis[0], outer=vx) xo, xi = s[A2].split(A2.op.axis[0], outer=vx)
xo, xi = s[A2].split(xi, 8) xo, xi = s[A2].split(xi, 8)
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
......
...@@ -38,7 +38,7 @@ def test_bound3(): ...@@ -38,7 +38,7 @@ def test_bound3():
s = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
s[A1].set_scope("shared") s[A1].set_scope("shared")
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x") thread_x = tvm.thread_axis((0, 16), "threadIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], 32) xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x) xi0, xi1 = s[A2].split(xi, outer=thread_x)
yo, yi = s[A2].split(A2.op.axis[1], 16) yo, yi = s[A2].split(A2.op.axis[1], 16)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment