Commit 5f829774 by tqchen

Add domain

parent 5324b211
...@@ -128,6 +128,18 @@ class Array : public NodeRef { ...@@ -128,6 +128,18 @@ class Array : public NodeRef {
if (node_.get() == nullptr) return 0; if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size(); return static_cast<const ArrayNode*>(node_.get())->data.size();
} }
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) {
os << '[';
} else {
os << ", ";
}
os << r[i];
}
os << ']';
return os;
}
}; };
} // namespace tvm } // namespace tvm
......
...@@ -13,14 +13,133 @@ ...@@ -13,14 +13,133 @@
namespace tvm { namespace tvm {
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
//using Domain = Array<Range>; /*! \brief Node range */
class Range : public NodeRef {
public:
/*! \brief constructor */
Range() {}
/*!
* \brief constructor
* \param begin start of the range.
* \param end end of the range.
*/
Range(Expr begin, Expr end);
/*! \return The extent of the range */
Expr extent() const;
/*! \return the begining of the range */
inline const Expr& begin() const {
return static_cast<const RangeNode*>(node_.get())->begin;
}
/*! \return the end of the range */
inline const Expr& end() const {
return static_cast<const RangeNode*>(node_.get())->end;
}
friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*)
os << '[' << r.begin() << ", " << r.end() <<')';
return os;
}
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
class RDomain : public NodeRef { /*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
CHECK(node_->is_type<RDomainNode>());
}
/*! \return The dimension of the RDomain */
inline size_t ndim() const {
return static_cast<const RDomainNode*>(node_.get())->index.size();
}
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const {
return static_cast<const RDomainNode*>(node_.get())->index[i];
}
/*!
* \return The domain of the reduction.
*/
inline const Domain& domain() const {
return static_cast<const RDomainNode*>(node_.get())->domain;
}
friend std::ostream& operator<<(std::ostream &os, const RDomain& r) { // NOLINT(*)
os << "rdomain(" << r.domain() << ")";
return os;
}
}; };
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
} // namespace tvm } // namespace tvm
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "./tensor.h" #include "./tensor.h"
#include "./expr.h" #include "./expr.h"
namespace tvm { namespace tvm {
/*! \brief variable node for symbolic variables */ /*! \brief variable node for symbolic variables */
class VarNode : public ExprNode { class VarNode : public ExprNode {
public: public:
......
...@@ -16,7 +16,9 @@ namespace tvm { ...@@ -16,7 +16,9 @@ namespace tvm {
* \param src The source expression * \param src The source expression
* \return the simplified expression. * \return the simplified expression.
*/ */
Expr Simplify(const Expr& src); inline Expr Simplify(Expr src) {
return src;
}
/*! /*!
* \brief visit the exression node in expr tree in post DFS order. * \brief visit the exression node in expr tree in post DFS order.
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/domain.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
namespace tvm {
Range::Range(Expr begin, Expr end) {
node_ = std::make_shared<RangeNode>(
std::move(begin), std::move(end));
}
Expr Range::extent() const {
return Simplify(end() - begin());
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
index.push_back(Var("reduction_index"));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
...@@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const { ...@@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const {
os << ')'; os << ')';
return; return;
} }
case kReduceNode: {
const auto* n = Get<ReduceNode>();
os << "reduce("<< n->op->FunctionName() << ", ";
n->src.Print(os);
os << ", " << n->rdom << ')';
return;
}
default: { default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name(); LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
} }
......
...@@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode); ...@@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode); TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode); TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode); TVM_REGISTER_NODE_TYPE(BinaryOpNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
} // namespace tvm } // namespace tvm
...@@ -11,6 +11,16 @@ TEST(Expr, Basic) { ...@@ -11,6 +11,16 @@ TEST(Expr, Basic) {
CHECK(os.str() == "max(((x + 1) + 2), 100)"); CHECK(os.str() == "max(((x + 1) + 2), 100)");
} }
TEST(Expr, Reduction) {
using namespace tvm;
Var x("x");
RDomain rdom({{0, 3}});
auto z = sum(x + 1 + 2, rdom);
std::ostringstream os;
os << z;
CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))");
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
...@@ -7,11 +8,6 @@ TEST(Tensor, Basic) { ...@@ -7,11 +8,6 @@ TEST(Tensor, Basic) {
Var m, n, k; Var m, n, k;
Tensor A({m, k}); Tensor A({m, k});
Tensor B({n, k}); Tensor B({n, k});
auto x = [=](Var i, Var j, Var k) {
return A(i, k) * B(j, k);
};
auto C = Tensor({m, n}, x);
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment