Commit 0f053c82 by Josh Pollock Committed by Tianqi Chen

[Relay][Op] Clip (#1844)

parent 4d05fd96
......@@ -515,6 +515,35 @@ def ones_like(data):
"""
return _make.ones_like(data)
def clip(a, a_min, a_max):
"""Clip the elements in `a` between `a_min` and `a_max`.
`a_min` and `a_max` are cast to `a`'s dtype.
Parameters
----------
a : relay.Expr
The input tensor.
a_min : float
The clip minimum.
a_max : float
The clip maximum.
Returns
-------
result : relay.Expr
`a` with elements clipped between `a_min` and `a_max`.
Examples
--------
.. code:: python
x = relay.Constant(tvm.nd.array([0, 1, 5, 3, 4, 2]))
relay.clip(x, 1., 4.)
# [1, 1, 4, 3, 4, 2]
"""
return _make.clip(a, a_min, a_max)
def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
......
......@@ -87,6 +87,37 @@ RELAY_REGISTER_UNARY_OP("copy")
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};
TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("clip")
.describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Clip", IdentityRel);
RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
......@@ -153,6 +184,5 @@ RELAY_REGISTER_UNARY_OP("negative")
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);
} // namespace relay
} // namespace tvm
......@@ -19,6 +19,18 @@ def test_unary_identity():
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")
def test_clip_type():
ib = relay.ir_builder.IRBuilder()
a = ib.param("a", relay.TensorType((10, 4), "float32"))
with ib.function(a) as func:
ib.ret(relay.clip(a.var, 1., 4.))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "float32")
def test_copy_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
......@@ -57,6 +69,7 @@ def test_reshape_infer_type():
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 2000), "float32")
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
......@@ -78,9 +91,11 @@ def test_single_op():
tvm.relay.round, tvm.relay.abs, tvm.relay.negative]:
check_single_op(opfunc)
if __name__ == "__main__":
test_single_op()
test_unary_identity()
test_clip_type()
test_copy_infer_type()
test_transpose_infer_type()
test_reshape_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment