alter_op_layout.cc 6.48 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30
/*!
 *  Copyright (c) 2018 by Contributors
 * \file alter_op_layout.cc
 * \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages.
 *        e.g., convolution may calculates faster with NCHW16c layout.
 */
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/pass_functions.h>
31
#include <tvm/operation.h>
32 33
#include <algorithm>
#include <functional>
34 35
#include "compile_engine.h"
#include "graph_transform.h"
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

namespace nnvm {
namespace compiler {
namespace {

tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
                                      const uint32_t nid,
                                      const ShapeVector& shape_vec,
                                      const DTypeVector& dtype_vec) {
  tvm::Array<tvm::Tensor> vec;
  for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) {
    tvm::Array<tvm::Expr> shape;
    for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
      CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
      shape.push_back(tvm::make_const(tvm::Int(32), x));
    }
    vec.push_back(tvm::placeholder(
      shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
  }
  return vec;
}

Graph AlterOpLayout(const Graph& src) {
  static auto& falter_op_layout =
    Op::GetAttr<nnvm::compiler::FTVMAlterOpLayout >("FTVMAlterOpLayout");

  const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
  const DTypeVector& dtype_vec = src.GetAttr<DTypeVector>("dtype");
  const IndexedGraph& idx_graph = src.indexed_graph();

  std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
  std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
68
  std::unordered_map<const Node*, uint32_t> unchanged_nodes;
69 70 71 72 73 74 75 76 77

  if (src.HasAttr("layout")) {
    // record layouts so that LayoutTransform pass can fix layouts correctly,
    // e.g., conv2d can be replaced by some contrib implement
    // whose layout is different from the original one
    // (which was imported from a model file).
    const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
    for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
      const auto &inode = idx_graph[nid];
78 79
      // record input layouts for all nodes,
      // while replaced nodes will ignore the records here and have undefined input layouts.
80 81 82 83 84 85 86
      std::vector<Layout> in_layout;
      for (const auto& e : inode.inputs) {
        in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
      }
      in_layouts_of_node[nid] = in_layout;

      std::vector<Layout> out_layout;
Tang, Cheng committed
87
      for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
88 89 90 91 92 93 94 95 96 97 98 99
        out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]);
      }
      out_layouts_of_node[nid] = out_layout;
    }
  }

  auto transform = [&](uint32_t nid,
                       const NodePtr& n,
                       std::vector<NodeEntry>* ret) {
    nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
      falter_op_layout.get(n->op(), nullptr);
    if (fn_alter_op_layout == nullptr) {
100 101
      // will restore the original input layouts later.
      unchanged_nodes[n.get()] = nid;
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
      return false;
    }

    // construct parameters for registered function
    std::vector<Symbol> op_inputs;
    tvm::Array<tvm::Tensor> tensor_infos;
    CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size());
    for (uint32_t i = 0; i < n->num_inputs(); ++i) {
      const nnvm::NodeEntry& input = n->inputs[i];
      // input operator
      Symbol op_input;
      op_input.outputs.push_back(input);
      op_inputs.push_back(op_input);

      // input tinfo, extract from the original graph
      // because it was where infer_shape & infer_type applied.
      tvm::Array<tvm::Tensor> op_output_tinfos =
        GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id,
                      shape_vec, dtype_vec);
      tensor_infos.push_back(op_output_tinfos[input.index]);
    }
    // callback registered function to get a new operator.
124 125 126
    Symbol op;
    bool do_alter =
      fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
127 128 129 130 131 132 133

    if (do_alter) {
      *ret = op.outputs;
    } else {
      // will restore the original input layouts later.
      unchanged_nodes[n.get()] = nid;
    }
134
    return do_alter;
135 136 137 138 139 140 141 142 143 144
  };

  Graph ret = nnvm::compiler::GraphTransform(src, transform);

  if (src.HasAttr("layout")) {
    // restore the layouts to return graph
    const auto& ret_idx = ret.indexed_graph();
    std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
    for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
      const auto& inode = ret_idx[nid];
145
      if (unchanged_nodes.count(inode.source)) {
146
        const std::vector<Layout>& in_layouts =
147
          in_layouts_of_node[unchanged_nodes[inode.source]];
148 149 150
        for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
          const auto& e = inode.inputs[i];
          ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
151 152
        }
        const std::vector<Layout>& out_layouts =
153
          out_layouts_of_node[unchanged_nodes[inode.source]];
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
          ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
        }
      }
    }

    // cannot call indexed_graph() before return the origin Graph,
    // thus create a new one.
    nnvm::Graph new_ret;
    new_ret.outputs = ret.outputs;
    new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
    return new_ret;
  }

  return ret;
}

// register pass
NNVM_REGISTER_PASS(AlterOpLayout)
.set_body(AlterOpLayout)
.set_change_graph(true);

}  // namespace
}  // namespace compiler
}  // namespace nnvm