Commit cf0fc361 by tqchen

[REFACTOR] Move Node always bebind NodeRef, expose ->

parent 13383928
......@@ -13,27 +13,10 @@
namespace tvm {
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
// Internal node container of Range
class RangeNode;
// Internal node container of RDomain
class RDomainNode;
/*! \brief Node range */
class Range : public NodeRef {
......@@ -48,14 +31,16 @@ class Range : public NodeRef {
Range(Expr begin, Expr end);
/*! \return The extent of the range */
Expr extent() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RangeNode* operator->() const;
/*! \return the begining of the range */
inline const Expr& begin() const {
return static_cast<const RangeNode*>(node_.get())->begin;
}
inline const Expr& begin() const;
/*! \return the end of the range */
inline const Expr& end() const {
return static_cast<const RangeNode*>(node_.get())->end;
}
inline const Expr& end() const;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*)
os << '[' << r.begin() << ", " << r.end() <<')';
return os;
......@@ -65,28 +50,6 @@ class Range : public NodeRef {
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
......@@ -104,35 +67,27 @@ class RDomain : public NodeRef {
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
* \brief access the internal node container
* \return the pointer to the internal node container
*/
explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
CHECK(node_->is_type<RDomainNode>());
}
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const {
return static_cast<const RDomainNode*>(node_.get())->index.size();
}
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const {
return static_cast<const RDomainNode*>(node_.get())->index[i];
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
/*!
* \return The domain of the reduction.
*/
inline const Domain& domain() const {
return static_cast<const RDomainNode*>(node_.get())->domain;
}
friend std::ostream& operator<<(std::ostream &os, const RDomain& r) { // NOLINT(*)
inline const Domain& domain() const;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r.domain() << ")";
return os;
}
......@@ -141,6 +96,79 @@ class RDomain : public NodeRef {
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
// implements of inline functions
inline const RangeNode* Range::operator->() const {
return static_cast<const RangeNode*>(node_.get());
}
inline const Expr& Range::begin() const {
return (*this)->begin;
}
inline const Expr& Range::end() const {
return (*this)->end;
}
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
inline const Domain& RDomain::domain() const {
return (*this)->domain;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
......@@ -192,7 +192,7 @@ struct TensorReadNode : public ExprNode {
TensorReadNode(Tensor && tensor, Array<Expr> && indices)
: tensor(std::move(tensor)), indices(std::move(indices)) {
node_type_ = kReduceNode;
dtype_ = tensor.dtype();
dtype_ = tensor->dtype;
}
~TensorReadNode() {
this->Destroy();
......@@ -201,7 +201,7 @@ struct TensorReadNode : public ExprNode {
return "TensorReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, tensor.dtype());
CHECK_EQ(dtype_, tensor->dtype);
for (size_t i = 0; i < indices.size(); ++i) {
CHECK_EQ(indices[i].dtype(), kInt32);
}
......
......@@ -15,34 +15,8 @@
namespace tvm {
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
public:
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("dim_index", &dim_index);
fvisit("shape", &shape);
fvisit("source", &source);
}
};
// Internal node container of Tensor
class TensorNode;
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
......@@ -94,30 +68,13 @@ class Tensor : public NodeRef {
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorNode* operator->() const;
/*! \return The dimension of the tensor */
inline size_t ndim() const {
return static_cast<const TensorNode*>(node_.get())->shape.size();
}
/*! \return The name of the tensor */
inline const std::string& name() const {
return static_cast<const TensorNode*>(node_.get())->name;
}
/*! \return The data type tensor */
inline DataType dtype() const {
return static_cast<const TensorNode*>(node_.get())->dtype;
}
/*! \return The source expression of intermediate tensor */
inline const Expr& source() const {
return static_cast<const TensorNode*>(node_.get())->source;
}
/*! \return The internal dimension index used by source expression */
inline const Array<Var>& dim_index() const {
return static_cast<const TensorNode*>(node_.get())->dim_index;
}
/*! \return The shape of the tensor */
inline const Array<Expr>& shape() const {
return static_cast<const TensorNode*>(node_.get())->shape;
}
inline size_t ndim() const;
/*!
* \brief Take elements from the tensor
* \param args The indices
......@@ -138,14 +95,55 @@ class Tensor : public NodeRef {
std::vector<Tensor> InputTensors() const;
/*! \return whether the tensor stores a result of reduction */
bool IsRTensor() const;
// printt function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t.shape()
<< ", source=" << t.source()
<< ", name=" << t.name() << ')';
return os;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
};
/*! \brief Node to represent a tensor */
class TensorNode : public Node {
public:
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("dim_index", &dim_index);
fvisit("shape", &shape);
fvisit("source", &source);
}
};
// implementations
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(node_.get());
}
inline size_t Tensor::ndim() const {
return (*this)->shape.size();
}
inline std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t->shape
<< ", source=" << t->source
<< ", name=" << t->name << ')';
return os;
}
} // namespace tvm
#endif // TVM_TENSOR_H_
......@@ -256,7 +256,7 @@ void Expr::Print(std::ostream& os) const {
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices;
os << n->tensor->name << n->indices;
return;
}
default: {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment