Commit e011edd1 by tqchen

Add Tensor, cleanup test, all present tests pass

parent f52d0713
Subproject commit 84a568ce86ca64ff4e186b78745152061499cbf4 Subproject commit f72e313118a61b0cc49987b9eebfc77300d2de0d
...@@ -7,13 +7,27 @@ ...@@ -7,13 +7,27 @@
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IROperator.h>
#include <type_traits> #include <type_traits>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
using Halide::Type; using Halide::Type;
using Halide::Float;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;
using Halide::Expr; using Halide::Expr;
using Var = Halide::VarExpr;
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
...@@ -9,9 +9,10 @@ ...@@ -9,9 +9,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
#include <tvm/array.h>
#include <ir/FunctionBase.h>
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./array.h"
namespace tvm { namespace tvm {
...@@ -35,11 +36,13 @@ inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) { ...@@ -35,11 +36,13 @@ inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
} }
using Halide::IR::FunctionRef;
/*! /*!
* \brief Tensor structure representing a possible input, * \brief Tensor structure representing a possible input,
* or intermediate computation result. * or intermediate computation result.
*/ */
class Tensor : public NodeRef { class Tensor : public FunctionRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
...@@ -51,7 +54,7 @@ class Tensor : public NodeRef { ...@@ -51,7 +54,7 @@ class Tensor : public NodeRef {
*/ */
explicit Tensor(Array<Expr> shape, explicit Tensor(Array<Expr> shape,
std::string name = "tensor", std::string name = "tensor",
DataType dtype = kFloat32); Type dtype = Float(32));
/*! /*!
* \brief constructor of intermediate result. * \brief constructor of intermediate result.
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
...@@ -91,10 +94,6 @@ class Tensor : public NodeRef { ...@@ -91,10 +94,6 @@ class Tensor : public NodeRef {
* \return the result expression representing tensor read. * \return the result expression representing tensor read.
*/ */
Expr operator()(Array<Expr> indices) const; Expr operator()(Array<Expr> indices) const;
/*! \return list of input tensors to this tensor */
std::vector<Tensor> InputTensors() const;
/*! \return whether the tensor stores a result of reduction */
bool IsRTensor() const;
// overload print function // overload print function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t); friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
}; };
...@@ -105,9 +104,9 @@ class TensorNode : public Node { ...@@ -105,9 +104,9 @@ class TensorNode : public Node {
/*! \brief optional name of the tensor */ /*! \brief optional name of the tensor */
std::string name; std::string name;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
DataType dtype; Type dtype;
/*! \brief The index representing each dimension, used by source expression. */ /*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_index; Array<Var> dim_var;
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
/*! \brief source expression */ /*! \brief source expression */
...@@ -115,16 +114,15 @@ class TensorNode : public Node { ...@@ -115,16 +114,15 @@ class TensorNode : public Node {
/*! \brief constructor */ /*! \brief constructor */
TensorNode() {} TensorNode() {}
const char* type_key() const override { const char* type_key() const override {
return "TensorNode"; return "Tensor";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype);
} }
void VisitNodeRefFields(FNodeRefVisit fvisit) override { void VisitAttrs(AttrVisitor* v) final {
fvisit("dim_index", &dim_index); v->Visit("name", &name);
fvisit("shape", &shape); v->Visit("dtype", &dtype);
fvisit("source", &source); v->Visit("dim_var", &dim_var);
v->Visit("shape", &shape);
v->Visit("source", &source);
} }
}; };
......
...@@ -8,10 +8,6 @@ ...@@ -8,10 +8,6 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./op.h"
#include "./tensor.h" #include "./tensor.h"
#include "./domain.h"
#include "./array.h"
#include "./expr_util.h"
#endif // TVM_TVM_H_ #endif // TVM_TVM_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.cc
*/
#include <tvm/tensor.h>
#include <ir/IR.h>
#include <memory>
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->shape = std::move(shape);
size_t ndim = node->shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_index" << i;
dim_index.push_back(Var(os.str()));
}
node->dim_var = Array<Var>(dim_index);
node->source = fcompute(node->dim_var);
node->dtype = node->source.type();
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
return Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
...@@ -11,43 +11,6 @@ TEST(Expr, Basic) { ...@@ -11,43 +11,6 @@ TEST(Expr, Basic) {
CHECK(os.str() == "max(((x + 1) + 2), 100)"); CHECK(os.str() == "max(((x + 1) + 2), 100)");
} }
TEST(Expr, Reduction) {
using namespace tvm;
Var x("x");
RDomain rdom({{0, 3}});
auto z = sum(x + 1 + 2, rdom);
std::ostringstream os;
os << z;
CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))");
}
TEST(Expr, Simplify) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, x + 10) * 100;
std::ostringstream os;
os << Simplify(z);
CHECK(os.str() == "((x * 100) + 1000)");
}
TEST(Expr, Bind) {
using namespace tvm;
Var x("x"), y("y"), z("z");
Var i("i"), j("j");
Tensor A({y, z}, "A");
Expr e1 = x * 5;
std::unordered_map<Expr, Expr> dict = {{x, y * 10 + z}};
std::ostringstream os1, os2;
os1 << Bind(e1, dict);
CHECK(os1.str() == "(((y * 10) + z) * 5)");
Expr e2 = A(i, j);
dict.clear();
dict[i] = 64 * x;
dict[j] = z + 16 * y;
os2 << Bind(e2, dict);
CHECK(os2.str() == "A[(64 * x), (z + (16 * y))]");
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -8,16 +8,10 @@ TEST(Tensor, Basic) { ...@@ -8,16 +8,10 @@ TEST(Tensor, Basic) {
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A"); Tensor A({m, l}, "A");
Tensor B({n, l}, "B"); Tensor B({n, l}, "B");
RDomain rd({{0, l}});
auto C = Tensor({m, n}, [&](Var i, Var j) { auto C = Tensor({m, n}, [&](Var i, Var j) {
return sum(A(i, rd.i0()) * B(j, rd.i0()), rd); return A(i, j) * B(j, i);
}, "C"); }, "C");
auto inputs = C.InputTensors();
CHECK(inputs[0] == A);
CHECK(inputs[1] == B);
CHECK(C.IsRTensor());
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
...@@ -4,7 +4,7 @@ GTEST_INC=$(GTEST_PATH)/include/ ...@@ -4,7 +4,7 @@ GTEST_INC=$(GTEST_PATH)/include/
TEST_SRC = $(wildcard tests/cpp/*_test.cc) TEST_SRC = $(wildcard tests/cpp/*_test.cc)
TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC)) TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC))
tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.a tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.a HalideIR/lib/libHalideIR.a
$(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d $(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d
$(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \ $(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \
-L$(GTEST_LIB) $(LDFLAGS) -lgtest -L$(GTEST_LIB) $(LDFLAGS) -lgtest
......
import tvm
def test_range_infer():
x = tvm.Var('x')
y = tvm.Var('y')
t = tvm.Var('t')
z = x + y + t
zr = tvm.infer_range(z, {x: tvm.Range(10, 20), y : tvm.Range(10, 11)})
assert str(zr) == "((t0 + 20), (t0 + 30))"
def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
rd = tvm.RDom(tvm.Range(A.shape[1]))
T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert T.is_rtensor
assert str(tdom[0]) == "(0, 10)"
if __name__ == "__main__":
test_range_infer()
test_tensor_dom_infer()
import tvm
def test_schedule():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
rd = tvm.RDom(tvm.Range(A.shape[1]))
T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]), name="T")
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0]), name="C")
bufA = tvm.Buffer(tvm.Scope.Thread, name='A')
bufB = tvm.Buffer(tvm.Scope.Thread, name='B')
bufT = tvm.Buffer(tvm.Scope.Thread, name='T')
schA = tvm.Schedule(A, buffer=bufA)
schB = tvm.Schedule(B, buffer=bufB)
schT = tvm.Schedule(T, buffer=bufT)
schC = tvm.Schedule(C)
Cx0 = tvm.Split(dim=0, factor=64)
Cy0 = tvm.Split(dim=1, factor=64)
Cx1 = tvm.Split(dim=0, factor=8)
Cy1 = tvm.Split(dim=1, factor=8)
Tk = tvm.Split(dim=0, factor=8, rdom=True)
schC.add_split(Cx0)
schC.add_split(Cy0)
schC.add_split(Cx1)
schC.add_split(Cy1)
schT.add_split(Tk)
schC.attach(Cy1, schT)
schT.attach(Tk, schA)
schT.attach(Tk, schB)
body = schC.realize()
print('\n'.join(body))
if __name__ == "__main__":
test_schedule()
import tvm
def test_split_dom_infer():
A = tvm.Tensor(2, name='A')
split1 = tvm.Split(0, 64)
split2 = tvm.Split(1, 64)
split3 = tvm.Split(0, 8)
dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
dom1 = split1.infer_inner_domain(dom)
dom2 = split2.infer_inner_domain(dom1)
dom3 = split3.infer_inner_domain(dom2)
i1 = split1.loop_index.name
i2 = split2.loop_index.name
i3 = split3.loop_index.name
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
if __name__ == "__main__":
test_split_dom_infer()
import tvm
def test_tensor():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
print(tvm.format_str(T.expr))
def test_tensor_inputs():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
assert(T.input_tensors() == set([A, B]))
def test_tensor_reduce():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
print(tvm.format_str(C.expr))
if __name__ == "__main__":
test_tensor()
test_tensor_inputs()
test_tensor_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment