Commit d903342b by tqchen

Checkin Schedule and split construction in front-end

parent ed99ddc7
Subproject commit 2b3ea8f5207152340014fd0a1ab12816ac48c326 Subproject commit ec84af1359c841df622f683048968348381e328a
...@@ -28,6 +28,9 @@ enum AttachType : int { ...@@ -28,6 +28,9 @@ enum AttachType : int {
/*! \brief schedule container */ /*! \brief schedule container */
class Schedule : public NodeRef { class Schedule : public NodeRef {
public: public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Schedule(Tensor tensor, std::string scope);
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -38,6 +41,8 @@ class Schedule : public NodeRef { ...@@ -38,6 +41,8 @@ class Schedule : public NodeRef {
/*! \brief schedule container */ /*! \brief schedule container */
class AttachSpec : public NodeRef { class AttachSpec : public NodeRef {
public: public:
AttachSpec() {}
explicit AttachSpec(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -59,15 +64,13 @@ class AttachSpecNode : public Node { ...@@ -59,15 +64,13 @@ class AttachSpecNode : public Node {
Split attach_split; Split attach_split;
/*! \brief the child schedule to be attached. */ /*! \brief the child schedule to be attached. */
Schedule schedule; Schedule schedule;
const char* type_key() const override { const char* type_key() const final {
return "AttachSpecNode"; return "AttachSpec";
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* v) final {
visitor->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
} v->Visit("attach_split", &attach_split);
void VisitNodeRefFields(FNodeRefVisit fvisit) override { v->Visit("schedule", &schedule);
fvisit("attach_split", &attach_split);
fvisit("schedule", &schedule);
} }
}; };
...@@ -82,16 +85,14 @@ class ScheduleNode : public Node { ...@@ -82,16 +85,14 @@ class ScheduleNode : public Node {
Array<Split> splits; Array<Split> splits;
/*! \brief attach specifications */ /*! \brief attach specifications */
Array<AttachSpec> attachs; Array<AttachSpec> attachs;
const char* type_key() const override { const char* type_key() const final {
return "AttachSpecNode"; return "Schedule";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("scope", &scope);
} }
void VisitNodeRefFields(FNodeRefVisit fvisit) override { void VisitAttrs(AttrVisitor* v) final {
fvisit("tensor", &tensor); v->Visit("scope", &scope);
fvisit("splits", &splits); v->Visit("tensor", &tensor);
fvisit("attachs", &attachs); v->Visit("splits", &splits);
v->Visit("attachs", &attachs);
} }
}; };
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#define TVM_SPLIT_H_ #define TVM_SPLIT_H_
#include "./base.h" #include "./base.h"
#include "./array.h"
#include "./domain.h" #include "./domain.h"
namespace tvm { namespace tvm {
...@@ -20,8 +19,7 @@ class Split : public NodeRef { ...@@ -20,8 +19,7 @@ class Split : public NodeRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Split() {} Split() {}
/*! \return Whether the split is over RDomain or not */ explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {}
inline bool is_over_rdom() const;
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -37,38 +35,29 @@ class Split : public NodeRef { ...@@ -37,38 +35,29 @@ class Split : public NodeRef {
class SplitNode : public Node { class SplitNode : public Node {
public: public:
/*! \brief whether the split is over reduction domain*/ /*! \brief whether the split is over reduction domain*/
int split_over_rdom{0}; bool split_over_rdom{false};
/*!
* \brief given the output domain, infer input domain
* \param split_index The index to be splitted on
* \param out_domain The outer domain
* \return The inferred inner domain.
*/
virtual Domain InferInnerDomain(Expr split_index, Domain out_domain) const = 0;
}; };
/*! \brief simple split node that splits over one dimension */ /*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode { class DimSplitNode : public SplitNode {
public: public:
/*! \brief The dimension to split on */ /*! \brief The dimension to split on */
int64_t dim_index; int dim_index;
/*! \brief The factor of the split */ /*! \brief The factor of the split */
Expr factor; Expr factor;
/*! \brief constructor */ /*! \brief constructor */
DimSplitNode() {} DimSplitNode() {}
const char* type_key() const override { const char* type_key() const final {
return "DimSplitNode"; return "DimSplit";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("split_over_rdom", &split_over_rdom);
} }
void VisitNodeRefFields(FNodeRefVisit fvisit) override { void VisitAttrs(AttrVisitor* v) final {
fvisit("factor", &factor); v->Visit("split_over_rdom", &split_over_rdom);
} v->Visit("dim_index", &dim_index);
Domain InferInnerDomain(Expr split_index, Domain out_domain) const override { v->Visit("factor", &factor);
LOG(FATAL) << "not implemented";
return Domain();
} }
static Split make(int dim_index,
Expr factor,
bool over_rdom);
}; };
// Implementations of inline functions // Implementations of inline functions
...@@ -76,9 +65,5 @@ inline const SplitNode* Split::operator->() const { ...@@ -76,9 +65,5 @@ inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get()); return static_cast<const SplitNode*>(node_.get());
} }
inline bool Split::is_over_rdom() const {
return (*this)->split_over_rdom != 0;
}
} // namespace tvm } // namespace tvm
#endif // TVM_SPLIT_H_ #endif // TVM_SPLIT_H_
...@@ -115,7 +115,7 @@ class TensorNode : public Node { ...@@ -115,7 +115,7 @@ class TensorNode : public Node {
Expr source; Expr source;
/*! \brief constructor */ /*! \brief constructor */
TensorNode() {} TensorNode() {}
const char* type_key() const override { const char* type_key() const final {
return "Tensor"; return "Tensor";
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
......
"""C++ backend related python scripts""" """C++ backend related python scripts"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node from ._ctypes._api import register_node
from . import tensor as tensor
from . import expr from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import collections from . import collections
from . import tensor from . import schedule
from .function import *
...@@ -62,7 +62,7 @@ class NodeBase(object): ...@@ -62,7 +62,7 @@ class NodeBase(object):
self.handle = handle self.handle = handle
def __repr__(self): def __repr__(self):
return _function_internal.format_str(self) return _function_internal._format_str(self)
def __del__(self): def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle)) check_call(_LIB.TVMNodeFree(self.handle))
......
...@@ -122,22 +122,6 @@ def sum(expr, rdom): ...@@ -122,22 +122,6 @@ def sum(expr, rdom):
x = _make.Reduce("Add", expr, rdom) x = _make.Reduce("Add", expr, rdom)
return x return x
def sum(expr, rdom):
"""Create a sum expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Add", expr, rdom)
return x
def min(expr, rdom): def min(expr, rdom):
"""Create a min expression over rdom """Create a min expression over rdom
...@@ -154,7 +138,6 @@ def min(expr, rdom): ...@@ -154,7 +138,6 @@ def min(expr, rdom):
x = _make.Reduce("Min", expr, rdom) x = _make.Reduce("Min", expr, rdom)
return x return x
def max(expr, rdom): def max(expr, rdom):
"""Create a min expression over rdom """Create a min expression over rdom
...@@ -172,4 +155,12 @@ def max(expr, rdom): ...@@ -172,4 +155,12 @@ def max(expr, rdom):
return x return x
def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope)
def Split(dim, factor, over_rdom=False):
return _function_internal._DimSplit(dim, factor, over_rdom)
_init_function_module("tvm") _init_function_module("tvm")
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class DimSplit(NodeBase):
pass
@register_node
class AttachSpec(NodeBase):
pass
@register_node
class Schedule(NodeBase):
pass
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/domain.h> #include <tvm/domain.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
...@@ -73,6 +75,7 @@ TVM_REGISTER_API(Range) ...@@ -73,6 +75,7 @@ TVM_REGISTER_API(Range)
*ret = Range(args.at(0), args.at(1)); *ret = Range(args.at(0), args.at(1));
} }
}) })
.describe("create a domain range")
.add_argument("begin", "Expr", "beginning of the range.") .add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range"); .add_argument("end", "Expr", "extent of the range");
...@@ -90,4 +93,14 @@ TVM_REGISTER_API(_RDomain) ...@@ -90,4 +93,14 @@ TVM_REGISTER_API(_RDomain)
*ret = RDomain(args.at(0).operator Domain()); *ret = RDomain(args.at(0).operator Domain());
}); });
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1), args.at(2));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1));
});
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
*/
#include <tvm/schedule.h>
namespace tvm {
Schedule::Schedule(Tensor tensor, std::string scope) {
auto n = std::make_shared<ScheduleNode>();
n->tensor = tensor;
n->scope = scope;
node_ = n;
}
TVM_REGISTER_NODE_TYPE(AttachSpecNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace tvm {
Split DimSplitNode::make(int dim_index,
Expr factor,
bool over_rdom) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->split_over_rdom = over_rdom;
n->dim_index = dim_index;
n->factor = factor;
return Split(n);
}
TVM_REGISTER_NODE_TYPE(DimSplitNode);
} // namespace tvm
...@@ -9,7 +9,7 @@ def test_make(): ...@@ -9,7 +9,7 @@ def test_make():
x = tvm.const(1) x = tvm.const(1)
y = tvm.make.IntImm('int32', 1) y = tvm.make.IntImm('int32', 1)
z = x + y z = x + y
print(tvm.format_str(z)) print(z)
def test_ir(): def test_ir():
x = tvm.const(1) x = tvm.const(1)
...@@ -22,7 +22,7 @@ def test_basic(): ...@@ -22,7 +22,7 @@ def test_basic():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
c = a + b c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) assert str(c) == '(%s + %s)' % (a.name, b.name)
def test_array(): def test_array():
a = tvm.convert([1,2,3]) a = tvm.convert([1,2,3])
......
import tvm
def test_schedule_create():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared")
tk1 = tvm.Split(0, 10)
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
print(sch.scope)
print(sch.attachs)
if __name__ == "__main__":
test_schedule_create()
...@@ -7,7 +7,8 @@ def test_tensor(): ...@@ -7,7 +7,8 @@ def test_tensor():
A = tvm.Tensor((m, l), name='A') A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B') B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.source))
print(T.source)
assert(tuple(T.shape) == (m, n, l)) assert(tuple(T.shape) == (m, n, l))
assert(A.source is None) assert(A.source is None)
...@@ -20,7 +21,7 @@ def test_tensor_reduce(): ...@@ -20,7 +21,7 @@ def test_tensor_reduce():
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1])) rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.Tensor((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd)) C = tvm.Tensor((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(tvm.format_str(C.source)) print(C.source)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment