Unverified Commit 947ff307 by Tianqi Chen Committed by GitHub

[OP] right_shift (#1832)

parent d6ff734b
......@@ -213,6 +213,24 @@ def greater_equal(lhs, rhs):
return _make.greater_equal(lhs, rhs)
def right_shift(lhs, rhs):
"""Right shift with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.right_shift(lhs, rhs)
def concat(*args):
"""Concatenate the input tensors along the zero axis.
......
/*!
* Copyright (c) 2018 by Contributors
* \file binary.cc
* \brief binary broadcast operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include "../type_relations.h"
namespace tvm {
namespace relay {
#define RELAY_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel)
// Addition
RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("subtract")
.describe("Elementwise substract with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("right_shift")
.describe("Elementwise right shift with broadcasting")
.set_support_level(4);
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_support_level(SupportLevel) \
.add_type_rel("BroadcastComp", BroadcastCompRel);
RELAY_REGISTER_CMP_OP("equal", 4);
RELAY_REGISTER_CMP_OP("not_equal", 4);
RELAY_REGISTER_CMP_OP("less", 4);
RELAY_REGISTER_CMP_OP("less_equal", 4);
RELAY_REGISTER_CMP_OP("greater", 4);
RELAY_REGISTER_CMP_OP("greater_equal", 4);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file elemwise.cc
* \brief Elementwise operators.
* \file unary.cc
* \brief Unary operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
......@@ -64,68 +64,6 @@ RELAY_REGISTER_UNARY_OP("sqrt")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
// Addition
TVM_REGISTER_API("relay.op._make.add")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("add")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
.add_type_rel("Broadcast", BroadcastRel);
// def broadcast(s1, s2):
// ...
//
// input1: Tensor[dtype, s1]
// input2: Tensor[dtype, s2]
// output: Tensor[dtype, broadcast(s1, s2)]
// Addition
TVM_REGISTER_API("relay.op._make.subtract")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("subtract");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("subtract")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
.add_type_rel("Broadcast", BroadcastRel);
// def broadcast(s1, s2):
// ...
//
// input1: Tensor[dtype, s1]
// input2: Tensor[dtype, s2]
// output: Tensor[dtype, broadcast(s1, s2)]
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_support_level(SupportLevel) \
.add_type_rel("BroadcastComp", BroadcastCompRel);
RELAY_REGISTER_CMP_OP("equal", 4);
RELAY_REGISTER_CMP_OP("not_equal", 4);
RELAY_REGISTER_CMP_OP("less", 4);
RELAY_REGISTER_CMP_OP("less_equal", 4);
RELAY_REGISTER_CMP_OP("greater", 4);
RELAY_REGISTER_CMP_OP("greater_equal", 4);
// Concat
TVM_REGISTER_API("relay.op._make.concat")
......@@ -135,10 +73,10 @@ TVM_REGISTER_API("relay.op._make.concat")
});
RELAY_REGISTER_OP("concat")
.set_num_inputs(1)
.add_argument("tuple", "Tuple", "The tupled tensor arguments.")
.set_support_level(1)
.add_type_rel("Concat", ConcatRel);
.set_num_inputs(1)
.add_argument("tuple", "Tuple", "The tupled tensor arguments.")
.set_support_level(1)
.add_type_rel("Concat", ConcatRel);
} // namespace relay
} // namespace tvm
......@@ -20,5 +20,19 @@ def test_cmp_type():
assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1")
def test_binary_broadcast():
for op in [relay.right_shift]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
ib.ret(op(x.var, y.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
if __name__ == "__main__":
test_cmp_type()
test_binary_broadcast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment