Commit d4af7ad6 by Tianqi Chen Committed by GitHub

[TEST/PYTHON] Add unittest folder, add a build pipeline. Rename Buffer.ptr to…

[TEST/PYTHON] Add unittest folder, add a build pipeline. Rename Buffer.ptr to Buffer.data to be consistent with Array. (#29)
parent 891630ed
......@@ -61,7 +61,7 @@ class BufferNode : public Node {
/*! \brief optional name of the buffer */
std::string name;
/*! \brief The pointer to the head of the data */
Var ptr;
Var data;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
......@@ -77,7 +77,7 @@ class BufferNode : public Node {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("ptr", &ptr);
v->Visit("data", &data);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
......
......@@ -17,3 +17,4 @@ from .ndarray import cpu, gpu, opencl, init_opencl, cl
from ._base import TVMError
from .api import *
from .build import build
......@@ -145,7 +145,7 @@ def Buffer(shape, dtype=None,
name="buffer",
ptr=None,
strides=None):
"""Create a new buffer
"""Create a new symbolic buffer
Parameters
----------
......
"""The build pipeline in python.
Eventually some of these pipelines will be moved to C++.
But the first pipeline will be kept in python for ease of change and evolving.
"""
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
from . import api
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import codegen
def build(sch,
args,
target,
name="default_function",
binds=None,
record_codes=None):
"""Build a function with arguments as signiture.
Parameters
----------
sch : tvm.Schedule
The schedule to be builded
args : list of Buffer or Tensor or Var
The argument lists to the function.
target : str
The target of the compilation.
name : str
The name of result function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
f : Function, or pair of functions
The result function.
If the function requires host space allocation,
a pair of functions will be returned.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.Buffer(x.shape, dtype=x.dtype, name=x.op.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
if record_codes is not None:
output_ssa = False
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
ret = codegen.BuildOpenCL(fsplits, "stackvm")
else:
raise ValueError("Unknown target %s" % target)
return ret
......@@ -59,12 +59,6 @@ class IterVar(NodeBase, _expr.ExprOp):
@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
......@@ -6,6 +6,11 @@ from . import _api_internal
from . import tensor as _tensor
@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass
......
......@@ -138,9 +138,9 @@ LoweredFunc MakeAPI(Stmt body,
UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field
if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
v_arg->name_hint + ".data")) {
Var vptr(buf->ptr);
Var vptr(buf->data);
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
}
// shape field
......
......@@ -45,23 +45,23 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->();
return ir::Load::make(n->dtype, n->ptr, BufferOffset(n, index));
return ir::Load::make(n->dtype, n->data, BufferOffset(n, index));
}
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->ptr, value, BufferOffset(n, index));
return ir::Store::make(n->data, value, BufferOffset(n, index));
}
Buffer BufferNode::make(std::string name,
Var ptr,
Var data,
Array<Expr> shape,
Array<Expr> strides,
Type dtype) {
auto n = std::make_shared<BufferNode>();
n->name = name;
n->ptr = ptr;
n->data = data;
n->shape = shape;
n->strides = strides;
n->dtype = dtype;
......
......@@ -138,7 +138,7 @@ class StorageFlattener : public IRMutator {
buf_map_[key].released = true;
return Allocate::make(
e.buffer->ptr, e.buffer->dtype, e.buffer->shape,
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
}
......
import tvm
import numpy as np
def test_add():
# graph
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
# schedule
s = tvm.Schedule(C.op)
# create iter var and assign them tags.
num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[C].split(x, outer=thread_x)
# one line to build the function.
codes = []
fadd = tvm.build(s, args=[A, B, C],
target="cuda", name="myadd",
record_codes=codes)
for c in codes:
print(c)
# call the function
num_device = 1
for i in range(num_device):
ctx = tvm.gpu(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
if __name__ == "__main__":
test_add()
......@@ -18,7 +18,7 @@ def test_makeapi():
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.ptr].dtype == Ab.dtype)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5)
output_ssa = False
......
......@@ -7,7 +7,7 @@ def test_buffer():
Ab = tvm.Buffer((m, n), tvm.float32)
Bb = tvm.Buffer((n, l), tvm.float32)
assert isinstance(Ab, tvm.collections.Buffer)
assert isinstance(Ab, tvm.schedule.Buffer)
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
......
......@@ -37,8 +37,8 @@ def test_stack_vm_loop():
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.Block(
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1),
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
print(stmt)
......@@ -59,10 +59,10 @@ def test_stack_vm_cond():
i, 0, n - 1, 0, 0,
tvm.make.IfThenElse(
tvm.make.EQ(i, 4),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1, i + 1),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 2, i + 1)))
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, i + 1),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
......
......@@ -38,10 +38,10 @@ fi
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
make all || exit -1
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose -v tests/python/ || exit -1
python3 -m nose -v tests/python/ || exit -1
python -m nose -v tests/python/unittest || exit -1
python3 -m nose -v tests/python/unittest || exit -1
else
nosetests -v tests/python/ || exit -1
nosetests3 -v tests/python/ || exit -1
nosetests -v tests/python/unittest || exit -1
nosetests3 -v tests/python/unittest || exit -1
fi
fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment