Commit cf0fc361 by tqchen

[REFACTOR] Move Node always bebind NodeRef, expose ->

parent 13383928
...@@ -13,27 +13,10 @@ ...@@ -13,27 +13,10 @@
namespace tvm { namespace tvm {
/*! \brief range over one dimension */ // Internal node container of Range
class RangeNode : public Node { class RangeNode;
public: // Internal node container of RDomain
/*! \brief beginning of the node */ class RDomainNode;
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief Node range */ /*! \brief Node range */
class Range : public NodeRef { class Range : public NodeRef {
...@@ -48,14 +31,16 @@ class Range : public NodeRef { ...@@ -48,14 +31,16 @@ class Range : public NodeRef {
Range(Expr begin, Expr end); Range(Expr begin, Expr end);
/*! \return The extent of the range */ /*! \return The extent of the range */
Expr extent() const; Expr extent() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RangeNode* operator->() const;
/*! \return the begining of the range */ /*! \return the begining of the range */
inline const Expr& begin() const { inline const Expr& begin() const;
return static_cast<const RangeNode*>(node_.get())->begin;
}
/*! \return the end of the range */ /*! \return the end of the range */
inline const Expr& end() const { inline const Expr& end() const;
return static_cast<const RangeNode*>(node_.get())->end; // overload print function
}
friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*) friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*)
os << '[' << r.begin() << ", " << r.end() <<')'; os << '[' << r.begin() << ", " << r.end() <<')';
return os; return os;
...@@ -65,28 +50,6 @@ class Range : public NodeRef { ...@@ -65,28 +50,6 @@ class Range : public NodeRef {
/*! \brief Domain is a multi-dimensional range */ /*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>; using Domain = Array<Range>;
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain */ /*! \brief reduction domain */
class RDomain : public NodeRef { class RDomain : public NodeRef {
public: public:
...@@ -104,35 +67,27 @@ class RDomain : public NodeRef { ...@@ -104,35 +67,27 @@ class RDomain : public NodeRef {
explicit RDomain(std::initializer_list<Range> domain) explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {} : RDomain(Domain(domain)) {}
/*! /*!
* \brief constructor from node pointer * \brief access the internal node container
* \param nptr Another node shared pointer * \return the pointer to the internal node container
*/ */
explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) { inline const RDomainNode* operator->() const;
CHECK(node_.get() != nullptr);
CHECK(node_->is_type<RDomainNode>());
}
/*! \return The dimension of the RDomain */ /*! \return The dimension of the RDomain */
inline size_t ndim() const { inline size_t ndim() const;
return static_cast<const RDomainNode*>(node_.get())->index.size();
}
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
/*! /*!
* \param i the index. * \param i the index.
* \return i-th index variable in the RDomain * \return i-th index variable in the RDomain
*/ */
inline Var index(size_t i) const { inline Var index(size_t i) const;
return static_cast<const RDomainNode*>(node_.get())->index[i]; /*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
} }
/*! /*!
* \return The domain of the reduction. * \return The domain of the reduction.
*/ */
inline const Domain& domain() const { inline const Domain& domain() const;
return static_cast<const RDomainNode*>(node_.get())->domain; // overload print function
} friend std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
friend std::ostream& operator<<(std::ostream &os, const RDomain& r) { // NOLINT(*)
os << "rdomain(" << r.domain() << ")"; os << "rdomain(" << r.domain() << ")";
return os; return os;
} }
...@@ -141,6 +96,79 @@ class RDomain : public NodeRef { ...@@ -141,6 +96,79 @@ class RDomain : public NodeRef {
/*! \brief use RDom as alias of RDomain */ /*! \brief use RDom as alias of RDomain */
using RDom = RDomain; using RDom = RDomain;
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
// implements of inline functions
inline const RangeNode* Range::operator->() const {
return static_cast<const RangeNode*>(node_.get());
}
inline const Expr& Range::begin() const {
return (*this)->begin;
}
inline const Expr& Range::end() const {
return (*this)->end;
}
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
inline const Domain& RDomain::domain() const {
return (*this)->domain;
}
} // namespace tvm } // namespace tvm
#endif // TVM_DOMAIN_H_ #endif // TVM_DOMAIN_H_
...@@ -192,7 +192,7 @@ struct TensorReadNode : public ExprNode { ...@@ -192,7 +192,7 @@ struct TensorReadNode : public ExprNode {
TensorReadNode(Tensor && tensor, Array<Expr> && indices) TensorReadNode(Tensor && tensor, Array<Expr> && indices)
: tensor(std::move(tensor)), indices(std::move(indices)) { : tensor(std::move(tensor)), indices(std::move(indices)) {
node_type_ = kReduceNode; node_type_ = kReduceNode;
dtype_ = tensor.dtype(); dtype_ = tensor->dtype;
} }
~TensorReadNode() { ~TensorReadNode() {
this->Destroy(); this->Destroy();
...@@ -201,7 +201,7 @@ struct TensorReadNode : public ExprNode { ...@@ -201,7 +201,7 @@ struct TensorReadNode : public ExprNode {
return "TensorReadNode"; return "TensorReadNode";
} }
void Verify() const override { void Verify() const override {
CHECK_EQ(dtype_, tensor.dtype()); CHECK_EQ(dtype_, tensor->dtype);
for (size_t i = 0; i < indices.size(); ++i) { for (size_t i = 0; i < indices.size(); ++i) {
CHECK_EQ(indices[i].dtype(), kInt32); CHECK_EQ(indices[i].dtype(), kInt32);
} }
......
...@@ -15,34 +15,8 @@ ...@@ -15,34 +15,8 @@
namespace tvm { namespace tvm {
/*! \brief Node to represent a tensor */ // Internal node container of Tensor
class TensorNode : public Node { class TensorNode;
public:
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("dim_index", &dim_index);
fvisit("shape", &shape);
fvisit("source", &source);
}
};
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
...@@ -94,30 +68,13 @@ class Tensor : public NodeRef { ...@@ -94,30 +68,13 @@ class Tensor : public NodeRef {
:Tensor(shape, GetFCompute(f), name) {} :Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor") Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {} :Tensor(shape, GetFCompute(f), name) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorNode* operator->() const;
/*! \return The dimension of the tensor */ /*! \return The dimension of the tensor */
inline size_t ndim() const { inline size_t ndim() const;
return static_cast<const TensorNode*>(node_.get())->shape.size();
}
/*! \return The name of the tensor */
inline const std::string& name() const {
return static_cast<const TensorNode*>(node_.get())->name;
}
/*! \return The data type tensor */
inline DataType dtype() const {
return static_cast<const TensorNode*>(node_.get())->dtype;
}
/*! \return The source expression of intermediate tensor */
inline const Expr& source() const {
return static_cast<const TensorNode*>(node_.get())->source;
}
/*! \return The internal dimension index used by source expression */
inline const Array<Var>& dim_index() const {
return static_cast<const TensorNode*>(node_.get())->dim_index;
}
/*! \return The shape of the tensor */
inline const Array<Expr>& shape() const {
return static_cast<const TensorNode*>(node_.get())->shape;
}
/*! /*!
* \brief Take elements from the tensor * \brief Take elements from the tensor
* \param args The indices * \param args The indices
...@@ -138,14 +95,55 @@ class Tensor : public NodeRef { ...@@ -138,14 +95,55 @@ class Tensor : public NodeRef {
std::vector<Tensor> InputTensors() const; std::vector<Tensor> InputTensors() const;
/*! \return whether the tensor stores a result of reduction */ /*! \return whether the tensor stores a result of reduction */
bool IsRTensor() const; bool IsRTensor() const;
// printt function // overload print function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*) friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
os << "Tensor(shape=" << t.shape() };
<< ", source=" << t.source()
<< ", name=" << t.name() << ')'; /*! \brief Node to represent a tensor */
return os; class TensorNode : public Node {
public:
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("dim_index", &dim_index);
fvisit("shape", &shape);
fvisit("source", &source);
} }
}; };
// implementations
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(node_.get());
}
inline size_t Tensor::ndim() const {
return (*this)->shape.size();
}
inline std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t->shape
<< ", source=" << t->source
<< ", name=" << t->name << ')';
return os;
}
} // namespace tvm } // namespace tvm
#endif // TVM_TENSOR_H_ #endif // TVM_TENSOR_H_
...@@ -256,7 +256,7 @@ void Expr::Print(std::ostream& os) const { ...@@ -256,7 +256,7 @@ void Expr::Print(std::ostream& os) const {
} }
case kTensorReadNode: { case kTensorReadNode: {
const auto* n = Get<TensorReadNode>(); const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices; os << n->tensor->name << n->indices;
return; return;
} }
default: { default: {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment