Commit ab0d1862 by Tatsuya Nishiyama Committed by Tianqi Chen

[NNVM] Move FTVMCompute registration of the elementwise operator to c++ (#1351)

parent b154e6b9
......@@ -182,69 +182,30 @@ reg.register_pattern("clip", OpPattern.ELEMWISE)
reg.register_schedule("clip", _fschedule_elemwise)
# elemwise sum
@reg.register_compute("elemwise_sum")
def compute_elemwise_sum(attrs, inputs, _):
"""Compute definition of elemwise sum"""
num_args = attrs.get_int("num_args")
assert num_args == len(inputs), "Number of tensors does not match num_args."
return topi.tensor.elemwise_sum(inputs)
reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
reg.register_schedule("elemwise_sum", _fschedule_elemwise)
# full
@reg.register_compute("full")
def compute_full(attrs, inputs, _):
"""Compute definition of full"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
fill_value = attrs.get_float("fill_value")
return topi.tensor.full(shape, dtype, fill_value)
reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("full", _fschedule_elemwise)
# full_like
@reg.register_compute("full_like")
def compute_full_like(attrs, inputs, _):
"""Compute definition of full_like"""
fill_value = attrs.get_float("fill_value")
return topi.tensor.full_like(inputs[0], fill_value)
reg.register_pattern("full_like", OpPattern.ELEMWISE)
reg.register_schedule("full_like", _fschedule_elemwise)
# zeros
@reg.register_compute("zeros")
def compute_zeros(attrs, inputs, _):
"""Compute definition of zeros"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
return topi.tensor.full(shape, dtype, 0)
reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("zeros", _fschedule_elemwise)
# zeros_like
@reg.register_compute("zeros_like")
def compute_zeros_like(_, inputs, out_info):
"""Compute definition of zeros_like"""
return topi.tensor.full_like(inputs[0], 0)
reg.register_pattern("zeros_like", OpPattern.ELEMWISE)
reg.register_schedule("zeros_like", _fschedule_elemwise)
# ones
@reg.register_compute("ones")
def compute_ones(attrs, inputs, _):
"""Compute definition of ones"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
#tvm.tensor.Tensor()
return topi.tensor.full(shape, dtype, 1)
reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("ones", _fschedule_elemwise)
# ones_like
@reg.register_compute("ones_like")
def compute_ones_like(_, inputs, out_info):
"""Compute definition of ones_like"""
return topi.tensor.full_like(inputs[0], 1)
reg.register_pattern("ones_like", OpPattern.ELEMWISE)
reg.register_schedule("ones_like", _fschedule_elemwise)
......
......@@ -7,6 +7,7 @@
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include <cmath>
#include "../op_common.h"
......@@ -14,6 +15,7 @@
#include "topi/broadcast.h"
#include "topi/elemwise.h"
#include "topi/tags.h"
#include "../../compiler/compile_engine.h"
namespace nnvm {
namespace top {
......@@ -382,6 +384,16 @@ NNVM_REGISTER_INIT_OP(full)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpWithScalarParam& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, param.fill_value);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros)
......@@ -395,6 +407,16 @@ NNVM_REGISTER_INIT_OP(zeros)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 0);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);
NNVM_REGISTER_INIT_OP(ones)
......@@ -408,6 +430,16 @@ NNVM_REGISTER_INIT_OP(ones)
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
Type dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 1);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
.set_support_level(4);
// full_like
......@@ -419,6 +451,14 @@ as the input array
.add_arguments(FillValueParam::__FIELDS__())
.set_attr_parser(ParamParser<FillValueParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const FillValueParam& param = nnvm::get<FillValueParam>(attrs.parsed);
const Expr fill_value = tvm::make_const(out_info[0]->dtype, param.fill_value);
return Array<Tensor> { topi::full_like(inputs[0], fill_value) };
})
.set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
......@@ -426,6 +466,13 @@ NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
as the input array.
)code")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor> { topi::full_like(inputs[0],
tvm::make_const(out_info[0]->dtype, 0)) };
})
.set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(ones_like)
......@@ -433,6 +480,13 @@ NNVM_REGISTER_INIT_LIKE_OP(ones_like)
as the input array.
)code")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor> { topi::full_like(inputs[0],
tvm::make_const(out_info[0]->dtype, 1)) };
})
.set_support_level(4);
// unary scalar op
......@@ -684,6 +738,14 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum)
.describe(R"code(Adds all input arguments element-wise.
)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ElementWiseReduceParam& param = nnvm::get<ElementWiseReduceParam>(attrs.parsed);
CHECK_EQ(param.num_args, inputs.size()) << """Compute definition of elemwise sum""";
return Array<Tensor>{ topi::elemwise_sum(inputs) };
})
.set_attr<nnvm::FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment