Commit d903342b by tqchen

Checkin Schedule and split construction in front-end

parent ed99ddc7
Subproject commit 2b3ea8f5207152340014fd0a1ab12816ac48c326
Subproject commit ec84af1359c841df622f683048968348381e328a
......@@ -28,6 +28,9 @@ enum AttachType : int {
/*! \brief schedule container */
class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Schedule(Tensor tensor, std::string scope);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -38,6 +41,8 @@ class Schedule : public NodeRef {
/*! \brief schedule container */
class AttachSpec : public NodeRef {
public:
AttachSpec() {}
explicit AttachSpec(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -59,15 +64,13 @@ class AttachSpecNode : public Node {
Split attach_split;
/*! \brief the child schedule to be attached. */
Schedule schedule;
const char* type_key() const override {
return "AttachSpecNode";
const char* type_key() const final {
return "AttachSpec";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("attach_type", &attach_type);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("attach_split", &attach_split);
fvisit("schedule", &schedule);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("attach_type", &attach_type);
v->Visit("attach_split", &attach_split);
v->Visit("schedule", &schedule);
}
};
......@@ -82,16 +85,14 @@ class ScheduleNode : public Node {
Array<Split> splits;
/*! \brief attach specifications */
Array<AttachSpec> attachs;
const char* type_key() const override {
return "AttachSpecNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("scope", &scope);
const char* type_key() const final {
return "Schedule";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("tensor", &tensor);
fvisit("splits", &splits);
fvisit("attachs", &attachs);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("tensor", &tensor);
v->Visit("splits", &splits);
v->Visit("attachs", &attachs);
}
};
......
......@@ -7,7 +7,6 @@
#define TVM_SPLIT_H_
#include "./base.h"
#include "./array.h"
#include "./domain.h"
namespace tvm {
......@@ -20,8 +19,7 @@ class Split : public NodeRef {
public:
/*! \brief default constructor */
Split() {}
/*! \return Whether the split is over RDomain or not */
inline bool is_over_rdom() const;
explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -37,38 +35,29 @@ class Split : public NodeRef {
class SplitNode : public Node {
public:
/*! \brief whether the split is over reduction domain*/
int split_over_rdom{0};
/*!
* \brief given the output domain, infer input domain
* \param split_index The index to be splitted on
* \param out_domain The outer domain
* \return The inferred inner domain.
*/
virtual Domain InferInnerDomain(Expr split_index, Domain out_domain) const = 0;
bool split_over_rdom{false};
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The dimension to split on */
int64_t dim_index;
int dim_index;
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
const char* type_key() const override {
return "DimSplitNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("split_over_rdom", &split_over_rdom);
const char* type_key() const final {
return "DimSplit";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("factor", &factor);
}
Domain InferInnerDomain(Expr split_index, Domain out_domain) const override {
LOG(FATAL) << "not implemented";
return Domain();
void VisitAttrs(AttrVisitor* v) final {
v->Visit("split_over_rdom", &split_over_rdom);
v->Visit("dim_index", &dim_index);
v->Visit("factor", &factor);
}
static Split make(int dim_index,
Expr factor,
bool over_rdom);
};
// Implementations of inline functions
......@@ -76,9 +65,5 @@ inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get());
}
inline bool Split::is_over_rdom() const {
return (*this)->split_over_rdom != 0;
}
} // namespace tvm
#endif // TVM_SPLIT_H_
......@@ -115,7 +115,7 @@ class TensorNode : public Node {
Expr source;
/*! \brief constructor */
TensorNode() {}
const char* type_key() const override {
const char* type_key() const final {
return "Tensor";
}
void VisitAttrs(AttrVisitor* v) final {
......
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import tensor as tensor
from . import expr
from . import stmt
from . import make
from . import collections
from . import tensor
from . import schedule
from .function import *
......@@ -62,7 +62,7 @@ class NodeBase(object):
self.handle = handle
def __repr__(self):
return _function_internal.format_str(self)
return _function_internal._format_str(self)
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
......
......@@ -122,22 +122,6 @@ def sum(expr, rdom):
x = _make.Reduce("Add", expr, rdom)
return x
def sum(expr, rdom):
"""Create a sum expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Add", expr, rdom)
return x
def min(expr, rdom):
"""Create a min expression over rdom
......@@ -154,7 +138,6 @@ def min(expr, rdom):
x = _make.Reduce("Min", expr, rdom)
return x
def max(expr, rdom):
"""Create a min expression over rdom
......@@ -172,4 +155,12 @@ def max(expr, rdom):
return x
def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope)
def Split(dim, factor, over_rdom=False):
return _function_internal._DimSplit(dim, factor, over_rdom)
_init_function_module("tvm")
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class DimSplit(NodeBase):
pass
@register_node
class AttachSpec(NodeBase):
pass
@register_node
class Schedule(NodeBase):
pass
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import make as _make
from . import expr as _expr
......
......@@ -6,6 +6,8 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/domain.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
......@@ -73,6 +75,7 @@ TVM_REGISTER_API(Range)
*ret = Range(args.at(0), args.at(1));
}
})
.describe("create a domain range")
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range");
......@@ -90,4 +93,14 @@ TVM_REGISTER_API(_RDomain)
*ret = RDomain(args.at(0).operator Domain());
});
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1), args.at(2));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1));
});
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
*/
#include <tvm/schedule.h>
namespace tvm {
Schedule::Schedule(Tensor tensor, std::string scope) {
auto n = std::make_shared<ScheduleNode>();
n->tensor = tensor;
n->scope = scope;
node_ = n;
}
TVM_REGISTER_NODE_TYPE(AttachSpecNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace tvm {
Split DimSplitNode::make(int dim_index,
Expr factor,
bool over_rdom) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->split_over_rdom = over_rdom;
n->dim_index = dim_index;
n->factor = factor;
return Split(n);
}
TVM_REGISTER_NODE_TYPE(DimSplitNode);
} // namespace tvm
......@@ -9,7 +9,7 @@ def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
print(tvm.format_str(z))
print(z)
def test_ir():
x = tvm.const(1)
......@@ -22,7 +22,7 @@ def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
assert str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.convert([1,2,3])
......
import tvm
def test_schedule_create():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T, scope="shared")
tk1 = tvm.Split(0, 10)
assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit)
print(sch.scope)
print(sch.attachs)
if __name__ == "__main__":
test_schedule_create()
......@@ -7,7 +7,8 @@ def test_tensor():
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.source))
print(T.source)
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
......@@ -20,7 +21,7 @@ def test_tensor_reduce():
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.Tensor((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(tvm.format_str(C.source))
print(C.source)
if __name__ == "__main__":
test_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment