Commit b71edd76 by Animesh Jain Committed by Tianqi Chen

Relay Op sprint (part 2) - Level 1 - log_softmax (#2128)

parent 81da33f8
......@@ -9,7 +9,6 @@ from ..op import OpPattern, schedule_injective
reg.register_schedule("nn.relu", schedule_injective)
reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
@reg.register_schedule("nn.softmax")
def schedule_softmax(_, outputs, target):
"""Schedule definition of softmax"""
......@@ -19,6 +18,15 @@ def schedule_softmax(_, outputs, target):
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
@reg.register_schedule("nn.log_softmax")
def schedule_log_softmax(_, outputs, target):
"""Schedule definition of log_softmax"""
with target:
return topi.generic.schedule_softmax(outputs)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
# dense
@reg.register_compute("nn.dense")
def compute_dense(attrs, inputs, out_type, target):
......
......@@ -291,7 +291,18 @@ RELAY_REGISTER_OP("nn.log_softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
<< "log_softmax currently only works on last dimension";
return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
});
// BatchFlatten
......
......@@ -137,12 +137,19 @@ def test_softmax():
def test_log_softmax():
n, d = tvm.var("n"), tvm.var("d")
x = relay.var("x", shape=(n, d))
y = relay.nn.log_softmax(x, axis=0)
shape = (10, 4)
x = relay.var("x", shape=shape)
y = relay.nn.log_softmax(x, axis=1)
assert "nn.log_softmax" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, d))
assert yy.checked_type == relay.TensorType(shape)
func = relay.Function([x], y)
x_data = np.random.uniform(size=shape).astype("float32")
ref_res = topi.testing.log_softmax_python(x_data)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_concatenate():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment