Commit 0992873a by Tianqi Chen Committed by GitHub

[LANG] Include buffer semnatics, introduce pylint (#11)

* [LANG] Include buffer semnatics, introduce pylint

* Refactor inline add support for buffer indexing

* fix doc
parent 69a80cce
......@@ -37,7 +37,7 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src
python2 dmlc-core/scripts/lint.py tvm all include src python
doc:
doxygen docs/Doxyfile
......
Subproject commit f294fc2271b27b0b6e2b117003ed2dc3d3ba8fda
Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16
/*!
* Copyright (c) 2016 by Contributors
* \file buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer.
*/
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_
#include <tvm/container.h>
#include <string>
#include "./base.h"
#include "./expr.h"
namespace tvm {
// Internal node container Buffer
class BufferNode;
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program.
*/
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
};
/*! \brief Node to represent a buffer */
class BufferNode : public Node {
public:
/*! \brief optional name of the buffer */
std::string name;
/*! \brief The pointer to the head of the data */
Var ptr;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
// Maybe need more information(alignment) later
/*! \brief constructor */
BufferNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("ptr", &ptr);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
}
static Buffer make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
};
inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_BUFFER_H_
......@@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"
namespace tvm {
......@@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt);
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(FunctionRef f,
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
Expr body);
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
} // namespace ir
} // namespace tvm
......
# pylint: disable=redefined-builtin, wildcard-import
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from ._ctypes._api import register_node
......
# coding: utf-8
# pylint: disable=invalid-name
# pylint: disable=invalid-name, no-member
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
......
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
......@@ -14,6 +15,7 @@ from .._base import check_call, ctypes2docstring
from .. import _function_internal
class ArgVariant(ctypes.Union):
"""ArgVariant in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
......@@ -30,8 +32,8 @@ NODE_TYPE = {
def _return_node(x):
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
......@@ -47,7 +49,7 @@ RET_SWITCH = {
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: _return_node(x)
kNodeHandle: _return_node
}
class SliceBase(object):
......@@ -251,6 +253,7 @@ def register_node(type_key=None):
"""
if isinstance(type_key, str):
def register(cls):
"""internal register function"""
NODE_TYPE[type_key] = cls
return cls
return register
......@@ -273,9 +276,9 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace]
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}
for name in op_names:
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
......@@ -6,6 +7,7 @@ from . import expr as _expr
@register_node
class Array(NodeBase):
"""Array container of TVM"""
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
......@@ -19,6 +21,7 @@ class Array(NodeBase):
@register_node
class Map(NodeBase):
"""Map container of TVM"""
def __getitem__(self, k):
return _function_internal._MapGetItem(self, k)
......@@ -26,6 +29,7 @@ class Map(NodeBase):
return _function_internal._MapCount(self, k) != 0
def items(self):
"""Get the items from the map"""
akvs = _function_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
......@@ -38,9 +42,17 @@ class Map(NodeBase):
@register_node
class Range(NodeBase):
"""Represent range in TVM"""
pass
@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable."""
pass
@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
......@@ -174,7 +175,7 @@ class Call(Expr):
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
pass
@register_node
class Let(Expr):
......
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert
from . import _function_internal
from . import make as _make
......@@ -8,6 +11,7 @@ from . import collections as _collections
int32 = "int32"
float32 = "float32"
handle = "handle"
def const(value, dtype=None):
"""construct a constant"""
......@@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="placeholder"):
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
......@@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"):
tensor: tensor.Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder(
shape, dtype, name)
......@@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"):
tensor: tensor.Tensor
The created tensor
"""
if isinstance(shape, _expr.Expr):
shape = (shape, )
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape)
arg_names = fcompute.__code__.co_varnames
......@@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"):
op_node = _function_internal._ComputeOp(
name, dim_var, body)
return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0)
shape, body.dtype, op_node, 0)
def Buffer(shape, dtype=None,
name="buffer", ptr=None,
strides=None):
"""Create a new buffer
Parameters
----------
shape : tuple of Expr
The shape of the buffer.
dtype : str, optional
The data type of the buffer.
name : str, optional
The name of the buffer.
ptr : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
Returns
-------
buffer : Buffer
The created buffer
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if ptr is None:
ptr = Var(name, "handle")
return _function_internal._Buffer(
name, ptr, shape, strides, dtype)
def IterVar(dom, name='iter', thread_tag=''):
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
......@@ -6,15 +7,18 @@ from . import tensor as _tensor
@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass
@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass
@register_node
class Schedule(NodeBase):
"""Schedule for all the stages."""
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
k = k.op
......@@ -26,6 +30,7 @@ class Schedule(NodeBase):
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both
......@@ -132,6 +137,32 @@ class Stage(NodeBase):
_function_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions
The final loop order from outmost to inner most are
[x_outer, y_outer, x_inner, y_inner]
Parameters
----------
x_parent : IterVar
The original x dimension
y_parent : IterVar
The original y dimension
x_factor : Expr
The stride factor on x axis
y_factor : Expr The stride factor on y axis
Returns
-------
x_outer : IterVar
Outer axis of x dimension
y_outer : IterVar
Outer axis of y dimension
x_inner : IterVar
Inner axis of x dimension
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
class Stmt(NodeBase):
pass
......@@ -23,7 +23,6 @@ class For(Stmt):
Parallel = 1
Vectorized = 2
Unrolled = 3
pass
@register_node
class Store(Stmt):
......
# pylint: disable=protected-access, no-member, invalid-name
"""Tensor related abstractions"""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
from . import collections as _collections
......@@ -51,10 +53,12 @@ class Tensor(NodeBase):
@property
def ndim(self):
"""Dimension of the tensor."""
return len(self.shape)
class Operation(NodeBase):
"""Represent an operation that generate a tensor"""
def output(self, index):
"""Get the index-th output of the operation
......@@ -72,8 +76,10 @@ class Operation(NodeBase):
@register_node
class ComputeOp(Operation):
"""Compute operation."""
pass
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass
......@@ -12,6 +12,7 @@
namespace tvm {
inline std::string Type2String(const Type& t) {
if (t.code() ==Type::Handle) return "handle";
std::ostringstream os;
os << t;
return os.str();
......@@ -28,6 +29,8 @@ inline Type String2Type(std::string s) {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 0, 0);
} else {
LOG(FATAL) << "unknown type " << s;
}
......
......@@ -123,6 +123,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Load);
REGISTER_MAKE3(Store);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h>
#include "./c_api_registry.h"
......@@ -140,14 +141,23 @@ TVM_REGISTER_API(Range)
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range");
TVM_REGISTER_API(_Tensor)
TVM_REGISTER_API(_Buffer)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
*ret = BufferNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3));
});
TVM_REGISTER_API(_TensorEqual)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
......
/*!
* Copyright (c) 2016 by Contributors
* \file buffer.cc
*/
#include <tvm/buffer.h>
#include <tvm/ir.h>
namespace tvm {
Array<Expr> GetStrides(Array<Expr> shape) {
CHECK_NE(shape.size(), 0U);
std::vector<Expr> vec{make_const(shape[0].type(), 1)};
for (size_t i = shape.size() - 1; i != 0; --i) {
vec.push_back(shape[i - 1] * vec.back());
}
return Array<Expr>(vec.rbegin(), vec.rend());
}
Buffer::Buffer(Array<Expr> shape,
Type dtype,
std::string name)
: Buffer(BufferNode::make(
name,
Var(name, Type(Type::Handle, 0, 0)),
shape, Array<Expr>(), dtype)) {
}
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr base;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
base = index[0];
for (size_t i = 1; i < index.size(); ++i) {
base = base * n->shape[i] + index[i];
}
} else {
CHECK_EQ(n->strides.size(), index.size());
base = index[0] * n->strides[0];
for (size_t i = 1; i < index.size(); ++i) {
base = base + index[i] * n->strides[i];
}
}
return base;
}
Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->();
return ir::Load::make(n->dtype, n->ptr, BufferOffset(n, index));
}
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->ptr, BufferOffset(n, index), value);
}
Buffer BufferNode::make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype) {
auto n = std::make_shared<BufferNode>();
n->name = name;
n->ptr = ptr;
n->shape = shape;
n->strides = strides;
n->dtype = dtype;
return Buffer(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
p->stream << "buffer(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(BufferNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file operation.cc
......
......@@ -8,7 +8,6 @@
namespace tvm {
namespace ir {
namespace {
// inliner to inline a function
// the result may not be SSA,
......@@ -50,12 +49,10 @@ class IRInline : public IRMutator {
}
};
} // namespace
Stmt Inline(FunctionRef f,
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
......
......@@ -13,7 +13,6 @@
namespace tvm {
namespace ir {
namespace {
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
......@@ -256,7 +255,7 @@ Stmt MakePipeline(const Stage& sch,
if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
} else {
LOG(FATAL) << "not supported op";
LOG(FATAL) << "not supported op " << sch->op->type_key();
}
std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
Stmt producer = MergeNest(nest, provide);
......@@ -317,10 +316,9 @@ Stmt InjectInline(const Operation op, Stmt body) {
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
return Inline(op, args, compute->body, body);
return Inline(body, op, args, compute->body);
}
} // namespace
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) {
......@@ -328,6 +326,8 @@ Stmt ScheduleOps(
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
......
......@@ -151,8 +151,10 @@ BoundProp(const Array<Operation>& post_order,
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<PlaceholderOpNode>()) {
// do nothing
} else {
LOG(FATAL) << "unknown operation mode";
LOG(FATAL) << "unknown operation mode " << op->type_key();
}
}
return result;
......
......@@ -42,12 +42,13 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else if (op.as<PlaceholderOpNode>()) {
// empty set of deps
rmap.Set(op, deps);
} else {
if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
}
}
return rmap;
}
......@@ -56,7 +57,7 @@ void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
Array<Operation>* post_order) {
if (op.as<PlaceholderOpNode>() || visited->count(op)) return;
if (visited->count(op)) return;
visited->insert(op);
for (const auto& t : g.at(op)) {
PostDFSOrder(t->op, g, visited, post_order);
......
import tvm
def test_buffer():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
Ab = tvm.Buffer((m, n), tvm.float32)
Bb = tvm.Buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.collections.Buffer)
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
if __name__ == "__main__":
test_buffer()
......@@ -33,6 +33,7 @@ def test_tensor_reduce():
assert(isinstance(C_loaded, tvm.tensor.Tensor))
assert(str(C_loaded) == str(C))
if __name__ == "__main__":
test_tensor()
test_tensor_reduce()
......@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
T.op, [x.var for x in T.op.axis], T.op.body, stmt)
stmt, T.op, [x.var for x in T.op.axis], T.op.body)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......
......@@ -63,8 +63,8 @@ def test_create_read_graph():
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A1.op)
assert(post_order[1] == A2.op)
assert(post_order[0] == A.op)
assert(post_order[1] == A1.op)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment