/*! * Copyright (c) 2017 by Contributors * \file Operator Declarations. */ #include <nnvm/op.h> #include <nnvm/op_attr_types.h> #include "./op_attr_types.h" namespace tvm { namespace contrib { using namespace nnvm; inline bool SameShape(const NodeAttrs& attrs, std::vector<TShape> *ishape, std::vector<TShape> *oshape) { if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false; for (TShape& pshape : *oshape) { pshape = (*ishape)[0]; } for (TShape& pshape : *ishape) { pshape = (*ishape)[0]; } return true; } NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) .set_attr<TOpPattern>("TOpPattern", kBroadcast) .set_attr<FInferShape>("FInferShape", SameShape); NNVM_REGISTER_OP(__add_symbol__) .describe("add two data together") .set_num_inputs(2) .include("ElementwiseOpAttr"); NNVM_REGISTER_OP(exp) .describe("Take exp") .set_num_inputs(1) .include("ElementwiseOpAttr"); } // namespace contrib } // namespace tvm