Commit 28acb184 by Tatsuya Nishiyama Committed by Tianqi Chen

[NNVM] Move FTVMCompute registration of cast, greter, less to C++. (#1370)

parent 2f4db1b3
...@@ -53,11 +53,6 @@ reg.register_pattern("copy", OpPattern.ELEMWISE) ...@@ -53,11 +53,6 @@ reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast) reg.register_schedule("copy", _fschedule_broadcast)
# cast # cast
@reg.register_compute("cast")
def compute_cast(attrs, inputs, _):
"""Compute definition of cast"""
dtype = attrs.get_string("dtype")
return topi.cast(inputs[0], dtype)
reg.register_pattern("cast", OpPattern.ELEMWISE) reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast) reg.register_schedule("cast", _fschedule_broadcast)
...@@ -210,18 +205,10 @@ reg.register_pattern("ones_like", OpPattern.ELEMWISE) ...@@ -210,18 +205,10 @@ reg.register_pattern("ones_like", OpPattern.ELEMWISE)
reg.register_schedule("ones_like", _fschedule_elemwise) reg.register_schedule("ones_like", _fschedule_elemwise)
# greater # greater
@reg.register_compute("greater")
def compute_greater(_, inputs, out_info):
"""Compute definition of greater"""
return topi.greater(inputs[0], inputs[1]).astype('float32')
reg.register_pattern("greater", OpPattern.ELEMWISE) reg.register_pattern("greater", OpPattern.ELEMWISE)
reg.register_schedule("greater", _fschedule_elemwise) reg.register_schedule("greater", _fschedule_elemwise)
# less # less
@reg.register_compute("less")
def compute_less(_, inputs, out_info):
"""Compute definition of less"""
return topi.less(inputs[0], inputs[1]).astype('float32')
reg.register_pattern("less", OpPattern.ELEMWISE) reg.register_pattern("less", OpPattern.ELEMWISE)
reg.register_schedule("less", _fschedule_elemwise) reg.register_schedule("less", _fschedule_elemwise)
......
...@@ -781,6 +781,12 @@ with 1.0 if (left > right), otherwise 0.0 element-wise. ...@@ -781,6 +781,12 @@ with 1.0 if (left > right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input") .add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::greater(inputs[0], inputs[1]), out_info[0]->dtype) };
})
.set_support_level(4); .set_support_level(4);
...@@ -793,6 +799,12 @@ with 1.0 if (left < right), otherwise 0.0 element-wise. ...@@ -793,6 +799,12 @@ with 1.0 if (left < right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input") .add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::cast(topi::less(inputs[0], inputs[1]), out_info[0]->dtype) };
})
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INDICATOR_OP(_max_mask) NNVM_REGISTER_INDICATOR_OP(_max_mask)
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/flatten.h" #include "topi/nn/flatten.h"
#include "topi/transform.h" #include "topi/transform.h"
#include "topi/elemwise.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "../../compiler/compile_engine.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
...@@ -413,6 +415,14 @@ NNVM_REGISTER_OP(cast) ...@@ -413,6 +415,14 @@ NNVM_REGISTER_OP(cast)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const CastParam& param = nnvm::get<CastParam>(attrs.parsed);
Type dtype = GetTVMType(param.dtype);
return Array<Tensor>{ topi::cast(inputs[0], dtype) };
})
.set_support_level(1); .set_support_level(1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment