Commit ee74d00e by Jon Soifer Committed by Yao Wang

[TOPI] Update softmax compute and CPU schedule (#3680)

* Update Softmax compute and CPU schedule

* Add C++ compute

* Fix schedule

* Update CUDA and OpenGL schedules

* Fix log_softmax

* Fix hls and opengl schedules

* Fix CUDA schedule
parent 7ce6a41d
......@@ -50,13 +50,32 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs
auto s = create_schedule(out_ops);
auto softmax = outs[0];
auto max_elem = softmax->op->InputTensors()[1];
auto expsum = softmax->op->InputTensors()[2];
tvm::Tensor max_elem;
tvm::Tensor expsum;
tvm::Tensor exp;
bool has_exp = false;
auto tag = softmax->op.as<ComputeOpNode>()->tag;
if (tag == "softmax_output") {
expsum = softmax->op->InputTensors()[1];
exp = softmax->op->InputTensors()[0];
max_elem = s[exp]->op->InputTensors()[1];
has_exp = true;
} else if (tag == "log_softmax_output") {
max_elem = softmax->op->InputTensors()[1];
expsum = softmax->op->InputTensors()[2];
} else {
LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
}
int num_thread = 64;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
if (has_exp) {
s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
}
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
......
......@@ -62,6 +62,9 @@ inline Tensor softmax(const Tensor &x,
auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
tvm::Map<std::string, NodeRef> attrs;
attrs.Set("axis", Integer(axis));
auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) {
Array<Expr> eval_range;
......@@ -75,35 +78,48 @@ inline Tensor softmax(const Tensor &x,
return eval_range;
};
auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
Array<Expr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return non_reduce_indices;
};
auto _compute_max = [&](const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k1);
return topi::MaxOp(x(eval_range), {k1});
};
auto _compute_expsum = [&](const Tensor &max_elem,
auto _compute_exp = [&](const Tensor &max_elem,
const Array<Var> &indices) {
auto non_reduce_indices = get_non_reduce_indices(indices);
return tvm::exp(x(indices) - max_elem(non_reduce_indices));
};
auto _compute_expsum = [&](const Tensor &exp,
const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k2);
return tvm::sum(tvm::exp(x(eval_range) - max_elem(indices)), {k2});
return tvm::sum(exp(eval_range), {k2});
};
auto _normalize = [&](const Tensor &max_elem, const Tensor &expsum,
auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
const Array<Var> &indices) {
Array<Expr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
}
return tvm::exp(x(indices) - max_elem(non_reduce_indices)) /
expsum(non_reduce_indices);
auto non_reduce_indices = get_non_reduce_indices(indices);
return exp(indices) / expsum(non_reduce_indices);
};
auto max_elem = tvm::compute(reduced_shape, _compute_max);
auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _compute_exp(max_elem, indices);
});
auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
return _compute_expsum(max_elem, indices);
return _compute_expsum(exp, indices);
});
return tvm::compute(input_shape, [&](const Array<Var> &indices) {
return _normalize(max_elem, expsum, indices);
}, name, tag);
return _normalize(exp, expsum, indices);
}, name, tag, attrs);
}
/*!
......
......@@ -38,17 +38,35 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if len(softmax.shape) > 2:
for op in [max_elem.op, expsum.op, softmax.op]:
ops = [max_elem.op, expsum.op, softmax.op]
if exp != None:
ops.append(exp.op)
for op in ops:
s = _schedule_injective(op, s)
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
if exp != None:
s[exp].bind(exp.op.axis[0], block_x)
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
......
......@@ -261,8 +261,22 @@ def schedule_softmax(outs):
tvm.schedule.AutoInlineInjective(s)
softmax = outs[0]
op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if exp != None:
s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
......
......@@ -48,24 +48,33 @@ def softmax(x, axis=-1):
def insert_reduce_index(indices, reduce_index):
return indices[:axis] + (reduce_index,) + indices[axis:]
def get_non_reduce_indices(indices):
return tuple([var for (i, var) in enumerate(indices) if i != axis])
def _compute_max(*indices):
eval_range = insert_reduce_index(indices, k1)
return tvm.max(x[eval_range], axis=k1)
def _compute_expsum(max_elem, *indices):
def _compute_exp(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return tvm.exp(x[indices] - max_elem[non_reduce_indices])
def _compute_expsum(exp, *indices):
eval_range = insert_reduce_index(indices, k2)
return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2)
return tvm.sum(exp[eval_range], axis=k2)
def _normalize(max_elem, expsum, *indices):
non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis])
return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
def _normalize(exp, expsum, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return exp[indices] / expsum[non_reduce_indices]
reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices),
exp = tvm.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
name='T_softmax_exp')
expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
name='T_softmax_expsum')
return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices),
name='T_softmax_norm')
return tvm.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
name='T_softmax_norm', attrs={"axis" : axis})
@tvm.tag_scope(tag='log_softmax_output')
......
......@@ -37,8 +37,23 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
op_tag = softmax.op.tag
if op_tag == 'softmax_output':
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
if exp != None:
s[exp].opengl()
s[max_elem].opengl()
s[expsum].opengl()
s[softmax].opengl()
......
......@@ -36,15 +36,34 @@ def schedule_softmax(outs):
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0]
softmax = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
op_tag = softmax.op.tag
if op_tag == 'softmax_output':
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
max_elem = s[exp].op.input_tensors[1]
axis = int(softmax.op.attrs['axis'])
elif op_tag == 'log_softmax_output':
exp = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
axis = 1
else:
s[x].parallel(s[x].op.axis[0])
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
# only parallelize outer dimensions up to axis
outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[softmax].fuse(*outer_axes)
s[softmax].parallel(fused_outer_axes)
# move computations with the same outer dimensions under the same root
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)
if exp != None:
s[exp].compute_at(s[softmax], fused_outer_axes)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment