Commit 760475f9 by Xingjian Shi Committed by Tianqi Chen

[TOPI] Add broadcast and reduce operators (#267)

[TOPI] Add broadcast and reduce operators
parent a59774e3
......@@ -41,6 +41,7 @@ def find_lib_path(name=None, search_path=None):
# Default cmake build directory
dll_path.append(os.path.join(source_dir, "build"))
dll_path.append(os.path.join(source_dir, "build", "Release"))
# Default mkae build directory
dll_path.append(os.path.join(source_dir, "lib"))
......@@ -57,8 +58,10 @@ def find_lib_path(name=None, search_path=None):
runtime_dll_path = []
else:
if sys.platform.startswith('win32'):
lib_dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path]
runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dll') for p in dll_path]
lib_dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path] +\
[os.path.join(p, 'tvm.dll') for p in dll_path]
runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dll') for p in dll_path] +\
[os.path.join(p, 'tvm_runtime.dll') for p in dll_path]
elif sys.platform.startswith('darwin'):
lib_dll_path = [os.path.join(p, 'libtvm.dylib') for p in dll_path]
runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dylib') for p in dll_path]
......
......@@ -320,23 +320,27 @@ class Stage(NodeBase):
outer, inner = _api_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, outer, inner):
"""Fuse inner and outer to a single iteration variable.
def fuse(self, *args):
"""Fuse multiple consecutive iteration variables into a single iteration variable.
fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
The order is from outer to inner.
Parameters
----------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
args : list of IterVars
Itervars that proceeds each other
Returns
-------
fused : IterVar
The fused variable of iteration.
"""
return _api_internal._StageFuse(self, outer, inner)
assert len(args) >= 1, "Length of the arguments must be >=1 for fuse."
fused = args[0]
for i in range(1, len(args)):
fused = _api_internal._StageFuse(self, fused, args[i])
return fused
def set_scope(self, scope):
"""Set the thread scope of this stage
......
......@@ -6,6 +6,7 @@
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <iterator>
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -33,8 +34,196 @@ Buffer decl_buffer(Array<Expr> shape,
0, 0);
}
// Split the given expression w.r.t the add operator
inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
using namespace ir;
std::vector<const Expr*> ret;
std::stack<const Expr*> split_buffer;
split_buffer.push(&expr);
while (!split_buffer.empty()) {
const Expr* top_ele = split_buffer.top();
split_buffer.pop();
auto expr_add_match = top_ele->as<Add>();
if (expr_add_match) {
split_buffer.push(&expr_add_match->b);
split_buffer.push(&expr_add_match->a);
} else {
ret.emplace_back(top_ele);
}
}
return ret;
}
// Searches for the following types of expr:
// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
// mod_l_expr = c
// mod_r_expr = k1 * k2 * ... * ki
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
// as it will take too much computation.
inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
const Expr &mod_l_expr,
const Expr &mod_r_expr) {
using namespace ir;
const Mul* mult_ptr = mult_expr.as<Mul>();
if (!mult_ptr) return std::make_pair(false, Expr());
Expr mult_outer = mult_ptr->b;
const Expr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier
while (true) {
mult_ptr = inner->as<Mul>();
if (mult_ptr) {
inner = &(mult_ptr->a);
mult_outer = mult_ptr->b * mult_outer;
} else {
break;
}
}
// 2. Search for the pattern c / (...) * (...) + c % (...)
// We match the search element with Add, Mul and Div.
// If Add is found, we need to continue our search for the rhs
// If Mult is found, we will expand the inner multiplication factor
// If Div is found, we will go on testing whether lhs matches the lhs of mod expr
// and returns the optimization result.
const Expr* search_ptr = inner;
Expr mult_inner; // The inner multiplication factor
Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) {
auto inner_div_ptr = search_ptr->as<Div>();
auto inner_mult_ptr = search_ptr->as<Mul>();
auto inner_add_ptr = search_ptr->as<Add>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
return std::make_pair(false, Expr());
} else if (inner_div_ptr) {
Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
if (Equal(overall_mult, inner_div_ptr->b)
&& Equal(overall_mult, mod_r_expr)
&& Equal(inner_div_ptr->a, mod_l_expr)) {
// Found!
Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
return std::make_pair(true, ret);
} else {
return std::make_pair(false, Expr());
}
} else if (inner_mult_ptr) {
mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
search_ptr = &(inner_mult_ptr->a);
} else if (inner_add_ptr) {
if (mult_inner.get()) {
return std::make_pair(false, Expr());
}
no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
search_ptr = &(inner_add_ptr->b);
} else {
LOG(FATAL) << "Unexpected search result!";
break;
}
}
return std::make_pair(false, Expr());
}
// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
std::list<Expr>* mult_exprs,
std::list<std::pair<Expr, Expr> >* mod_exprs,
Expr* no_opt_sum,
bool* has_mult,
bool* has_mod) {
using namespace ir;
*has_mult = false;
*has_mod = false;
for (const Expr* ele : eles) {
auto mod_ptr = ele->as<Mod>();
auto mult_ptr = ele->as<Mul>();
if (mod_ptr) {
*has_mod = true;
mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
} else if (mult_ptr) {
*has_mult = true;
mult_exprs->emplace_back(*ele);
} else {
*no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
}
}
}
// Searches for this types of expr:
// (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
// + c % (k1 * k2 * ... * ki)
// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
// a pair with (true, optimized_expr) if can be optimized
inline Expr MergeMulMod(const Expr &base) {
using namespace ir;
// 1. Prepare the lists.
// We store two lists, a list that contain all the elements that match Mul and
// a list that contain all the elements that match Mod.
// The elements in the Mod will be used to match against the elements in Mul.
// The result will then be split and pushed back to these two lists.
Expr simplified_base = Simplify(base);
std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
std::list<Expr> mult_exprs;
std::list<std::pair<Expr, Expr> > mod_exprs;
Expr no_opt_sum;
bool has_mult;
bool has_mod;
MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
&no_opt_sum, &has_mult, &has_mod);
bool find_opt = false;
std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
// 2. Exhaustive Search
while (search_mod_it != mod_exprs.end()) {
std::list<Expr>::iterator mult_it = mult_exprs.begin();
bool inner_find_opt = false;
while (mult_it != mult_exprs.end()) {
std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
search_mod_it->first,
search_mod_it->second);
if (ret.first) {
inner_find_opt = true;
auto temp_mod_it = search_mod_it;
++search_mod_it;
mod_exprs.erase(temp_mod_it);
mult_exprs.erase(mult_it);
std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
&no_opt_sum, &has_mult, &has_mod);
if (has_mult) {
search_mod_it = mod_exprs.begin();
} else if (has_mod && search_mod_it == mod_exprs.end()) {
search_mod_it--;
}
break;
} else {
++mult_it;
}
}
find_opt = find_opt || inner_find_opt;
if (!inner_find_opt) {
++search_mod_it;
}
}
if (!find_opt) {
return simplified_base;
}
for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
}
for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second;
}
return no_opt_sum;
}
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
......@@ -44,18 +233,19 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
} else {
base = base + index[0];
}
base = MergeMulMod(base);
for (size_t i = 1; i < index.size(); ++i) {
base = base * n->shape[i] + index[i];
base = MergeMulMod(base * n->shape[i] + index[i]);
}
} else {
CHECK_EQ(n->strides.size(), index.size());
if (is_zero(base)) {
base = index[0] * n->strides[0];
base = MergeMulMod(index[0] * n->strides[0]);
} else {
base = base + index[0] * n->strides[0];
base = MergeMulMod(base + index[0] * n->strides[0]);
}
for (size_t i = 1; i < index.size(); ++i) {
base = base + index[i] * n->strides[i];
base = MergeMulMod(base + index[i] * n->strides[i]);
}
}
return base;
......
......@@ -11,6 +11,7 @@
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <algorithm>
#include <vector>
#include <string>
#include <cstring>
......
......@@ -23,6 +23,38 @@ def test_buffer_access_ptr():
aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE
def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
n = tvm.var('n')
s = tvm.var('s')
k0 = tvm.var('k0')
k1 = tvm.var('k1')
A = tvm.decl_buffer((m, n), tvm.float32)
A_stride = tvm.decl_buffer((m, n), tvm.float32, strides=(s, 1))
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
# Test Case1
index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case2
index_simplified = A.vload(((k0 % (k1 / s)) / n,
(k0 % (k1 / s)) % n + (k0 % k1)))
index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s)))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n,
((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload(((k0 % (k1 / s)) / n,
(k0 % (k1 / n)) % n + (k0 % k1)))
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
assert_simplified_equal(index_simplified, index_direct)
if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_index_merge_mult_mod()
......@@ -7,6 +7,8 @@ optimizing tvm generated kernels.
from __future__ import absolute_import as _abs
from .math import *
from .reduction import *
from .broadcast import *
from . import nn
from . import cuda
from . import testing
# pylint: disable=no-member,consider-using-enumerate
"""Broadcast operators"""
from __future__ import absolute_import as _abs
import tvm
def _get_bcast_info(original_shape, target_shape):
"""Get the broadcasting info.
bcast_info = _get_bcast_info(original_shape, target_shape)
In bcast_info:
-1 means to the padding dim
0 means to to be the same as the original shape
1 means to the broadcasted dim
E.g
original: (2, 1, 5), target: (2, 4, 5) => bcast_info: (0, 1, 0)
original: (2, 5), target: (4, 2, 5) => bcast_info: (-1, 0, 0)
original: (1, 5), target: (4, 2, 5) => bcast_info: (-1, 1, 0)
Parameters
----------
original_shape : tuple of tvm.expr.IntImm
The original shape before broadcasting
target_shape : tuple
The target shape
Returns
-------
bcast_info : list
"""
assert len(target_shape) >= len(original_shape)
bcast_info = [-1 for _ in range(len(target_shape))]
original_shape = [original_shape[i] for i in range(len(original_shape))]
original_shape = original_shape[::-1]
target_shape = target_shape[::-1]
for i in range(len(original_shape)):
if not isinstance(original_shape[i], tvm.expr.IntImm):
raise ValueError("Element of original_shape tuple should be IntImm")
if tvm.ir_pass.Equal(tvm.convert(target_shape[i]), original_shape[i]):
bcast_info[i] = 0
elif tvm.ir_pass.Equal(original_shape[i], tvm.convert(1)):
bcast_info[i] = 1
else:
raise ValueError("Original Shape: {} cannot be broadcast to {}"
.format(original_shape[::-1], target_shape[::-1]))
bcast_info = bcast_info[::-1]
return bcast_info
@tvm.tag_scope(tag="broadcast_to")
def broadcast_to(data, shape):
"""Broadcast the src to the target shape
We follows the numpy broadcasting rule.
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
Parameters
----------
data : tvm.Tensor
shape : list or tuple
Returns
-------
ret : tvm.Tensor
"""
def _bcast_to_arg_eval(data, bcast_info, *args):
indices_tuple = []
for i in range(len(args)):
if bcast_info[i] == 0:
indices_tuple.append(args[i])
elif bcast_info[i] == 1:
indices_tuple.append(0)
return data[tuple(indices_tuple)]
original_shape = data.shape
bcast_info = _get_bcast_info(original_shape=original_shape, target_shape=shape)
ret = tvm.compute([tvm.convert(ele) for ele in shape],
lambda *args: _bcast_to_arg_eval(data,
bcast_info,
*args), name=data.name + "_broadcast")
return ret
......@@ -5,3 +5,5 @@ from __future__ import absolute_import as _abs
from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
# pylint: disable=invalid-name,unused-variable
"""Schedule for broadcast operators"""
from __future__ import absolute_import as _abs
import tvm
def _schedule_broadcast_to(op, sch):
data_in = op.input_tensors[0]
data_out = op.output(0)
num_thread = 512
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
xo, vi = sch[data_out].split(sch[data_out].op.axis[len(sch[data_out].op.axis) - 1],
factor=4)
sch[data_out].vectorize(vi)
fused_axis = sch[data_out].fuse(*[sch[data_out].op.axis[i]
for i in range(len(sch[data_out].op.axis) - 1)] + [xo])
bx, tx = sch[data_out].split(fused_axis, factor=num_thread)
sch[data_out].bind(bx, block_x)
sch[data_out].bind(tx, thread_x)
return sch
def schedule_broadcast_to(outs):
"""Schedule for broadcast_to ops + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_to in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_to':
_schedule_broadcast_to(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
# pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition
"""Schedule for reduce operators"""
from __future__ import absolute_import as _abs
import tvm
def _schedule_reduce(op, sch):
data_in = op.input_tensors[0]
data_out = op.output(0)
assert len(sch[data_out].op.reduce_axis) > 0, "reduce_axis must be bigger than zero!"
if len(sch[data_out].op.axis) > 0:
all_reduce = False
num_thread = 16
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
else:
all_reduce = True
num_thread = 512
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
# Fuse and refactor the reduce axis
fused_reduce = sch[data_out].fuse(*[sch[data_out].op.reduce_axis[i]
for i in range(len(sch[data_out].op.reduce_axis))])
ko, ki = sch[data_out].split(fused_reduce, factor=num_thread)
data_out_rf = sch.rfactor(data_out, ki)
sch[data_out_rf].compute_at(sch[data_out], sch[data_out].op.reduce_axis[0])
if not all_reduce:
# Fuse and split the axis
fused_outer = sch[data_out].fuse(*[sch[data_out].op.axis[i]
for i in range(len(sch[data_out].op.axis))])
bx, outer_in = sch[data_out].split(fused_outer, factor=num_thread)
# Bind the axes to threads and blocks
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
sch[data_out].bind(outer_in, thread_y)
sch[data_out].bind(bx, block_x)
else:
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
return sch
def schedule_reduce(outs):
"""Schedule for reduce ops + ewise + scale_shift ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
......@@ -2,7 +2,7 @@
"""Neural network operators"""
from __future__ import absolute_import as _abs
from .mapping import *
from .ewise import *
from .mapping import *
from .conv import *
from .dilate import *
# pylint: disable=redefined-builtin,consider-using-enumerate
"""Reduce operators"""
from __future__ import absolute_import as _abs
import tvm
def _get_real_axis(ndim, axis):
if axis is None:
real_axis = list(range(ndim))
else:
if isinstance(axis, int):
axis = [axis]
else:
assert isinstance(axis, (list, tuple))
real_axis = []
for ele in axis:
if ele < 0:
ele += ndim
if ele >= ndim:
raise ValueError(
"{} exceeds the maximum dimension {}. Received axis={}".format(ele, ndim, axis))
real_axis.append(ele)
real_axis.sort()
real_axis = list(set(real_axis)) # Remove the duplicates
return real_axis
def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
real_axis = _get_real_axis(len(src_shape), axis)
if keepdims:
dst_shape = [src_shape[i] if i in real_axis else 1 for i in range(len(src_shape))]
else:
dst_shape = []
for i in range(len(src_shape)):
if i not in real_axis:
dst_shape.append(src_shape[i])
return dst_shape
@tvm.tag_scope(tag="comm_reduce")
def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
"""Reducing the data
Parameters
----------
data : tvm.Tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
func : function
functions like tvm.sum, tvm.max, tvm.min
Returns
-------
ret : tvm.Tensor
"""
def _build_reduce_compute_func(data, real_axis, reduce_axes, keepdims,
func, *args):
eval_range = []
if not keepdims:
arg_counter = 0
else:
arg_counter = None
red_counter = 0
for i in range(len(data.shape)):
if i in real_axis:
eval_range.append(reduce_axes[red_counter])
red_counter += 1
else:
if not keepdims:
eval_range.append(args[arg_counter])
arg_counter += 1
else:
eval_range.append(args[i])
return func(data[tuple(eval_range)], axis=reduce_axes)
ndim = len(data.shape)
real_axis = _get_real_axis(ndim, axis)
if real_axis == list(range(ndim)) and keepdims is False:
raise ValueError("Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}".format(axis, keepdims))
reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
if keepdims:
target_shape = [tvm.convert(1) if i in real_axis else tvm.convert(data.shape[i])
for i in range(ndim)]
else:
target_shape = []
for i in range(ndim):
if i not in real_axis:
target_shape.append(tvm.convert(data.shape[i]))
out = tvm.compute(target_shape,
lambda *args: _build_reduce_compute_func(data,
real_axis,
reduce_axes,
keepdims, func, *args),
name=data.name + "_red")
return out
def sum(data, axis=None, keepdims=False):
"""Sum of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.sum)
def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.max)
def min(data, axis=None, keepdims=False):
"""Minimum of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.min)
import os
import tvm
from tvm.contrib import nvcc
import numpy as np
import topi
TASK = "reduce_map"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_broadcast_to(in_shape, out_shape):
global TASK
TASK = "bcast_to_i" + "_".join([str(ele) for ele in in_shape])\
+ "o" + "_".join([str(ele) for ele in out_shape])
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B)
fcuda = tvm.build(s, [A, B], "cuda", name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, tvm.gpu())
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), tvm.gpu())
for _ in range(2):
fcuda(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
if __name__ == "__main__":
test_broadcast_to((1,), (10,))
test_broadcast_to((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
test_broadcast_to((1, 128, 1, 32), (64, 128, 64, 32))
import os
import tvm
from tvm.contrib import nvcc
import numpy as np
import topi
TASK = "reduce_map"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0):
global TASK
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
if type == "sum":
TASK = "sum_map_id%d" %test_id
B = topi.sum(A, axis=axis, keepdims=keepdims)
elif type == "max":
TASK = "max_map_id%d" %test_id
B = topi.max(A, axis=axis, keepdims=keepdims)
elif type == "min":
TASK = "min_map_id%d" %test_id
B = topi.min(A, axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
with tvm.build_config(auto_unroll_max_step=16,
auto_unroll_min_depth=0):
fcuda = tvm.build(s, [A, B], "cuda", name="sum")
# Test
in_npy = np.random.normal(size=in_shape).astype(np.float32)
if type == "sum":
out_npy = in_npy.sum(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy.max(axis=axis, keepdims=keepdims)
elif type == "min":
out_npy = in_npy.min(axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
data_tvm = tvm.nd.array(in_npy, ctx=tvm.gpu())
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=tvm.gpu())
for _ in range(2):
fcuda(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 4E-4, 4E-4)
if __name__ == "__main__":
test_reduce_map(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3),
keepdims=True,
type="sum",
test_id=0)
test_reduce_map(in_shape=(128, 24 * 128 * 24),
axis=(1,),
keepdims=False,
type="max",
test_id=1)
test_reduce_map(in_shape=(32, 128, 24),
axis=None,
keepdims=True,
type="sum",
test_id=2)
test_reduce_map(in_shape=(128, 24, 128, 24),
axis=(0, 2),
keepdims=False,
type="min",
test_id=3)
"""Test code for broadcasting operators."""
import os
import numpy as np
import tvm
import topi
def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,))
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
if __name__ == "__main__":
test_broadcast_to()
"""Test code for reduce."""
import os
import numpy as np
import tvm
import topi
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
if type == "sum":
B = topi.sum(A, axis=axis, keepdims=keepdims)
elif type == "max":
B = topi.max(A, axis=axis, keepdims=keepdims)
elif type == "min":
B = topi.min(A, axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.normal(size=in_shape).astype(np.float32)
if type == "sum":
out_npy = in_npy.sum(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy.max(axis=axis, keepdims=keepdims)
elif type == "min":
out_npy = in_npy.min(axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx)
for _ in range(1):
foo(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_reduce_map():
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3),
keepdims=True,
type="sum")
verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
axis=(1,),
keepdims=False,
type="max")
verify_reduce_map_ele(in_shape=(32, 128, 24),
axis=None,
keepdims=True,
type="sum")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(0, 2),
keepdims=False,
type="min")
if __name__ == "__main__":
test_reduce_map()
......@@ -73,12 +73,12 @@ print(tvm.lower(s, [A, B], simple_mode=True))
# Reduction Factoring and Parallelization
# ---------------------------------------
# One problem of building a reduction is that we cannot simply
# parallelize over the reduction axis. We need to devide the computation
# of the reduction, store the local reduction result in a temporal array.
# Before doing a reduction over the temp array.
# parallelize over the reduction axis. We need to divide the computation
# of the reduction, store the local reduction result in a temporal array
# before doing a reduction over the temp array.
#
# The rfactor primitive does such rewrite of the computation.
# In the following schedule, the result of B is write written to a temporary
# In the following schedule, the result of B is written to a temporary
# result B.rf. The factored dimension becomes the first dimension of B.rf.
#
s = tvm.create_schedule(B.op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment