Commit ee74d00e by Jon Soifer Committed by Yao Wang

[TOPI] Update softmax compute and CPU schedule (#3680)

* Update Softmax compute and CPU schedule

* Add C++ compute

* Fix schedule

* Update CUDA and OpenGL schedules

* Fix log_softmax

* Fix hls and opengl schedules

* Fix CUDA schedule
parent 7ce6a41d
...@@ -50,13 +50,32 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs ...@@ -50,13 +50,32 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs
auto s = create_schedule(out_ops); auto s = create_schedule(out_ops);
auto softmax = outs[0]; auto softmax = outs[0];
auto max_elem = softmax->op->InputTensors()[1]; tvm::Tensor max_elem;
auto expsum = softmax->op->InputTensors()[2]; tvm::Tensor expsum;
tvm::Tensor exp;
bool has_exp = false;
auto tag = softmax->op.as<ComputeOpNode>()->tag;
if (tag == "softmax_output") {
expsum = softmax->op->InputTensors()[1];
exp = softmax->op->InputTensors()[0];
max_elem = s[exp]->op->InputTensors()[1];
has_exp = true;
} else if (tag == "log_softmax_output") {
max_elem = softmax->op->InputTensors()[1];
expsum = softmax->op->InputTensors()[2];
} else {
LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
}
int num_thread = 64; int num_thread = 64;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x"); auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
if (has_exp) {
s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
}
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x); s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0]; auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
......
...@@ -62,6 +62,9 @@ inline Tensor softmax(const Tensor &x, ...@@ -62,6 +62,9 @@ inline Tensor softmax(const Tensor &x,
auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2"); auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
tvm::Map<std::string, NodeRef> attrs;
attrs.Set("axis", Integer(axis));
auto insert_reduce_index = [axis, ndim](const Array<Var> &indices, auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) { const IterVar &reduce_index) {
Array<Expr> eval_range; Array<Expr> eval_range;
...@@ -75,35 +78,48 @@ inline Tensor softmax(const Tensor &x, ...@@ -75,35 +78,48 @@ inline Tensor softmax(const Tensor &x,
return eval_range; return eval_range;
}; };
auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
Array<Expr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return non_reduce_indices;
};
auto _compute_max = [&](const Array<Var> &indices) { auto _compute_max = [&](const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k1); auto eval_range = insert_reduce_index(indices, k1);
return topi::MaxOp(x(eval_range), {k1}); return topi::MaxOp(x(eval_range), {k1});
}; };
auto _compute_expsum = [&](const Tensor &max_elem, auto _compute_exp = [&](const Tensor &max_elem,
const Array<Var> &indices) {
auto non_reduce_indices = get_non_reduce_indices(indices);
return tvm::exp(x(indices) - max_elem(non_reduce_indices));
};
auto _compute_expsum = [&](const Tensor &exp,
const Array<Var> &indices) { const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k2); auto eval_range = insert_reduce_index(indices, k2);
return tvm::sum(tvm::exp(x(eval_range) - max_elem(indices)), {k2}); return tvm::sum(exp(eval_range), {k2});
}; };
auto _normalize = [&](const Tensor &max_elem, const Tensor &expsum, auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
const Array<Var> &indices) { const Array<Var> &indices) {
Array<Expr> non_reduce_indices; auto non_reduce_indices = get_non_reduce_indices(indices);
for (size_t i = 0; i < ndim; ++i) { return exp(indices) / expsum(non_reduce_indices);
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return tvm::exp(x(indices) - max_elem(non_reduce_indices)) /
expsum(non_reduce_indices);
}; };
auto max_elem = tvm::compute(reduced_shape, _compute_max); auto max_elem = tvm::compute(reduced_shape, _compute_max);
auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _compute_exp(max_elem, indices);
});
auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) { auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
return _compute_expsum(max_elem, indices); return _compute_expsum(exp, indices);
}); });
return tvm::compute(input_shape, [&](const Array<Var> &indices) { return tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _normalize(max_elem, expsum, indices); return _normalize(exp, expsum, indices);
}, name, tag); }, name, tag, attrs);
} }
/*! /*!
......
...@@ -38,17 +38,35 @@ def schedule_softmax(outs): ...@@ -38,17 +38,35 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0] softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2] op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if len(softmax.shape) > 2: if len(softmax.shape) > 2:
for op in [max_elem.op, expsum.op, softmax.op]: ops = [max_elem.op, expsum.op, softmax.op]
if exp != None:
ops.append(exp.op)
for op in ops:
s = _schedule_injective(op, s) s = _schedule_injective(op, s)
else: else:
num_thread = 64 num_thread = 64
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
if exp != None:
s[exp].bind(exp.op.axis[0], block_x)
s[max_elem].bind(max_elem.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0] k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread) ko, ki = s[expsum].split(k, factor=num_thread)
......
...@@ -261,8 +261,22 @@ def schedule_softmax(outs): ...@@ -261,8 +261,22 @@ def schedule_softmax(outs):
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
softmax = outs[0] softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2] op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if exp != None:
s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
s[expsum].compute_at(s[softmax], s[softmax].op.axis[1]) s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1]) s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
......
...@@ -48,24 +48,33 @@ def softmax(x, axis=-1): ...@@ -48,24 +48,33 @@ def softmax(x, axis=-1):
def insert_reduce_index(indices, reduce_index): def insert_reduce_index(indices, reduce_index):
return indices[:axis] + (reduce_index,) + indices[axis:] return indices[:axis] + (reduce_index,) + indices[axis:]
def get_non_reduce_indices(indices):
return tuple([var for (i, var) in enumerate(indices) if i != axis])
def _compute_max(*indices): def _compute_max(*indices):
eval_range = insert_reduce_index(indices, k1) eval_range = insert_reduce_index(indices, k1)
return tvm.max(x[eval_range], axis=k1) return tvm.max(x[eval_range], axis=k1)
def _compute_expsum(max_elem, *indices): def _compute_exp(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return tvm.exp(x[indices] - max_elem[non_reduce_indices])
def _compute_expsum(exp, *indices):
eval_range = insert_reduce_index(indices, k2) eval_range = insert_reduce_index(indices, k2)
return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2) return tvm.sum(exp[eval_range], axis=k2)
def _normalize(max_elem, expsum, *indices): def _normalize(exp, expsum, *indices):
non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis]) non_reduce_indices = get_non_reduce_indices(indices)
return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices] return exp[indices] / expsum[non_reduce_indices]
reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis]) reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem') max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices), exp = tvm.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
name='T_softmax_exp')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
name='T_softmax_expsum') name='T_softmax_expsum')
return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices), return tvm.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
name='T_softmax_norm') name='T_softmax_norm', attrs={"axis" : axis})
@tvm.tag_scope(tag='log_softmax_output') @tvm.tag_scope(tag='log_softmax_output')
......
...@@ -37,8 +37,23 @@ def schedule_softmax(outs): ...@@ -37,8 +37,23 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0] softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2] op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if exp != None:
s[exp].opengl()
s[max_elem].opengl() s[max_elem].opengl()
s[expsum].opengl() s[expsum].opengl()
s[softmax].opengl() s[softmax].opengl()
......
...@@ -36,15 +36,34 @@ def schedule_softmax(outs): ...@@ -36,15 +36,34 @@ def schedule_softmax(outs):
The computation schedule for the op. The computation schedule for the op.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0] softmax = outs[0]
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5: op_tag = softmax.op.tag
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) if op_tag == 'softmax_output':
s[x].parallel(fused) exp = softmax.op.input_tensors[0]
elif len(s[x].op.axis) >= 3: expsum = softmax.op.input_tensors[1]
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) max_elem = s[exp].op.input_tensors[1]
s[x].parallel(fused) axis = int(softmax.op.attrs['axis'])
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
axis = 1
else: else:
s[x].parallel(s[x].op.axis[0]) raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
# only parallelize outer dimensions up to axis
outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[softmax].fuse(*outer_axes)
s[softmax].parallel(fused_outer_axes)
# move computations with the same outer dimensions under the same root
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)
if exp != None:
s[exp].compute_at(s[softmax], fused_outer_axes)
return s return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment