Commit 61de73b4 by tqchen

Finalize tensor and operation

parent 605813e4
......@@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>
#include "./expr.h"
#include "./schedule.h"
namespace tvm {
namespace ir {
......@@ -50,6 +51,14 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir
} // namespace tvm
......
......@@ -9,43 +9,15 @@
#include <string>
#include "./expr.h"
#include "./domain.h"
#include "./tensor.h"
namespace tvm {
// internal node container for Operation
class OperationNode;
/*! \brief Split over input domain */
class Operation : public NodeRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OperationNode* operator->() const;
};
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief optional name of the operation */
std::string name;
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -54,6 +26,12 @@ class ComputeOpNode : public OperationNode {
const char* type_key() const final {
return "ComputeOp";
}
size_t num_outputs() const final {
return 1;
}
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
......@@ -66,9 +44,43 @@ class ComputeOpNode : public OperationNode {
Expr body);
};
// Implementations of inline functions
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
}
} // namespace tvm
......
......@@ -14,12 +14,14 @@
#include "./base.h"
#include "./expr.h"
#include "./operation.h"
#include "./domain.h"
namespace tvm {
// Internal node container of Tensor
class TensorNode;
// internal node container for Operation
class OperationNode;
using Halide::IR::FunctionRef;
......@@ -68,57 +70,24 @@ class Tensor : public FunctionRef {
friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr(Var x)> f) {
return [f] (const Array<Var>& i) { return f(i[0]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
}
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
}
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
}
/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OperationNode* operator->() const;
/*!
* \brief get the i-th output of the operation.
* \param i the output index.
* \return The i-th output.
*/
Tensor output(size_t i) const;
};
/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
......@@ -158,7 +127,31 @@ class TensorNode : public FunctionBaseNode {
int value_index);
};
// implementations
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief optional name of the operation */
std::string name;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;
};
// Implementations of inline functions
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(node_.get());
}
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(node_.get());
......
......@@ -10,12 +10,9 @@
namespace tvm {
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
auto op_node = std::make_shared<ComputeOpNode>();
node->name = name;
node->shape = shape;
// compute dimension.
size_t ndim = node->shape.size();
size_t ndim = shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
......@@ -32,10 +29,8 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->dim_var);
op_node->name = name;
node->dtype = op_node->body.type();
node->op = Operation(op_node);
node->value_index = 0;
return Tensor(node);
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(Domain domain,
......@@ -50,6 +45,37 @@ Operation ComputeOpNode::make(Domain domain,
return Operation(n);
}
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0);
std::vector<Expr> shape;
for (size_t i = 0; i < domain.size(); ++i) {
shape.push_back(domain[i]->extent);
}
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule_ops.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(std::vector<Tensor> tensors)
: tensors_(tensors) {}
std::vector<Tensor> tensors_;
};
} // namespace
} // namespace ir
} // namespace tvm
......@@ -8,7 +8,7 @@ def test_tensor():
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(T.source)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
......@@ -21,7 +21,7 @@ def test_tensor_reduce():
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(C.source)
print(C.op.body)
if __name__ == "__main__":
test_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment