Commit 816419be by tqchen

Check in basic schedule container

parent 03735b42
...@@ -147,7 +147,7 @@ class Array : public NodeRef { ...@@ -147,7 +147,7 @@ class Array : public NodeRef {
/*! /*!
* \brief set i-th element of the array. * \brief set i-th element of the array.
* \param i The index * \param i The index
* \param other The value to be setted. * \param value The value to be setted.
*/ */
inline void Set(size_t i, const T& value) { inline void Set(size_t i, const T& value) {
this->CopyOnWrite(); this->CopyOnWrite();
...@@ -161,7 +161,7 @@ class Array : public NodeRef { ...@@ -161,7 +161,7 @@ class Array : public NodeRef {
size_t index; size_t index;
/*! /*!
* \brief assign operator * \brief assign operator
* \param value The value to be assigned * \param other The value to be assigned
* \return reference to self. * \return reference to self.
*/ */
inline ArrayItemRef& operator=(const T& other) { inline ArrayItemRef& operator=(const T& other) {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <typeinfo> #include <typeinfo>
#include <type_traits>
namespace tvm { namespace tvm {
......
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Common data structure for codegen
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_
namespace tvm {
// incomplete spec.
struct Assign : public Node {
Expr src;
Expr offset;
Var ptr;
};
struct Assign : public Node {
Expr src;
Expr offset;
Var ptr;
};
struct Loop : public Node {
Expr init;
Expr cond;
Stmt body;
};
struct IfThenElse : public Node {
Expr cond;
Expr then_;
Stmt else_;
};
} // namespace tvm
#endif // TVM_CODEGEN_H_
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.h
* \brief Define a schedule.
*/
#ifndef TVM_SCHEDULE_H_
#define TVM_SCHEDULE_H_
#include <string>
#include "./base.h"
#include "./split.h"
#include "./tensor.h"
namespace tvm {
// Node container for Schedule
class ScheduleNode;
// Node container for AttachSpec
class AttachSpecNode;
/*! \brief the attachment type */
enum AttachType : int {
kRoot = 0,
kInline = 1,
kSplit = 2
};
/*! \brief schedule container */
class Schedule : public NodeRef {
public:
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
};
/*! \brief schedule container */
class AttachSpec : public NodeRef {
public:
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const AttachSpecNode* operator->() const;
};
// defintion of node containers
/*! \brief The attach specification of each subschedule */
class AttachSpecNode : public Node {
public:
/*! \brief The attachment type */
AttachType attach_type;
/*!
* \brief The split to be attached to,
* only valid when attach_type is kRoot
*/
Split attach_split;
/*! \brief the child schedule to be attached. */
Schedule schedule;
const char* type_key() const override {
return "AttachSpecNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("attach_type", &attach_type);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("attach_split", &attach_split);
fvisit("schedule", &schedule);
}
};
/*! \brief represents the schedule of the tensor */
class ScheduleNode : public Node {
public:
/*! \brief Tensor to be scheduled */
Tensor tensor;
/*! \brief The thread scope level of the schedule */
std::string scope;
/*! \brief Splits over domains or rdomains */
Array<Split> splits;
/*! \brief attach specifications */
Array<AttachSpec> attachs;
const char* type_key() const override {
return "AttachSpecNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("scope", &scope);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("tensor", &tensor);
fvisit("splits", &splits);
fvisit("attachs", &attachs);
}
};
// implementations
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline const AttachSpecNode* AttachSpec::operator->() const {
return static_cast<const AttachSpecNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file split.h
* \brief Define a split over Domain or RDomain
*/
#ifndef TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./array.h"
#include "./domain.h"
namespace tvm {
// internal node container for split.
class SplitNode;
/*! \brief Split over input domain */
class Split : public NodeRef {
public:
/*! \brief default constructor */
Split() {}
/*! \return Whether the split is over RDomain or not */
inline bool is_over_rdom() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SplitNode* operator->() const;
};
/*!
* \brief base class of split node,
* specifies a split over domain
* split also defines how to generate
*/
class SplitNode : public Node {
public:
/*! \brief whether the split is over reduction domain*/
int split_over_rdom{0};
/*!
* \brief given the output domain, infer input domain
* \param split_index The index to be splitted on
* \param out_domain The outer domain
* \return The inferred inner domain.
*/
virtual Domain InferInnerDomain(Expr split_index, Domain out_domain) const = 0;
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The dimension to split on */
int64_t dim_index;
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
const char* type_key() const override {
return "DimSplitNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("split_over_rdom", &split_over_rdom);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("factor", &factor);
}
Domain InferInnerDomain(Expr split_index, Domain out_domain) const override {
LOG(FATAL) << "not implemented";
return Domain();
}
};
// Implementations of inline functions
inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get());
}
inline bool Split::is_over_rdom() const {
return (*this)->split_over_rdom != 0;
}
} // namespace tvm
#endif // TVM_SPLIT_H_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/c_api.h> #include <tvm/c_api.h>
#include <memory> #include <memory>
#include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -84,6 +85,8 @@ struct APIVariantValue { ...@@ -84,6 +85,8 @@ struct APIVariantValue {
} }
inline operator int() const { inline operator int() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_id, kLong);
CHECK_LE(v_union.v_long,
std::numeric_limits<int>::max());
return v_union.v_long; return v_union.v_long;
} }
inline operator std::string() const { inline operator std::string() const {
......
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
*/
#include <tvm/schedule.h>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment