Commit 0992873a by Tianqi Chen Committed by GitHub

[LANG] Include buffer semnatics, introduce pylint (#11)

* [LANG] Include buffer semnatics, introduce pylint

* Refactor inline add support for buffer indexing

* fix doc
parent 69a80cce
...@@ -37,7 +37,7 @@ LIBHALIDEIR: ...@@ -37,7 +37,7 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR) + cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
lint: lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src python2 dmlc-core/scripts/lint.py tvm all include src python
doc: doc:
doxygen docs/Doxyfile doxygen docs/Doxyfile
......
Subproject commit f294fc2271b27b0b6e2b117003ed2dc3d3ba8fda Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16
/*!
* Copyright (c) 2016 by Contributors
* \file buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer.
*/
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_
#include <tvm/container.h>
#include <string>
#include "./base.h"
#include "./expr.h"
namespace tvm {
// Internal node container Buffer
class BufferNode;
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program.
*/
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
};
/*! \brief Node to represent a buffer */
class BufferNode : public Node {
public:
/*! \brief optional name of the buffer */
std::string name;
/*! \brief The pointer to the head of the data */
Var ptr;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
// Maybe need more information(alignment) later
/*! \brief constructor */
BufferNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("ptr", &ptr);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
}
static Buffer make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
};
inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_BUFFER_H_
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "./expr.h" #include "./expr.h"
#include "./buffer.h"
#include "./schedule.h" #include "./schedule.h"
namespace tvm { namespace tvm {
...@@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt); ...@@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt);
* *
* \note All the passes in this file uses SSA form and outputs SSA form. * \note All the passes in this file uses SSA form and outputs SSA form.
*/ */
Stmt Inline(FunctionRef f, Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args, Array<Var> args,
Expr body, Expr body);
Stmt stmt);
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
# pylint: disable=redefined-builtin, wildcard-import
"""C++ backend related python scripts""" """C++ backend related python scripts"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import register_node from ._ctypes._api import register_node
......
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name # pylint: disable=invalid-name, no-member
""" ctypes library of nnvm and helper functions """ """ ctypes library of nnvm and helper functions """
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import os
import ctypes import ctypes
import numpy as np import numpy as np
from . import libinfo from . import libinfo
......
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API.""" """Symbolic configuration API."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
...@@ -14,6 +15,7 @@ from .._base import check_call, ctypes2docstring ...@@ -14,6 +15,7 @@ from .._base import check_call, ctypes2docstring
from .. import _function_internal from .. import _function_internal
class ArgVariant(ctypes.Union): class ArgVariant(ctypes.Union):
"""ArgVariant in C API"""
_fields_ = [("v_long", ctypes.c_long), _fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double), ("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p), ("v_str", ctypes.c_char_p),
...@@ -30,8 +32,8 @@ NODE_TYPE = { ...@@ -30,8 +32,8 @@ NODE_TYPE = {
def _return_node(x): def _return_node(x):
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p): if not isinstance(handle, NodeHandle):
handle = ctypes.c_void_p(handle) handle = NodeHandle(handle)
ret_val = ArgVariant() ret_val = ArgVariant()
ret_typeid = ctypes.c_int() ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
...@@ -47,7 +49,7 @@ RET_SWITCH = { ...@@ -47,7 +49,7 @@ RET_SWITCH = {
kLong: lambda x: x.v_long, kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double, kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str), kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: _return_node(x) kNodeHandle: _return_node
} }
class SliceBase(object): class SliceBase(object):
...@@ -251,6 +253,7 @@ def register_node(type_key=None): ...@@ -251,6 +253,7 @@ def register_node(type_key=None):
""" """
if isinstance(type_key, str): if isinstance(type_key, str):
def register(cls): def register(cls):
"""internal register function"""
NODE_TYPE[type_key] = cls NODE_TYPE[type_key] = cls
return cls return cls
return register return register
...@@ -273,9 +276,9 @@ def _init_function_module(root_namespace): ...@@ -273,9 +276,9 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace] module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = { namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace], "_make_": sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace] "_schedule_": sys.modules["%s.schedule" % root_namespace]
} }
for name in op_names: for name in op_names:
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
...@@ -6,6 +7,7 @@ from . import expr as _expr ...@@ -6,6 +7,7 @@ from . import expr as _expr
@register_node @register_node
class Array(NodeBase): class Array(NodeBase):
"""Array container of TVM"""
def __getitem__(self, i): def __getitem__(self, i):
if i >= len(self): if i >= len(self):
raise IndexError("array index out ot range") raise IndexError("array index out ot range")
...@@ -19,6 +21,7 @@ class Array(NodeBase): ...@@ -19,6 +21,7 @@ class Array(NodeBase):
@register_node @register_node
class Map(NodeBase): class Map(NodeBase):
"""Map container of TVM"""
def __getitem__(self, k): def __getitem__(self, k):
return _function_internal._MapGetItem(self, k) return _function_internal._MapGetItem(self, k)
...@@ -26,6 +29,7 @@ class Map(NodeBase): ...@@ -26,6 +29,7 @@ class Map(NodeBase):
return _function_internal._MapCount(self, k) != 0 return _function_internal._MapCount(self, k) != 0
def items(self): def items(self):
"""Get the items from the map"""
akvs = _function_internal._MapItems(self) akvs = _function_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
...@@ -38,9 +42,17 @@ class Map(NodeBase): ...@@ -38,9 +42,17 @@ class Map(NodeBase):
@register_node @register_node
class Range(NodeBase): class Range(NodeBase):
"""Represent range in TVM"""
pass pass
@register_node @register_node
class IterVar(NodeBase, _expr.ExprOp): class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable."""
pass
@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass pass
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import make as _make from . import make as _make
...@@ -174,7 +175,7 @@ class Call(Expr): ...@@ -174,7 +175,7 @@ class Call(Expr):
Halide = 3 Halide = 3
Intrinsic = 4 Intrinsic = 4
PureIntrinsic = 5 PureIntrinsic = 5
pass
@register_node @register_node
class Let(Expr): class Let(Expr):
......
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert from ._ctypes._api import _init_function_module, convert
from . import _function_internal from . import _function_internal
from . import make as _make from . import make as _make
...@@ -8,6 +11,7 @@ from . import collections as _collections ...@@ -8,6 +11,7 @@ from . import collections as _collections
int32 = "int32" int32 = "int32"
float32 = "float32" float32 = "float32"
handle = "handle"
def const(value, dtype=None): def const(value, dtype=None):
"""construct a constant""" """construct a constant"""
...@@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32): ...@@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype) return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="placeholder"): def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object. """Construct an empty tensor object.
Parameters Parameters
...@@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"): ...@@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"):
tensor: tensor.Tensor tensor: tensor.Tensor
The created tensor The created tensor
""" """
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder( return _function_internal._Placeholder(
shape, dtype, name) shape, dtype, name)
...@@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"): ...@@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"):
tensor: tensor.Tensor tensor: tensor.Tensor
The created tensor The created tensor
""" """
if isinstance(shape, _expr.Expr): shape = (shape,) if isinstance(shape, _expr.Expr) else shape
shape = (shape, )
ndim = len(shape) ndim = len(shape)
arg_names = fcompute.__code__.co_varnames arg_names = fcompute.__code__.co_varnames
...@@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"): ...@@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"):
op_node = _function_internal._ComputeOp( op_node = _function_internal._ComputeOp(
name, dim_var, body) name, dim_var, body)
return _function_internal._Tensor( return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0) shape, body.dtype, op_node, 0)
def Buffer(shape, dtype=None,
name="buffer", ptr=None,
strides=None):
"""Create a new buffer
Parameters
----------
shape : tuple of Expr
The shape of the buffer.
dtype : str, optional
The data type of the buffer.
name : str, optional
The name of the buffer.
ptr : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
Returns
-------
buffer : Buffer
The created buffer
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if ptr is None:
ptr = Var(name, "handle")
return _function_internal._Buffer(
name, ptr, shape, strides, dtype)
def IterVar(dom, name='iter', thread_tag=''): def IterVar(dom, name='iter', thread_tag=''):
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
...@@ -6,15 +7,18 @@ from . import tensor as _tensor ...@@ -6,15 +7,18 @@ from . import tensor as _tensor
@register_node @register_node
class Split(NodeBase): class Split(NodeBase):
"""Split operation on axis."""
pass pass
@register_node @register_node
class Fuse(NodeBase): class Fuse(NodeBase):
"""Fuse operation on axis."""
pass pass
@register_node @register_node
class Schedule(NodeBase): class Schedule(NodeBase):
"""Schedule for all the stages."""
def __getitem__(self, k): def __getitem__(self, k):
if isinstance(k, _tensor.Tensor): if isinstance(k, _tensor.Tensor):
k = k.op k = k.op
...@@ -26,6 +30,7 @@ class Schedule(NodeBase): ...@@ -26,6 +30,7 @@ class Schedule(NodeBase):
@register_node @register_node
class Stage(NodeBase): class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, outer=None): def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both """Split the stage either by factor providing outer scope, or both
...@@ -132,6 +137,32 @@ class Stage(NodeBase): ...@@ -132,6 +137,32 @@ class Stage(NodeBase):
_function_internal._StageReorder(self, args) _function_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor): def tile(self, x_parent, y_parent, x_factor, y_factor):
""" Perform tiling on two dimensions
The final loop order from outmost to inner most are
[x_outer, y_outer, x_inner, y_inner]
Parameters
----------
x_parent : IterVar
The original x dimension
y_parent : IterVar
The original y dimension
x_factor : Expr
The stride factor on x axis
y_factor : Expr The stride factor on y axis
Returns
-------
x_outer : IterVar
Outer axis of x dimension
y_outer : IterVar
Outer axis of y dimension
x_inner : IterVar
Inner axis of x dimension
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile( x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import make as _make
class Stmt(NodeBase): class Stmt(NodeBase):
pass pass
...@@ -23,7 +23,6 @@ class For(Stmt): ...@@ -23,7 +23,6 @@ class For(Stmt):
Parallel = 1 Parallel = 1
Vectorized = 2 Vectorized = 2
Unrolled = 3 Unrolled = 3
pass
@register_node @register_node
class Store(Stmt): class Store(Stmt):
......
# pylint: disable=protected-access, no-member, invalid-name
"""Tensor related abstractions"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, SliceBase, register_node, convert from ._ctypes._api import NodeBase, SliceBase, register_node, convert
from . import collections as _collections from . import collections as _collections
...@@ -51,10 +53,12 @@ class Tensor(NodeBase): ...@@ -51,10 +53,12 @@ class Tensor(NodeBase):
@property @property
def ndim(self): def ndim(self):
"""Dimension of the tensor."""
return len(self.shape) return len(self.shape)
class Operation(NodeBase): class Operation(NodeBase):
"""Represent an operation that generate a tensor"""
def output(self, index): def output(self, index):
"""Get the index-th output of the operation """Get the index-th output of the operation
...@@ -72,8 +76,10 @@ class Operation(NodeBase): ...@@ -72,8 +76,10 @@ class Operation(NodeBase):
@register_node @register_node
class ComputeOp(Operation): class ComputeOp(Operation):
"""Compute operation."""
pass pass
@register_node @register_node
class PlaceholderOp(Operation): class PlaceholderOp(Operation):
"""Placeholder operation."""
pass pass
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace tvm { namespace tvm {
inline std::string Type2String(const Type& t) { inline std::string Type2String(const Type& t) {
if (t.code() ==Type::Handle) return "handle";
std::ostringstream os; std::ostringstream os;
os << t; os << t;
return os.str(); return os.str();
...@@ -28,6 +29,8 @@ inline Type String2Type(std::string s) { ...@@ -28,6 +29,8 @@ inline Type String2Type(std::string s) {
code = Type::Float; s = s.substr(5); code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5); code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 0, 0);
} else { } else {
LOG(FATAL) << "unknown type " << s; LOG(FATAL) << "unknown type " << s;
} }
......
...@@ -123,6 +123,7 @@ REGISTER_MAKE3(Let); ...@@ -123,6 +123,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Load);
REGISTER_MAKE3(Store); REGISTER_MAKE3(Store);
REGISTER_MAKE4(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
...@@ -140,14 +141,23 @@ TVM_REGISTER_API(Range) ...@@ -140,14 +141,23 @@ TVM_REGISTER_API(Range)
.add_argument("begin", "Expr", "beginning of the range.") .add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range"); .add_argument("end", "Expr", "extent of the range");
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API(_Buffer)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0), *ret = BufferNode::make(args.at(0),
args.at(1),
args.at(2), args.at(2),
args.at(3), args.at(3),
args.at(4)); args.at(4));
}); });
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3));
});
TVM_REGISTER_API(_TensorEqual) TVM_REGISTER_API(_TensorEqual)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor(); *ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
......
/*!
* Copyright (c) 2016 by Contributors
* \file buffer.cc
*/
#include <tvm/buffer.h>
#include <tvm/ir.h>
namespace tvm {
Array<Expr> GetStrides(Array<Expr> shape) {
CHECK_NE(shape.size(), 0U);
std::vector<Expr> vec{make_const(shape[0].type(), 1)};
for (size_t i = shape.size() - 1; i != 0; --i) {
vec.push_back(shape[i - 1] * vec.back());
}
return Array<Expr>(vec.rbegin(), vec.rend());
}
Buffer::Buffer(Array<Expr> shape,
Type dtype,
std::string name)
: Buffer(BufferNode::make(
name,
Var(name, Type(Type::Handle, 0, 0)),
shape, Array<Expr>(), dtype)) {
}
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr base;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
base = index[0];
for (size_t i = 1; i < index.size(); ++i) {
base = base * n->shape[i] + index[i];
}
} else {
CHECK_EQ(n->strides.size(), index.size());
base = index[0] * n->strides[0];
for (size_t i = 1; i < index.size(); ++i) {
base = base + index[i] * n->strides[i];
}
}
return base;
}
Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->();
return ir::Load::make(n->dtype, n->ptr, BufferOffset(n, index));
}
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->ptr, BufferOffset(n, index), value);
}
Buffer BufferNode::make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype) {
auto n = std::make_shared<BufferNode>();
n->name = name;
n->ptr = ptr;
n->shape = shape;
n->strides = strides;
n->dtype = dtype;
return Buffer(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
p->stream << "buffer(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(BufferNode);
} // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file operation.cc * \file operation.cc
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace {
// inliner to inline a function // inliner to inline a function
// the result may not be SSA, // the result may not be SSA,
...@@ -50,12 +49,10 @@ class IRInline : public IRMutator { ...@@ -50,12 +49,10 @@ class IRInline : public IRMutator {
} }
}; };
} // namespace Stmt Inline(Stmt stmt,
FunctionRef f,
Stmt Inline(FunctionRef f,
Array<Var> args, Array<Var> args,
Expr body, Expr body) {
Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1) CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation"; << "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace {
/*! /*!
* \brief use message passing to calculate the assignment of each Var inside the loop body. * \brief use message passing to calculate the assignment of each Var inside the loop body.
...@@ -256,7 +255,7 @@ Stmt MakePipeline(const Stage& sch, ...@@ -256,7 +255,7 @@ Stmt MakePipeline(const Stage& sch,
if (sch->op.as<ComputeOpNode>()) { if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors); provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
} else { } else {
LOG(FATAL) << "not supported op"; LOG(FATAL) << "not supported op " << sch->op->type_key();
} }
std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map); std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
Stmt producer = MergeNest(nest, provide); Stmt producer = MergeNest(nest, provide);
...@@ -317,10 +316,9 @@ Stmt InjectInline(const Operation op, Stmt body) { ...@@ -317,10 +316,9 @@ Stmt InjectInline(const Operation op, Stmt body) {
for (auto iv : compute->axis) { for (auto iv : compute->axis) {
args.push_back(iv->var); args.push_back(iv->var);
} }
return Inline(op, args, compute->body, body); return Inline(body, op, args, compute->body);
} }
} // namespace
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) { Schedule sch, Map<IterVar, Range> dom_map) {
...@@ -328,6 +326,8 @@ Stmt ScheduleOps( ...@@ -328,6 +326,8 @@ Stmt ScheduleOps(
// reverse the post DFS order. // reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1]; Stage s = sch->stages[i - 1];
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
if (s->attach_type == kInline) { if (s->attach_type == kInline) {
body = InjectInline(s->op, body); body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) { } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
......
...@@ -151,8 +151,10 @@ BoundProp(const Array<Operation>& post_order, ...@@ -151,8 +151,10 @@ BoundProp(const Array<Operation>& post_order,
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<PlaceholderOpNode>()) {
// do nothing
} else { } else {
LOG(FATAL) << "unknown operation mode"; LOG(FATAL) << "unknown operation mode " << op->type_key();
} }
} }
return result; return result;
......
...@@ -42,12 +42,13 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) { ...@@ -42,12 +42,13 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps); rmap.Set(op, deps);
} else if (op.as<PlaceholderOpNode>()) {
// empty set of deps
rmap.Set(op, deps);
} else { } else {
if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key(); LOG(FATAL) << "unknown Operation" << op->type_key();
} }
} }
}
return rmap; return rmap;
} }
...@@ -56,7 +57,7 @@ void PostDFSOrder(const Operation& op, ...@@ -56,7 +57,7 @@ void PostDFSOrder(const Operation& op,
const ReadGraph& g, const ReadGraph& g,
std::unordered_set<Operation>* visited, std::unordered_set<Operation>* visited,
Array<Operation>* post_order) { Array<Operation>* post_order) {
if (op.as<PlaceholderOpNode>() || visited->count(op)) return; if (visited->count(op)) return;
visited->insert(op); visited->insert(op);
for (const auto& t : g.at(op)) { for (const auto& t : g.at(op)) {
PostDFSOrder(t->op, g, visited, post_order); PostDFSOrder(t->op, g, visited, post_order);
......
import tvm
def test_buffer():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
Ab = tvm.Buffer((m, n), tvm.float32)
Bb = tvm.Buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.collections.Buffer)
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
if __name__ == "__main__":
test_buffer()
...@@ -33,6 +33,7 @@ def test_tensor_reduce(): ...@@ -33,6 +33,7 @@ def test_tensor_reduce():
assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(isinstance(C_loaded, tvm.tensor.Tensor))
assert(str(C_loaded) == str(C)) assert(str(C_loaded) == str(C))
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor()
test_tensor_reduce() test_tensor_reduce()
...@@ -6,7 +6,7 @@ def test_inline(): ...@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T.op, [x.var for x in T.op.axis], T.op.body, stmt) stmt, T.op, [x.var for x in T.op.axis], T.op.body)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
......
...@@ -63,8 +63,8 @@ def test_create_read_graph(): ...@@ -63,8 +63,8 @@ def test_create_read_graph():
assert g[A2.op][0] == A1 assert g[A2.op][0] == A1
assert g[A1.op][0] == A assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder([A2.op], g) post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A1.op) assert(post_order[0] == A.op)
assert(post_order[1] == A2.op) assert(post_order[1] == A1.op)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment