Commit 56e10eb0 by tqchen

Tensor API

parent 5f829774
......@@ -23,10 +23,10 @@ class ArrayNode : public Node {
return "ArrayNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
LOG(FATAL) << "need to specially handle list";
LOG(FATAL) << "need to specially handle list attrs";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
LOG(FATAL) << "need to specially handle list";
// Do nothing, specially handled
}
};
......
......@@ -141,7 +141,7 @@ struct BinaryOpNode : public ExprNode {
}
};
/*! \brief Binary mapping operator */
/*! \brief Reduction operator operator */
struct ReduceNode : public ExprNode {
public:
/*! \brief The operator */
......@@ -178,6 +178,43 @@ struct ReduceNode : public ExprNode {
}
};
/*! \brief Tensor read operator */
struct TensorReadNode : public ExprNode {
public:
/*! \brief The tensor to be read from */
Tensor tensor;
/*! \brief The indices of read */
Array<Expr> indices;
/*! \brief constructor, do not use constructor */
TensorReadNode() {
node_type_ = kTensorReadNode;
}
TensorReadNode(Tensor && tensor, Array<Expr> && indices)
: tensor(std::move(tensor)), indices(std::move(indices)) {
node_type_ = kReduceNode;
dtype_ = tensor.dtype();
}
~TensorReadNode() {
this->Destroy();
}
const char* type_key() const override {
return "TensorReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, tensor.dtype());
for (size_t i = 0; i < indices.size(); ++i) {
CHECK_EQ(indices[i].dtype(), kInt32);
}
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("tensor", &tensor);
fvisit("indices", &indices);
}
};
} // namespace tvm
#endif // TVM_EXPR_NODE_H_
......@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <string>
#include <type_traits>
#include "./expr.h"
#include "./array.h"
......@@ -19,15 +20,14 @@ class TensorNode : public Node {
std::string name;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The index on each dimension */
/*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */
Expr source;
/*! \brief constructor */
TensorNode() {
}
TensorNode() {}
const char* type_key() const override {
return "TensorNode";
}
......@@ -42,20 +42,104 @@ class TensorNode : public Node {
}
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr (Var x)> f) {
return [f](const Array<Var>& i) { return f(i[0]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
}
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public NodeRef {
public:
explicit Tensor(Array<Expr> shape);
inline size_t ndim() const;
/*! \brief default constructor, used internally */
Tensor() {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
DataType dtype = kFloat32);
/*!
* \brief constructor of intermediate result.
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same constructor, specialized for different fcompute function
Tensor(Array<Expr> shape, std::function<Expr(Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
Tensor(Array<Expr> shape, std::function<Expr(Var, Var, Var, Var)> f, std::string name = "tensor")
:Tensor(shape, GetFCompute(f), name) {}
/*! \return The dimension of the tensor */
inline size_t ndim() const {
return static_cast<const TensorNode*>(node_.get())->shape.size();
}
/*! \return The name of the tensor */
inline const std::string& name() const {
return static_cast<const TensorNode*>(node_.get())->name;
}
/*! \return The data type tensor */
inline DataType dtype() const {
return static_cast<const TensorNode*>(node_.get())->dtype;
}
/*! \return The source expression of intermediate tensor */
inline const Expr& source() const {
return static_cast<const TensorNode*>(node_.get())->source;
}
/*! \return The internal dimension index used by source expression */
inline const Array<Var>& dim_index() const {
return static_cast<const TensorNode*>(node_.get())->dim_index;
}
/*! \return The shape of the tensor */
inline const Array<Expr>& shape() const {
return static_cast<const TensorNode*>(node_.get())->shape;
}
/*!
* \brief Take elements from the tensor
* \param args The indices
* \return the result expression representing tensor read.
*/
template<typename... Args>
inline Expr operator()(Args&& ...args) const {
Array<Expr> indices{std::forward<Args>(args)...};
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read";
return Expr{};
return operator()(indices);
}
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
Expr operator()(Array<Expr> indices) const;
// printt function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t.shape()
<< ", source=" << t.source()
<< ", name=" << t.name() << ')';
return os;
}
};
} // namespace tvm
#endif // TVM_TENSOR_H_
......@@ -22,7 +22,9 @@ Expr Range::extent() const {
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
index.push_back(Var("reduction_index"));
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
......
......@@ -55,6 +55,11 @@ void Expr::Print(std::ostream& os) const {
os << ", " << n->rdom << ')';
return;
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor.name() << n->indices;
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
......
......@@ -43,5 +43,6 @@ TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(TensorReadNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.cc
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <memory>
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, DataType dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->shape = std::move(shape);
size_t ndim = node->shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_index" << i;
dim_index.push_back(Var(os.str()));
}
node->dim_index = Array<Var>(dim_index);
node->source = fcompute(node->dim_index);
node->dtype = node->source.dtype();
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
auto node = std::make_shared<TensorReadNode>();
node->tensor = *this;
node->indices = std::move(indices);
return Expr(std::move(node));
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
......@@ -5,9 +5,14 @@
TEST(Tensor, Basic) {
using namespace tvm;
Var m, n, k;
Tensor A({m, k});
Tensor B({n, k});
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
RDomain rd({{0, l}});
auto C = Tensor({m, n}, [&](Var i, Var j) {
return sum(A(i, rd.i0()) * B(j, rd.i0()), rd);
}, "C");
}
int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment