Unverified Commit b9038343 by Tianqi Chen Committed by GitHub

[RELAY][OP] Move computes to cxx, enable concat as injective (#2166)

parent f8f06595
...@@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -236,18 +236,14 @@ class GraphRuntimeCodegen(ExprFunctor):
self.lowered_funcs.add(loweredf) self.lowered_funcs.add(loweredf)
inputs = [] inputs = []
tuple_arg_count = 0 # flatten tuple in the call.
for arg in call.args: for arg in call.args:
res = self.visit(arg)
if isinstance(arg.checked_type, TupleType): if isinstance(arg.checked_type, TupleType):
tuple_arg_count += 1 assert isinstance(res, tuple)
inputs.append(self.visit(arg)) inputs += res
# We need to specially handle tuple inputs and else:
# tuple output cases. inputs.append(res)
# Tuple input function(e.g. concat)
if tuple_arg_count:
assert len(call.args) == 1
assert isinstance(inputs[0], tuple)
inputs = list(inputs[0])
inputs = [x.to_json() for x in inputs] inputs = [x.to_json() for x in inputs]
op_name = cached_func.func_name op_name = cached_func.func_name
......
...@@ -589,11 +589,11 @@ def from_mxnet(symbol, ...@@ -589,11 +589,11 @@ def from_mxnet(symbol,
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(symbol, shape, dtype) sym = _from_mxnet_impl(symbol, shape, dtype)
elif isinstance(symbol, mx.gluon.HybridBlock): elif isinstance(symbol, mx.gluon.HybridBlock):
if args_params is not None or aux_params is not None: if arg_params is not None or aux_params is not None:
raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
params = {} params = {}
for k, v in symbol.collect_params().items(): for k, v in symbol.collect_params().items():
params[k] = tvm.nd.array(v.data().asnumpy()) params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data") data = mx.sym.Variable("data")
sym = symbol(data) sym = symbol(data)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
......
...@@ -5,223 +5,37 @@ import topi ...@@ -5,223 +5,37 @@ import topi
from .op import register_compute, register_schedule, register_pattern from .op import register_compute, register_schedule, register_pattern
from .op import schedule_injective, OpPattern from .op import schedule_injective, OpPattern
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
# log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]
register_schedule("log", schedule_broadcast) register_schedule("log", schedule_broadcast)
# exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]
register_schedule("exp", schedule_broadcast) register_schedule("exp", schedule_broadcast)
# sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]
register_schedule("sqrt", schedule_broadcast) register_schedule("sqrt", schedule_broadcast)
# sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]
register_schedule("sigmoid", schedule_broadcast) register_schedule("sigmoid", schedule_broadcast)
# floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]
register_schedule("floor", schedule_broadcast) register_schedule("floor", schedule_broadcast)
# ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]
register_schedule("ceil", schedule_broadcast) register_schedule("ceil", schedule_broadcast)
# trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]
register_schedule("trunc", schedule_broadcast) register_schedule("trunc", schedule_broadcast)
# round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]
register_schedule("round", schedule_broadcast) register_schedule("round", schedule_broadcast)
# abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]
register_schedule("abs", schedule_broadcast) register_schedule("abs", schedule_broadcast)
# tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]
register_schedule("tanh", schedule_broadcast) register_schedule("tanh", schedule_broadcast)
# negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]
register_schedule("negative", schedule_broadcast) register_schedule("negative", schedule_broadcast)
# add register_schedule("add", schedule_broadcast)
@register_compute("add")
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
register_schedule("add", schedule_injective)
# subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
register_schedule("subtract", schedule_broadcast) register_schedule("subtract", schedule_broadcast)
# multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
register_schedule("multiply", schedule_broadcast) register_schedule("multiply", schedule_broadcast)
# divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]
register_schedule("divide", schedule_broadcast) register_schedule("divide", schedule_broadcast)
# power
@register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]
register_schedule("power", schedule_injective) register_schedule("power", schedule_injective)
# mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]
register_schedule("mod", schedule_broadcast) register_schedule("mod", schedule_broadcast)
# equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
register_schedule("equal", schedule_broadcast) register_schedule("equal", schedule_broadcast)
# not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]
register_schedule("not_equal", schedule_broadcast) register_schedule("not_equal", schedule_broadcast)
# less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]
register_schedule("less", schedule_broadcast) register_schedule("less", schedule_broadcast)
# less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]
register_schedule("less_equal", schedule_broadcast) register_schedule("less_equal", schedule_broadcast)
# greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]
register_schedule("greater", schedule_broadcast) register_schedule("greater", schedule_broadcast)
# greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]
register_schedule("greater_equal", schedule_broadcast) register_schedule("greater_equal", schedule_broadcast)
# maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]
register_schedule("maximum_compute", schedule_injective) register_schedule("maximum_compute", schedule_injective)
# minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]
register_schedule("minimum", schedule_injective) register_schedule("minimum", schedule_injective)
# right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]
register_schedule("right_shift", schedule_injective) register_schedule("right_shift", schedule_injective)
# left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]
register_schedule("left_shift", schedule_injective) register_schedule("left_shift", schedule_injective)
# zeros # zeros
...@@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target): ...@@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)] return [topi.concatenate(inputs, axis=attrs.axis)]
register_schedule("concatenate", schedule_injective) register_schedule("concatenate", schedule_injective)
# TODO(tqchen): renable concat as injective register_pattern("concatenate", OpPattern.INJECTIVE)
register_pattern("concatenate", OpPattern.OPAQUE)
...@@ -56,30 +56,26 @@ class ScheduleGetter : ...@@ -56,30 +56,26 @@ class ScheduleGetter :
Op::GetAttr<FTVMSchedule>("FTVMSchedule"); Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_node<CachedFuncNode>(); auto cache_node = make_node<CachedFuncNode>();
cache_node->target = target_; cache_node->target = target_;
for (Var param : prim_func->params) {
if (prim_func->params.size() == 1 &&
prim_func->params[0]->checked_type().as<TupleTypeNode>()) {
// Handle tuple input type by flattening them.
// This is the current calling convention of tuple input.
Array<tvm::Tensor> inputs; Array<tvm::Tensor> inputs;
for (Type field : prim_func->params[0]->type_as<TupleTypeNode>()->fields) { if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder( tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype); GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor); cache_node->inputs.push_back(tensor);
inputs.push_back(tensor); inputs.push_back(tensor);
} else {
// flatten tuple of tensor type.
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
} }
memo_[prim_func->params[0]] = inputs; memo_[param] = inputs;
} else {
for (Var param : prim_func->params) {
const auto* ttype = param->type_as<TensorTypeNode>();
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
memo_[param] = Array<Tensor>({tensor});
}
} }
readable_name_stream_ << "fused"; readable_name_stream_ << "fused";
cache_node->outputs = this->VisitExpr(prim_func->body); cache_node->outputs = this->VisitExpr(prim_func->body);
...@@ -161,8 +157,9 @@ class ScheduleGetter : ...@@ -161,8 +157,9 @@ class ScheduleGetter :
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) { if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined()) CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce)
<< "Two complicated op in a primitive function"; << "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
} }
if (op_pattern >= master_op_patetrn_) { if (op_pattern >= master_op_patetrn_) {
master_op_ = op; master_op_ = op;
......
...@@ -212,7 +212,7 @@ class Interpreter : ...@@ -212,7 +212,7 @@ class Interpreter :
// Marshal the arguments. // Marshal the arguments.
// Handle tuple input/output by flattening them. // Handle tuple input/output by flattening them.
size_t arg_len = 0; size_t arg_len = 0;
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<TensorValueNode>()) { if (args[i].as<TensorValueNode>()) {
++arg_len; ++arg_len;
} else { } else {
...@@ -242,22 +242,19 @@ class Interpreter : ...@@ -242,22 +242,19 @@ class Interpreter :
<< context_ << ", but get " << arg_ctx; << context_ << ", but get " << arg_ctx;
}; };
if (func->params.size() == 1 && int arg_counter = 0;
func->params[0]->checked_type().as<TupleTypeNode>()) { for (Value arg : args) {
// handle tuple input. if (arg.as<TensorValueNode>()) {
const TupleValueNode* tuple = args[0].as<TupleValueNode>(); fset_input(arg_counter++, arg);
CHECK(tuple); } else {
for (size_t i = 0; i < tuple->fields.size(); ++i) { const TupleValueNode* tuple = arg.as<TupleValueNode>();
fset_input(i, tuple->fields[i]); CHECK(tuple != nullptr);
} for (size_t i = 0; i < tuple->fields.size(); ++i) {
} else { fset_input(arg_counter++, tuple->fields[i]);
CHECK_EQ(num_inputs, args.size()); }
// Decide the target context.
// Primitive functions always sit in the same context.
for (size_t i = 0; i < args.size(); i++) {
fset_input(i, args[i]);
} }
} }
// TVM's calling convention is that the final argument is the output // TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language // buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the // we need to allocate space for the output buffer based on the
......
...@@ -5,54 +5,75 @@ ...@@ -5,54 +5,75 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <topi/broadcast.h>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
#define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
} \
// Addition // Addition
RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") RELAY_REGISTER_BINARY_OP("relay.op._make.", "add")
.describe("Elementwise add with with broadcasting") .describe("Elementwise add with with broadcasting")
.set_support_level(1); .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
// Subtraction // Subtraction
RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract")
.describe("Elementwise substract with broadcasting") .describe("Elementwise substract with broadcasting")
.set_support_level(1); .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
// Right shift // Right shift
RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift")
.describe("Elementwise right shift with broadcasting") .describe("Elementwise right shift with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift") RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift")
.describe("Elementwise left shift with broadcasting") .describe("Elementwise left shift with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum") RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum")
.describe("Elementwise maximum of two tensors with broadcasting") .describe("Elementwise maximum of two tensors with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum") RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum")
.describe("Elementwise minimum of two tensors with broadcasting") .describe("Elementwise minimum of two tensors with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide") RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide")
.describe("Elementwise divide with broadcasting") .describe("Elementwise divide with broadcasting")
.set_support_level(1); .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
.describe("Elementwise multiply with broadcasting") .describe("Elementwise multiply with broadcasting")
.set_support_level(1); .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power") RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
.describe("Elementwise power with broadcasting") .describe("Elementwise power with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
.describe("Elementwise mod with broadcasting") .describe("Elementwise mod with broadcasting")
.set_support_level(1); .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
// Comparisons // Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \ #define RELAY_REGISTER_CMP_OP(OpName) \
...@@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") ...@@ -70,22 +91,38 @@ RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
RELAY_REGISTER_CMP_OP("equal") RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting") .describe("Elementwise equal compare with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal));
RELAY_REGISTER_CMP_OP("not_equal") RELAY_REGISTER_CMP_OP("not_equal")
.describe("Elementwise not equal with broadcasting") .describe("Elementwise not equal with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal));
RELAY_REGISTER_CMP_OP("less") RELAY_REGISTER_CMP_OP("less")
.describe("Elementwise less than with broadcasting") .describe("Elementwise less than with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less));
RELAY_REGISTER_CMP_OP("less_equal") RELAY_REGISTER_CMP_OP("less_equal")
.describe("Elementwise less than or equal compare with broadcasting") .describe("Elementwise less than or equal compare with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal));
RELAY_REGISTER_CMP_OP("greater") RELAY_REGISTER_CMP_OP("greater")
.describe("Elementwise greater than compare with broadcasting") .describe("Elementwise greater than compare with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater));
RELAY_REGISTER_CMP_OP("greater_equal") RELAY_REGISTER_CMP_OP("greater_equal")
.describe("Elementwise greater than or equal compare with broadcasting") .describe("Elementwise greater than or equal compare with broadcasting")
.set_support_level(4); .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal));
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -5,12 +5,21 @@ ...@@ -5,12 +5,21 @@
*/ */
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <topi/elemwise.h>
#include "../type_relations.h" #include "../type_relations.h"
#include "../op_common.h" #include "../op_common.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
#define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<Tensor> { \
return {FTOPI(inputs[0])}; \
} \
RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
.describe(R"code(Returns the log input array, computed element-wise. .describe(R"code(Returns the log input array, computed element-wise.
...@@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") ...@@ -20,7 +29,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
.describe(R"code(Returns the exp input array, computed element-wise. .describe(R"code(Returns the exp input array, computed element-wise.
...@@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") ...@@ -30,7 +41,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
...@@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") ...@@ -41,7 +53,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input. .describe(R"code(Returns an array of zeros, with same type and shape as the input.
...@@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") ...@@ -49,6 +63,7 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like") RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like")
.describe(R"code(Returns an array of ones, with same type and shape as the input. .describe(R"code(Returns an array of ones, with same type and shape as the input.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
...@@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") ...@@ -63,13 +78,17 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy")
.describe(R"code(Copy a tensor. .describe(R"code(Copy a tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
// Clip // Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
...@@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") ...@@ -107,7 +126,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor")
.describe(R"code(Returns the floor of input array, computed element-wise. .describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
.describe(R"code(Returns the ceil of input array, computed element-wise. .describe(R"code(Returns the ceil of input array, computed element-wise.
...@@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") ...@@ -117,7 +138,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
.describe(R"code(Returns the trunc of input array, computed element-wise. .describe(R"code(Returns the trunc of input array, computed element-wise.
...@@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") ...@@ -127,7 +150,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
.describe(R"code(Returns the round of input array, computed element-wise. .describe(R"code(Returns the round of input array, computed element-wise.
...@@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") ...@@ -137,7 +162,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
.describe(R"code(Returns the abs of input array, computed element-wise. .describe(R"code(Returns the abs of input array, computed element-wise.
...@@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") ...@@ -147,7 +174,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
.describe(R"code(Returns the tanh of input array, computed element-wise. .describe(R"code(Returns the tanh of input array, computed element-wise.
...@@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") ...@@ -157,7 +186,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise. .describe(R"code(Returns the numeric negative of input array, computed element-wise.
...@@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") ...@@ -167,7 +198,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(3) .set_support_level(3)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -188,20 +188,22 @@ def test_concatenate(): ...@@ -188,20 +188,22 @@ def test_concatenate():
x = relay.var("x", shape=(10, 5)) x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5)) y = relay.var("y", shape=(10, 5))
t = relay.var("z", shape=())
z = relay.concatenate((x, y), axis=1) z = relay.concatenate((x, y), axis=1)
z = relay.add(z, t)
# Check result. # Check result.
func = relay.Function([x, y], z) func = relay.Function([x, y, t], z)
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(10, 5).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1) t_data = np.random.uniform(size=()).astype('float32')
ref_res = np.concatenate((x_data, y_data), axis=1) + t_data
for target, ctx in ctx_list(): for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, y_data) op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01)
op_res2 = intrp2.evaluate(func)(x_data, y_data) op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01)
def test_dropout(): def test_dropout():
...@@ -306,11 +308,11 @@ def test_dense(): ...@@ -306,11 +308,11 @@ def test_dense():
if __name__ == "__main__": if __name__ == "__main__":
test_concatenate()
test_bias_add() test_bias_add()
test_unary_op() test_unary_op()
test_binary_op() test_binary_op()
test_expand_dims_infer_type() test_expand_dims_infer_type()
test_concatenate()
test_expand_dims() test_expand_dims()
test_softmax() test_softmax()
test_log_softmax() test_log_softmax()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment