Commit aa424139 by Philip Hyunsu Cho Committed by Wuwei Lin

[TOPI] FIFO buffer op, to accelerate sequence modeling with dilated convolutions (#4039)

* Add FIFO buffer op to enable explicit computation re-use in convolution

* Add a test

* Add end-to-end test with 1D convolution

* Add a stub in MXNet frontend

* Address reviewer comments

* Add back stub for MXNet frontend
parent 47e50e1e
......@@ -376,6 +376,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
};
/*! \brief Attributes for FIFO buffer operator */
struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
int axis;
TVM_DECLARE_ATTRS(FIFOBufferAttrs, "relay.attrs.FIFOBufferAttrs") {
TVM_ATTR_FIELD(axis).set_default(0);
}
};
/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
......
......@@ -1026,6 +1026,12 @@ def _mx_one_hot(inputs, attrs):
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -1198,6 +1204,7 @@ _convert_map = {
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
"contrib_fifo_buffer" : _mx_contrib_fifo_buffer,
}
# set identity list
......
......@@ -69,6 +69,20 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute('nn.fifo_buffer')
def compute_fifo_buffer(attrs, inputs, out_type, target):
return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]
@reg.register_schedule('nn.fifo_buffer')
def schedule_fifo_buffer(attrs, outputs, target):
with target:
return topi.generic.schedule_injective(outputs)
reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
......
......@@ -601,6 +601,36 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)
def fifo_buffer(data, buffer, axis):
"""FIFO buffer
Compute equivalent of
```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```
Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
Parameters
----------
data : tvm.relay.Expr
The input data
buffer : tvm.relay.Expr
Previous value of the FIFO buffer
axis : int
Specify which axis should be used for buffering
Returns
-------
result : tvm.relay.Expr
Updated value for the buffer
"""
return _make.fifo_buffer(data, buffer, axis)
def relu(data):
"""Rectified linear unit.
......
......@@ -55,6 +55,11 @@ class DenseAttrs(Attrs):
@register_relay_attr_node
class FIFOBufferAttrs(Attrs):
"""Attributes for nn.fifo_buffer"""
@register_relay_attr_node
class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""
......
......@@ -100,6 +100,73 @@ RELAY_REGISTER_OP("nn.bias_add")
});
// relay.nn.fifo_buffer
TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);
Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
auto attrs = make_node<FIFOBufferAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.fifo_buffer");
return CallNode::make(op, {input, buffer}, Attrs(attrs), {});
}
bool FIFOBufferRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* input = types[0].as<TensorTypeNode>();
const auto* buffer = types[1].as<TensorTypeNode>();
const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>();
if (input == nullptr || buffer == nullptr) {
return false;
}
CHECK(param != nullptr);
CHECK_EQ(input->shape.size(), buffer->shape.size());
const size_t buffer_axis
= static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis
: param->axis);
reporter->Assert(buffer_axis < buffer->shape.size());
for (size_t i = 0; i < buffer->shape.size(); ++i) {
if (i != buffer_axis) {
reporter->AssertEQ(input->shape[i], buffer->shape[i]);
}
}
reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);
Array<tvm::Expr> oshape = buffer->shape;
reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
return true;
}
TVM_REGISTER_API("relay.op.nn._make.fifo_buffer")
.set_body_typed(MakeFIFOBuffer);
RELAY_REGISTER_OP("nn.fifo_buffer")
.describe(R"code(FIFO buffer
Compute equivalent of
```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```
Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.FIFOBufferAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Latest input")
.add_argument("buffer", "Tensor",
"Buffer storing latest [length_buffer] inputs")
.set_support_level(3)
.add_type_rel("FIFOBuffer", FIFOBufferRel);
// relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs);
......
......@@ -22,3 +22,4 @@ from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
from .pad import *
from .fifo_buffer import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FIFO buffer op"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from ..transform import concatenate, strided_slice
@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer")
def fifo_buffer(data, buffer, axis):
"""
Implements the FIFO buffer
"""
assert len(data.shape) == len(buffer.shape), \
'buffer and data must have same number of dimensions, ' + \
'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape)
assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported'
assert 0 <= axis < len(buffer.shape), 'buffer axis out of range'
for i in range(len(data.shape)):
if i == axis:
assert int(str(data.shape[i])) <= int(str(buffer.shape[i]))
else:
assert int(str(data.shape[i])) == int(str(buffer.shape[i]))
buflen = buffer.shape[axis]
data_size = data.shape[axis]
# Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher
if len(buffer.shape) == 1:
return tvm.compute(buffer.shape,
lambda i:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size],
data[i - buflen + data_size]),
name='new_buffer')
elif len(buffer.shape) == 2:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j],
data[i - buflen + data_size, j]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size],
data[i, j - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 3:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k],
data[i - buflen + data_size, j, k]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k],
data[i, j - buflen + data_size, k]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size],
data[i, j, k - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 4:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k, l],
data[i - buflen + data_size, j, k, l]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k, l],
data[i, j - buflen + data_size, k, l]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size, l],
data[i, j, k - buflen + data_size, l]),
name='new_buffer')
if axis == 3:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(l < buflen - data_size,
buffer[i, j, k, l + data_size],
data[i, j, k, l - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
else:
# Implement FIFO buffer as combination of concat and slice
begin = [0] * len(buffer.shape)
begin[axis] = data.shape[axis]
end = list(buffer.shape[:])
end[axis] += data.shape[axis]
return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end)
return None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for FIFO buffer"""
import tvm
import topi
import numpy as np
from common import get_all_backend
from tvm.contrib.pickle_memoize import memoize
def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
buffer = tvm.placeholder(buffer_shape, name='buffer', dtype=dtype)
data = tvm.placeholder(data_shape, name='data', dtype=dtype)
# Use memoize, pickle the test data for next time use
@memoize('topi.tests.test_fifo_buffer')
def get_ref_data():
buffer_np = np.random.uniform(size=buffer_shape).astype(dtype)
data_np = np.random.uniform(size=data_shape).astype(dtype)
# Reference implementation of FIFO queue
begin = data_np.shape[axis]
end = buffer_np.shape[axis] + data_np.shape[axis]
ndim = len(buffer_np.shape)
ss = tuple((slice(begin, end, 1) if x == axis else slice(None)) for x in range(ndim))
out_np = np.concatenate((buffer_np, data_np), axis=axis)[ss]
return (buffer_np, data_np, out_np)
# Get the test data
buffer_np, data_np, out_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print(' Skip because %s is not enabled' % device)
return
print(' Running on target: {}'.format(device))
with tvm.target.create(device):
out = topi.nn.fifo_buffer(data, buffer, axis=axis)
s = topi.generic.schedule_injective([out])
buffer_tvm = tvm.nd.array(buffer_np, ctx=ctx)
data_tvm = tvm.nd.array(data_np, ctx=ctx)
out_tvm = tvm.nd.empty(shape=buffer_shape, ctx=ctx, dtype=dtype)
f = tvm.build(s, [data, buffer, out], device, name='fifo')
f(data_tvm, buffer_tvm, out_tvm)
tvm.testing.assert_allclose(out_tvm.asnumpy(), out_np)
for device in get_all_backend():
check_device(device)
def verify_conv1d_integration():
batch_size = 1
num_channel = 1
num_filter = 1
# Note: TVM doesn't have a separate op for 1D convolution, so we use conv2d instead.
# We set height=1 to indicate that convolution is really 1D.
stride = (1, 1)
dilate = (1, 1)
padding = (0, 0)
kernel_size = (1, 3)
input_window_size = (1, 10)
inc_input_size = (1, 2)
context_size = (1, 4)
inc_output_size = (1, 2)
output_window_size = (1, 8)
num_iteration = 20
buffer_axis = 3
kernel_shape = (num_filter, num_channel, kernel_size[0], kernel_size[1])
input_window_shape = (batch_size, num_channel, input_window_size[0], input_window_size[1])
inc_input_shape = (batch_size, num_channel, inc_input_size[0], inc_input_size[1])
inc_output_shape = (batch_size, num_filter, inc_output_size[0], inc_output_size[1])
context_shape = (batch_size, num_channel, context_size[0], context_size[1])
output_window_shape = (batch_size, num_filter, output_window_size[0], output_window_size[1])
# Rule: Convolution of Tensor[context_shape] and Tensor[kernel_shape]
# produces Tensor[inc_input_shape]
dtype = 'float32'
inc_input = tvm.placeholder(inc_input_shape, name='inc_input', dtype=dtype)
input_window = tvm.placeholder(input_window_shape, name='input_window', dtype=dtype)
context = tvm.placeholder(context_shape, name='context', dtype=dtype)
kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=dtype)
inc_output = tvm.placeholder(inc_input_shape, name='inc_output', dtype=dtype)
output_window = tvm.placeholder(output_window_shape, name='output_window', dtype=dtype)
# Use memoize, pickle the test data for next time use
@memoize('topi.tests.test_fifo_buffer_conv1d_integration')
def get_data():
# Generate [num_iteration] slices of input
inc_input_np = np.random.uniform(size=tuple([num_iteration] + list(inc_input_shape)))\
.astype(dtype)
input_window_np = np.zeros(input_window_shape, dtype=dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
context_np = np.zeros(context_shape, dtype=dtype)
output_window_np = np.zeros(output_window_shape, dtype=dtype)
return (inc_input_np, input_window_np, kernel_np, context_np, output_window_np)
# Get the test data
inc_input_np, input_window_np, kernel_np, context_np, output_window_np = get_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print(' Skip because %s is not enabled' % device)
return
print(' Running on target: {}'.format(device))
with tvm.target.create(device):
out = topi.nn.fifo_buffer(inc_input, context, axis=buffer_axis)
s = topi.generic.schedule_injective([out])
update_context = tvm.build(s, [inc_input, context, out], device, name='update_context')
out = topi.nn.conv2d(context, kernel, strides=stride, padding=padding, dilation=dilate,
layout='NCHW', out_dtype=dtype)
s = topi.generic.schedule_conv2d_nchw([out])
conv2d_inc = tvm.build(s, [context, kernel, out], device, name='conv2d_inc')
out = topi.nn.fifo_buffer(inc_output, output_window, axis=buffer_axis)
s = topi.generic.schedule_injective([out])
update_output_window = tvm.build(s, [inc_output, output_window, out], device,
name='update_output_window')
out = topi.nn.fifo_buffer(inc_input, input_window, axis=buffer_axis)
s = topi.generic.schedule_injective([out])
update_input_window = tvm.build(s, [inc_input, input_window, out], device,
name='update_input_window')
out = topi.nn.conv2d(input_window, kernel, strides=stride, padding=padding,
dilation=dilate, layout='NCHW', out_dtype=dtype)
s = topi.generic.schedule_conv2d_nchw([out])
conv2d = tvm.build(s, [input_window, kernel, out], device, name='conv2d')
input_window_tvm = tvm.nd.array(input_window_np, ctx=ctx)
new_input_window_tvm = tvm.nd.empty(shape=input_window_shape, ctx=ctx, dtype=dtype)
kernel_tvm = tvm.nd.array(kernel_np, ctx=ctx)
context_tvm = tvm.nd.array(context_np, ctx=ctx)
new_context_tvm = tvm.nd.empty(shape=context_shape, ctx=ctx, dtype=dtype)
inc_output_tvm = tvm.nd.empty(shape=inc_output_shape, ctx=ctx, dtype=dtype)
output_window_tvm = tvm.nd.array(output_window_np, ctx=ctx)
new_output_window_tvm = tvm.nd.empty(shape=output_window_shape, ctx=ctx, dtype=dtype)
output_window_ref_tvm = tvm.nd.empty(shape=output_window_shape, ctx=ctx, dtype=dtype)
for i in range(num_iteration):
# Take i-th slice of inc_input_np
inc_input_tvm = tvm.nd.array(inc_input_np[i], ctx=ctx)
# Compute new output window incrementally, using the FIFO buffer op
update_context(inc_input_tvm, context_tvm, new_context_tvm)
conv2d_inc(new_context_tvm, kernel_tvm, inc_output_tvm)
update_output_window(inc_output_tvm, output_window_tvm, new_output_window_tvm)
context_tvm = new_context_tvm
output_window_tvm = new_output_window_tvm
# Compute full input window, so that we have a baseline
update_input_window(inc_input_tvm, input_window_tvm, new_input_window_tvm)
input_window_tvm = new_input_window_tvm
conv2d(input_window_tvm, kernel_tvm, output_window_ref_tvm)
# Incrementally updating the output window should be equivalent to computing it from
# scratch using the input window
tvm.testing.assert_allclose(output_window_tvm.asnumpy(),
output_window_ref_tvm.asnumpy())
for device in get_all_backend():
check_device(device)
def test_fifo_buffer():
for ndim in [1, 2, 3, 4, 5, 6]:
for axis in range(ndim):
buffer_shape = tuple(7 for _ in range(ndim))
data_shape = tuple((2 if i == axis else 7) for i in range(ndim))
print('Testing FIFO buffer op: buffer_shape = {}, data_shape = {}, axis = {}'
.format(buffer_shape, data_shape, axis))
verify_fifo_buffer(buffer_shape, data_shape, axis)
def test_conv1d_integration():
print('Testing FIFO buffer with 1D convolution')
verify_conv1d_integration()
if __name__ == '__main__':
test_fifo_buffer()
test_conv1d_integration()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment