Commit 6edb3564 by Siju Committed by Tianqi Chen

[RELAY]sch & comp for ops in nn.py (#2092)

parent 7761416f
......@@ -327,7 +327,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size;
int size;
int axis;
double bias;
double alpha;
......
......@@ -17,6 +17,7 @@ def schedule_softmax(_, outputs, target):
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
schedule_broadcast = schedule_injective
@reg.register_schedule("nn.log_softmax")
def schedule_log_softmax(_, outputs, target):
......@@ -194,3 +195,47 @@ def schedule_global_avg_pool2d(_, outs, target):
return topi.generic.schedule_global_pool(outs)
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# leaky_relu
reg.register_schedule("nn.leaky_relu", schedule_broadcast)
reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)
# prelu
reg.register_schedule("nn.prelu", schedule_broadcast)
reg.register_pattern("nn.prelu", OpPattern.BROADCAST)
# flatten
reg.register_schedule("nn.batch_flatten", schedule_broadcast)
reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE)
# lrn
@reg.register_compute("nn.lrn")
def compute_lrn(attrs, inputs, out_dtype, target):
"""Compute definition of lrn"""
assert len(inputs) == 1
return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
attrs.alpha, attrs.beta, attrs.bias)]
@reg.register_schedule("nn.lrn")
def schedule_lrn(attrs, outs, target):
"""Schedule definition of lrn"""
with target:
return topi.generic.schedule_lrn(outs)
reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
# l2_normalize
@reg.register_compute("nn.l2_normalize")
def compute_l2_normalize(attrs, inputs, out_dtype, target):
"""Compute definition of l2 normalize"""
return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]
@reg.register_schedule("nn.l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with target:
return topi.generic.schedule_l2_normalize(outs)
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
......@@ -9,6 +9,7 @@
#include <tvm/relay/attrs/image.h>
#include <topi/nn.h>
#include <topi/nn/softmax.h>
#include <topi/nn/flatten.h>
#include <vector>
#include "../type_relations.h"
#include "../op_common.h"
......@@ -169,7 +170,15 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<LeakyReluAttrs>();
return Array<Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
});
TVM_REGISTER_NODE_TYPE(PReluAttrs);
......@@ -225,7 +234,15 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel);
.add_type_rel("PRelu", PReluRel)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<PReluAttrs>();
return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
});
TVM_REGISTER_API("relay.op.nn._make.softmax")
......@@ -365,7 +382,14 @@ Example::
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("BatchFlatten", BatchFlattenRel);
.add_type_rel("BatchFlatten", BatchFlattenRel)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
});
// relu
......@@ -398,7 +422,7 @@ RELAY_REGISTER_OP("nn.relu")
TVM_REGISTER_NODE_TYPE(LRNAttrs);
Expr MakeLRN(Expr data,
IndexExpr size,
int size,
int axis,
double alpha,
double beta,
......
......@@ -295,6 +295,25 @@ def test_flatten_infer_type():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32")
shape = (1, 5, 10, 10)
o_shape = (1, 500)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = relay.nn.batch_flatten(x)
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(o_shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = x_data.flatten().reshape(o_shape)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_pad_infer_type():
# entirely concrete case
n, c, h, w = 1, 2, 3, 4
......@@ -320,6 +339,29 @@ def test_lrn():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c , h, w))
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
size=5
axis=1
bias=0.5
alpha=.00001
beta=0.75
z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_l2_normalize():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", shape=(n, c , h, w))
......@@ -328,6 +370,26 @@ def test_l2_normalize():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c , h, w))
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
eps=0.001
axis=1
z = relay.nn.l2_normalize(x, eps=0.001, axis=[axis])
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = topi.testing.l2_normalize_python(x_data, eps, axis)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
if __name__ == "__main__":
test_pool2d()
......
......@@ -4,6 +4,7 @@ import tvm
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.testing import ctx_list
from nose.tools import raises
def test_zeros_ones():
......@@ -214,6 +215,25 @@ def test_infer_type_leaky_relu():
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = relay.nn.leaky_relu(x, alpha=0.1)
assert "alpha=0.1" in z.astext()
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = np.where(x_data > 0, x_data, x_data * 0.1)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
x = relay.var("data", relay.TensorType(data, dtype))
if alpha:
......@@ -230,6 +250,27 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or not alpha:
return
func = relay.Function([x, y], z)
x_data = np.random.uniform(low=-1, high=1, size=data).astype(dtype)
a_data = np.random.uniform(low=-1, high=1, size=alpha).astype(dtype)
if axis == 1:
ref_res = (x_data < 0) * (x_data * a_data.reshape(3, 1, 1)) + (x_data>=0) * x_data
else:
ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data>=0) * x_data
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, a_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data, a_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_infer_type_prelu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment