state_op.cc 2.09 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
/*!
 *  Copyright (c) 2018 by Contributors
 * \file state_op.cc
 * \brief Experimental operators
 *   Currently we only support assign
 */
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include <topi/elemwise.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {

using namespace tvm;
using namespace nnvm::compiler;

NNVM_REGISTER_OP(_assign)
.describe(R"doc(Assign rhs to the lhs.

lhs must be a Variable.
This is an experimental operator.

)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FMutateInputs>(
  "FMutateInputs", [](const NodeAttrs& attrs) {
    return std::vector<uint32_t>{0};
})
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    // This implementation is needed for the special
    // logic handling assign in the compiler
    // It simply copies the result of rhs the output
    // The later decoration in compiler will change
    // the memory assignment of assign to tie
    // the lhs to the output.
    return Array<Tensor>{ topi::identity(inputs[1]) };
})
.set_attr<FInferShape>("FInferShape", SameShape)
48 49
.set_attr<FCorrectLayout>(
  "FCorrectLayout", [](const NodeAttrs& attrs,
50 51 52 53 54 55 56
                     std::vector<Layout> *in_layouts,
                     const std::vector<Layout> *last_in_layouts,
                     std::vector<Layout> *out_layouts) {
  NNVM_ASSIGN_LAYOUT(*in_layouts, 1, (*in_layouts)[0]);
  NNVM_ASSIGN_LAYOUT(*out_layouts, 0, (*in_layouts)[0]);
  return true;
})
57 58 59
.set_attr<FInplaceOption>(
  "FInplaceOption", [](const NodeAttrs& attrs) {
    return std::vector<std::pair<int, int> >{{1, 0}};
Yao Wang committed
60 61 62 63 64 65 66 67 68
})
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds){
    return std::vector<NodeEntry>{
      MakeNode("zeros_like", n->attrs.name + "_zero_grad",
               {n->inputs[0]}),
      ograds[0]
    };
69 70 71 72
});

}  // namespace top
}  // namespace nnvm