Commit 5079987e by tqchen

Enable array, basic form of tensor

parent 8278e02f
...@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ ...@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test .PHONY: clean all test doc
all: lib/libtvm.a all: lib/libtvm.a
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
...@@ -27,6 +27,9 @@ lib/libtvm.a: $(ALL_DEP) ...@@ -27,6 +27,9 @@ lib/libtvm.a: $(ALL_DEP)
lint: lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src python2 dmlc-core/scripts/lint.py tvm cpp include src
doc:
doxygen docs/Doxyfile
clean: clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
......
This source diff could not be displayed because it is too large. You can view the blob instead.
/*!
* Copyright (c) 2016 by Contributors
* \file array.h
* \brief Array container in the DSL graph.
*/
#ifndef TVM_ARRAY_H_
#define TVM_ARRAY_H_
#include <type_traits>
#include <vector>
#include <initializer_list>
#include "./base.h"
namespace tvm {
/*! \brief node content in array */
class ArrayNode : public Node {
public:
/*! \brief the data content */
std::vector<std::shared_ptr<Node> > data;
const char* type_key() const override {
return "ArrayNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
LOG(FATAL) << "need to specially handle list";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
LOG(FATAL) << "need to specially handle list";
}
};
/*!
* \brief Immutable array container of NodeRef in DSL graph.
* \tparam T The content NodeRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
class Array : public NodeRef {
public:
/*!
* \brief default constructor
*/
Array() {}
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T> &other) { // NOLINT(*)
node_ = other.node_;
}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T> & other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<ArrayNode>();
n->data.reserve(end - begin);
for (IterType i = begin; i < end; ++i) {
n->data.push_back(i->node_);
}
node_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline T operator[](size_t i) const {
T inst;
inst.node_ = static_cast<const ArrayNode*>(node_.get())->data[i];
return inst;
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
};
} // namespace tvm
#endif // TVM_ARRAY_H_
...@@ -123,10 +123,11 @@ class NodeRef { ...@@ -123,10 +123,11 @@ class NodeRef {
/*! \return wheyjer the expression is null */ /*! \return wheyjer the expression is null */
inline bool is_null() const; inline bool is_null() const;
protected:
template<typename T, typename>
friend class Array;
NodeRef() = default; NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {} explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
protected:
/*! \brief the internal node */ /*! \brief the internal node */
std::shared_ptr<Node> node_; std::shared_ptr<Node> node_;
}; };
......
...@@ -7,10 +7,16 @@ ...@@ -7,10 +7,16 @@
#define TVM_DOMAIN_H_ #define TVM_DOMAIN_H_
#include <memory> #include <memory>
#include "./base.h"
#include "./array.h"
#include "./expr.h"
namespace tvm { namespace tvm {
class RDom {
};
// using Domain = Array<Range>;
} // namespace tvm } // namespace tvm
......
...@@ -36,7 +36,7 @@ class UnaryOp { ...@@ -36,7 +36,7 @@ class UnaryOp {
* \param src left operand * \param src left operand
* \return the result expr * \return the result expr
*/ */
Expr operator()(Expr lhs, Expr rhs) const; Expr operator()(Expr src) const;
}; };
...@@ -111,6 +111,11 @@ DEFINE_OP_OVERLOAD(/, DivOp); ...@@ -111,6 +111,11 @@ DEFINE_OP_OVERLOAD(/, DivOp);
DEFINE_BINARY_OP_FUNCTION(max, MaxOp); DEFINE_BINARY_OP_FUNCTION(max, MaxOp);
DEFINE_BINARY_OP_FUNCTION(min, MinOp); DEFINE_BINARY_OP_FUNCTION(min, MinOp);
// overload negation
inline Expr operator-(Expr src) {
return src * (-1);
}
} // namespace tvm } // namespace tvm
#endif // TVM_OP_H_ #endif // TVM_OP_H_
...@@ -7,14 +7,54 @@ ...@@ -7,14 +7,54 @@
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include "./expr.h" #include "./expr.h"
#include "./array.h"
namespace tvm { namespace tvm {
class Tensor { /*! \brief Node to represent a tensor */
private: class TensorNode : public Node {
public:
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index on each dimension */
Array<Var> dim_index;
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */ /*! \brief source expression */
Expr src_expr; Expr source;
/*! \brief constructor */
TensorNode() {
}
const char* type_key() const override {
return "TensorNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("dim_index", &dim_index);
fvisit("shape", &shape);
fvisit("source", &source);
}
};
class Tensor : public NodeRef {
public:
Tensor(Array<Expr> shape);
Tensor(Array<Expr> shape, std::function<Expr (Var, Var, Var)> f3) {
}
inline size_t ndim() const;
template<typename... Args>
inline Expr operator()(Args&& ...args) const {
Array<Expr> indices{std::forward<Args>(args)...};
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read";
return Expr{};
}
}; };
......
...@@ -10,5 +10,7 @@ ...@@ -10,5 +10,7 @@
#include "./expr.h" #include "./expr.h"
#include "./op.h" #include "./op.h"
#include "./tensor.h" #include "./tensor.h"
#include "./domain.h"
#include "./array.h"
#endif // TVM_TVM_H_ #endif // TVM_TVM_H_
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
TEST(Array, Expr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Array<Expr> list{x, z, z};
LOG(INFO) << list.size();
LOG(INFO) << list[0];
LOG(INFO) << list[1];
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
TEST(Tensor, Basic) {
using namespace tvm;
Var m, n, k;
Tensor A({m, k});
Tensor B({n, k});
auto x = [=](Var i, Var j, Var k) {
return A(i, k) * B(j, k);
};
auto C = Tensor({m, n}, x);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment