merge_compiler_regions.cc 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * \file src/relay/transforms/merge_compiler_regions.cc
 *
 * \brief After operators have been annotated with the targets that support
 * them, this pass creates regions of the operators for each target. It
 * is guaranteed that the regions will have a topological ordering so that
 * no data dependency issues exist.
 *
 * This pass only introduces annotations to indicate the regions.
 * partition_graph must subsequently be called to lift these regions out
 * as external functions.
 */

33
#include <tvm/ir/error.h>
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "../analysis/annotated_region_set.h"

namespace tvm {
namespace relay {
namespace partitioning {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

/*! \brief This is a pre-requisite pass to merge-supported pass.
 *  The AnnotateRestDefault pass will put "default" Compiler Annotations to
 *  nodes that are not annotated already. This is there to ensure that the
 *  user will not leave un-annotated nodes MergeCompilerRegions pass is run.
 *  Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
 */
class AnnotateRestDefault : public ExprMutator {
 public:
  explicit AnnotateRestDefault(const Expr& expr) {
65
    regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
66 67 68 69 70 71 72
  }

  Expr Annotate(const Expr& expr) {
    // Its a function that is being passed on to annotate
    func_ = Downcast<Function>(expr);

    // Corner Case CC1 : If the last node does not belong
73
    // to a region node to add a compiler_end
74 75 76 77 78
    auto region = regions_->GetRegion(func_->body);
    auto mutated_expr = this->VisitExpr(expr);
    if (!region.defined()) {
      func_ = Downcast<Function>(mutated_expr);
      // CC1 : add that compiler end after mutation
79 80
      auto body = InsertEnd(func_->body);
      func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
81 82 83 84 85 86
      return Downcast<Expr>(func_);
    }
    return mutated_expr;
  }

  /*! \brief This function adds compiler ends to nodes that
87
   * don't belong to a region already (default).
88 89 90
   * \param expr The expression to add a compiler end to.
   * \return expr The expression with or without a compiler end added.
   */
91 92 93 94 95 96 97
  Expr InsertEnd(const Expr& expr) {
    if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
        !expr->IsInstance<ConstantNode>()) {
      const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
      CHECK(end_op);
      Expr end = (*end_op)(expr, target_);
      return end;
98
    }
99
    return expr;
100 101
  }

102 103 104 105 106 107 108 109 110 111 112
  /*! \brief This function adds compiler begins to nodes that
   * don't belong to a region already (default).
   * \param expr The expression to add a compiler begin to.
   * \return expr The expression with or without a compiler begin added.
   */
  Expr InsertBegin(const Expr& expr) {
    const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
    CHECK(begin_op);
    Expr begin = (*begin_op)(expr, target_);
    annotated_nodes_.insert(begin);
    return begin;
113 114
  }

115 116 117 118
  Expr VisitExpr_(const CallNode* cn) final {
    auto region = regions_->GetRegion(GetRef<Call>(cn));
    auto new_e = ExprMutator::VisitExpr_(cn);
    Call call = Downcast<Call>(new_e);
119

120
    // Add compiler ends if the parent isn't annotated
121 122
    Array<Expr> args;
    for (auto arg : call->args) {
123
      args.push_back(InsertEnd(arg));
124 125
    }

126 127 128 129 130 131 132 133
    Expr updated_call = Call(call->op, args, call->attrs);
    if (!region.defined()) {
      // if the current node does not belong to annotated region
      // annotate the all incoming edges (args)
      // with "default" compiler_begin annotations.
      Array<Expr> compiler_begins;
      for (auto arg : args) {
        compiler_begins.push_back(InsertBegin(arg));
134
      }
135 136 137
      updated_call = Call(call->op, compiler_begins, call->attrs);
    } else {
      annotated_nodes_.insert(updated_call);
138
    }
139
    return updated_call;
140 141
  };

142 143
  Expr VisitExpr_(const TupleNode* op) {
    auto region = regions_->GetRegion(GetRef<Tuple>(op));
144
    auto new_e = ExprMutator::VisitExpr_(op);
145 146 147
    Tuple tup = Downcast<Tuple>(new_e);

    Array<Expr> fields;
148
    for (auto field : tup->fields) {
149
      fields.push_back(InsertEnd(field));
150
    }
151 152 153 154 155 156 157 158 159 160 161 162

    Expr updated_tuple = Tuple(fields);
    if (!region.defined()) {
      Array<Expr> compiler_begins;
      for (const auto& field : fields) {
        compiler_begins.push_back(InsertBegin(field));
      }
      updated_tuple = Tuple(compiler_begins);
    } else {
      annotated_nodes_.insert(updated_tuple);
    }
    return updated_tuple;
163 164
  }

165 166
  Expr VisitExpr_(const TupleGetItemNode* op) {
    auto region = regions_->GetRegion(GetRef<TupleGetItem>(op));
167 168
    auto new_e = ExprMutator::VisitExpr_(op);
    auto get = Downcast<TupleGetItem>(new_e);
169 170 171 172 173 174 175 176 177

    auto updated_tuple = InsertEnd(get->tuple);
    Expr updated_get = TupleGetItem(updated_tuple, get->index);
    if (!region.defined()) {
      updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
    } else {
      annotated_nodes_.insert(updated_get);
    }
    return updated_get;
178 179
  }

180 181
  Expr VisitExpr_(const IfNode* op) {
    auto region = regions_->GetRegion(GetRef<If>(op));
182
    auto new_e = ExprMutator::VisitExpr_(op);
183 184 185 186 187 188 189 190 191 192 193
    auto iff = Downcast<If>(new_e);

    if (!region.defined()) {
      return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)),
                InsertBegin(InsertEnd(iff->false_branch)));
    } else {
      Expr updated_iff =
          If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch));
      annotated_nodes_.insert(updated_iff);
      return updated_iff;
    }
194 195
  }

196
  Expr VisitExpr_(const LetNode* op) {
197
    auto new_e = ExprMutator::VisitExpr_(op);
198 199
    auto let = Downcast<Let>(new_e);
    return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
200 201
  }

202
  Expr VisitExpr_(const RefCreateNode* op) {
203 204
    auto new_e = ExprMutator::VisitExpr_(op);
    auto create = Downcast<RefCreate>(new_e);
205
    return RefCreate(InsertEnd(create->value));
206 207
  }

208
  Expr VisitExpr_(const RefReadNode* op) {
209 210
    auto new_e = ExprMutator::VisitExpr_(op);
    auto read = Downcast<RefRead>(new_e);
211
    return RefRead(InsertEnd(read->ref));
212 213
  }

214
  Expr VisitExpr_(const RefWriteNode* op) {
215 216
    auto new_e = ExprMutator::VisitExpr_(op);
    auto write = Downcast<RefWrite>(new_e);
217
    return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
218 219 220
  }

 private:
221 222 223 224
  AnnotatedRegionSet regions_;
  const std::string target_ = "default";
  Function func_;
  std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
225 226 227 228 229 230 231
};

class MergeAnnotations : public ExprMutator {
 public:
  explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}

  Expr VisitExpr_(const CallNode* call) final {
232 233 234 235 236 237 238 239
    // remove 'default' annotations
    auto attrs = call->attrs.as<CompilerAttrs>();
    if (attrs != nullptr && attrs->compiler == "default") {
      return VisitExpr(call->args[0]);
    }
    // Merge annotations which are now internal to a region.
    // This happens if we see a compiler begin next to a
    // compiler end and they're both in the same region.
240 241 242 243 244 245 246
    if (call->op == compiler_begin_op) {
      if (call->args[0]->IsInstance<CallNode>()) {
        auto arg = Downcast<Call>(call->args[0]);
        if (arg->op == compiler_end_op) {
          auto region1 = regions_->GetRegion(GetRef<Call>(call));
          auto region2 = regions_->GetRegion(arg);
          if (region1 == region2) {
247
            return VisitExpr(arg->args[0]);
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
          }
        }
      }
    }
    return ExprMutator::VisitExpr_(call);
  }

 private:
  AnnotatedRegionSet regions_;
};

class RegionMerger : public ExprVisitor {
 public:
  explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}

  void VisitExpr_(const CallNode* call) final {
    if (call->op == compiler_end_op) {
      auto region = regions_->GetRegion(GetRef<Call>(call));
266
      if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
      // set the region target
      auto compiler_attrs = call->attrs.as<CompilerAttrs>();
      region_targets_[region->GetID()] = compiler_attrs->compiler;
      // first look at the region args to determine the parent regions
      for (const auto& arg : region->GetInputs()) {
        // all args should be begin annotations
        auto begin = Downcast<Call>(arg);
        CHECK_EQ(begin->op, compiler_begin_op);
        // the arguments of the begin annotations will be in the parent regions
        auto parent_region = regions_->GetRegion(begin->args[0]);
        // if there is no parent region, move on
        if (!parent_region.defined()) continue;
        // merge the parent region if it hasn't been done already
        if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
          VisitExpr(begin->args[0]);
        }
283 284
      }
      // get the mergeable regions now all the parents have been visited
285
      std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
286 287 288 289 290
      for (const auto& arg : region->GetInputs()) {
        auto begin = Downcast<Call>(arg);
        CHECK_EQ(begin->op, compiler_begin_op);
        auto parent_region = regions_->GetRegion(begin->args[0]);
        if (!parent_region.defined()) continue;
291
        mergeable_regions.insert(parent_region);
292 293 294 295 296
      }
      auto& region_restrictions = region_restrictions_[region->GetID()];
      for (const auto& parent_region : mergeable_regions) {
        // add all the parent restrictions to the current region
        auto parent_restrictions = region_restrictions_[parent_region->GetID()];
297
        region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
298 299 300 301 302 303 304 305 306
      }
      for (const auto& parent_region : mergeable_regions) {
        bool merged = false;
        // check the parent region has the same target
        if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) {
          // check the parent region isn't in the restrictions
          if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
            // merge the parent region into the current region
            regions_->MergeRegions(parent_region, region);
307 308
            // update the restrictions of all other regions to reflect the
            // change in id
309 310 311 312 313 314 315 316 317 318
            for (const auto& r : regions_) {
              auto& restrictions = region_restrictions_[r->GetID()];
              if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
                restrictions.erase(parent_region->GetID());
                restrictions.insert(region->GetID());
              }
            }
            merged = true;
          }
        }
319 320 321
        // if the parent wasn't merged, add it as a restriction to the current
        // region
        if (!merged) region_restrictions.insert(parent_region->GetID());
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
      }
      merged_regions_.insert(region->GetID());
    }
    ExprVisitor::VisitExpr_(call);
  }

 private:
  AnnotatedRegionSet regions_;
  std::unordered_set<int> merged_regions_;
  std::map<int, std::unordered_set<int>> region_restrictions_;
  std::map<int, std::string> region_targets_;
};

Expr MergeCompilerRegions(const Expr& expr) {
  // Annotate all the nodes that aren't annotated as 'default'.
  AnnotateRestDefault anno_default(expr);
  auto expr_all_annotated = anno_default.Annotate(expr);

  // Create regions using the annotations.
341 342
  AnnotatedRegionSet regions =
      AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369

  // By now, all the nodes have some sort of annotation.
  // Region merger is an ExprVisitor that will update the
  // AnnotatedRegionSet, merging all the regions that can be merged.
  RegionMerger merger(regions);
  merger.VisitExpr(expr_all_annotated);

  // This updates the expression to remove annotations that are now
  // 'internal' to a merged region.
  MergeAnnotations merge_anno(regions);
  return merge_anno.Mutate(expr_all_annotated);
}

}  // namespace partitioning

namespace transform {

Pass MergeCompilerRegions() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(partitioning::MergeCompilerRegions(f));
      };
  auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
  return Sequential({partitioned, InferType()});
}

TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
370
    .set_body_typed(transform::MergeCompilerRegions);
371 372 373 374 375

}  // namespace transform

}  // namespace relay
}  // namespace tvm