Commit cb8a70f4 by Yizhi Liu Committed by Tianqi Chen

[Relay] compute & schedule for relu, softmax (#2127)

parent ade98e14
......@@ -5,6 +5,20 @@ from topi.util import get_const_int, get_const_tuple
from .. import op as reg
from ..op import OpPattern, schedule_injective
# relu
reg.register_schedule("nn.relu", schedule_injective)
reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
@reg.register_schedule("nn.softmax")
def schedule_softmax(_, outputs, target):
"""Schedule definition of softmax"""
with target:
return topi.generic.schedule_softmax(outputs)
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
# dense
@reg.register_compute("nn.dense")
def compute_dense(attrs, inputs, out_type, target):
......
......@@ -7,6 +7,8 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h>
#include <topi/nn.h>
#include <topi/nn/softmax.h>
#include <vector>
#include "../type_relations.h"
#include "../op_common.h"
......@@ -252,7 +254,15 @@ RELAY_REGISTER_OP("nn.softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
});
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
......@@ -364,7 +374,13 @@ RELAY_REGISTER_OP("nn.relu")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
});
// Positional relay function to create LRN operator used by frontend FFI.
......
......@@ -3,6 +3,7 @@ import tvm
import numpy as np
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing
def sigmoid(x):
one = np.ones_like(x)
......@@ -42,7 +43,7 @@ def test_unary_op():
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, None)]: # Just add RELU here after registering.
(relay.nn.relu, relu)]:
check_single_op(opfunc, ref)
......@@ -120,12 +121,19 @@ def test_expand_dims_infer_type():
def test_softmax():
n, d = tvm.var("n"), tvm.var("d")
x = relay.var("x", shape=(n, d))
shape = (10, 4)
x = relay.var("x", shape=shape)
y = relay.nn.softmax(x, axis=1)
assert "nn.softmax" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, d))
assert yy.checked_type == relay.TensorType(shape)
func = relay.Function([x], y)
x_data = np.random.uniform(size=shape).astype("float32")
ref_res = topi.testing.softmax_python(x_data)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def test_log_softmax():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment