combine_context_call.cc 4.01 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Combine calls into context related function into one.
 *
 * \file combine_context_call.cc
 */
25 26 27
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
28 29 30
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>

31
#include <tvm/tir/ir_pass.h>
32

33 34 35
#include <map>

namespace tvm {
36
namespace tir {
37 38 39

// Calculate the statistics of packed function.
// These information are needed during codegen.
40
class ContextCallCombiner final : public StmtExprMutator {
41 42
 public:
  struct CompareExpr {
43
    bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
44 45 46 47
      return Compare(lhs, rhs) < 0;
    }
  };

48
  PrimExpr VisitExpr_(const CallNode* op) final {
49 50
    if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
      CHECK_EQ(op->args.size(), 1U);
51
      PrimExpr ctx = op->args[0];
52 53 54 55
      auto it  = ctx_map_.find(ctx);
      if (it != ctx_map_.end()) {
        return it->second;
      } else {
56
        CHECK(ctx.dtype().is_handle());
57
        std::string name;
58
        if (const CallNode* call = ctx.as<CallNode>()) {
59 60 61 62
          name = call->name + "_cache";
        } else {
          name = "ctx_cache_";
        }
63
        Var ctx_var(name, ctx.dtype());
64
        ctx_map_[ctx] = ctx_var;
65
        return std::move(ctx_var);
66 67
      }
    } else {
68
      return StmtExprMutator::VisitExpr_(op);
69 70 71
    }
  }

72
  Stmt VisitStmt_(const AttrStmtNode* op) final {
73 74
    if (op->attr_key == attr::thread_extent ||
        op->attr_key == attr::coproc_uop_scope) {
75
      // Map of comparison expression to variable
76
      std::map<PrimExpr, Var, CompareExpr> temp;
77
      std::swap(temp, ctx_map_);
78
      Stmt stmt = StmtExprMutator::VisitStmt_(op);
79 80 81
      std::swap(temp, ctx_map_);
      return BuildContext(temp, stmt);
    } else {
82
      return StmtExprMutator::VisitStmt_(op);
83 84 85
    }
  }

86
  Stmt VisitStmt_(const ForNode* op) final {
87 88
    if (op->for_type == ForType::Parallel) {
      // Map of comparison expression to variable
89
      std::map<PrimExpr, Var, CompareExpr> temp;
90
      std::swap(temp, ctx_map_);
91
      Stmt stmt = StmtExprMutator::VisitStmt_(op);
92 93 94
      std::swap(temp, ctx_map_);
      return BuildContext(temp, stmt);
    } else {
95
      return StmtExprMutator::VisitStmt_(op);
96 97 98 99
    }
  }

  Stmt Combine(Stmt stmt) {
100
    return BuildContext(ctx_map_, this->VisitStmt(stmt));
101 102 103
  }

 private:
104
  static Stmt BuildContext(const std::map<PrimExpr, Var, CompareExpr>& cmap,
105 106
                           Stmt body) {
    for (const auto& kv : cmap) {
107
      body = LetStmtNode::make(kv.second, kv.first, body);
108 109 110 111
    }
    return body;
  }
  // Map of comparison expression to variable
112
  std::map<PrimExpr, Var, CompareExpr> ctx_map_;
113 114 115
};

LoweredFunc CombineContextCall(LoweredFunc f) {
116
  auto n = make_object<LoweredFuncNode>(*f.operator->());
117 118 119 120
  n->body = ContextCallCombiner().Combine(n->body);
  return LoweredFunc(n);
}

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
namespace transform {

Pass CombineContextCall() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    n->body = ContextCallCombiner().Combine(n->body);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
}

TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
.set_body_typed(CombineContextCall);

}  // namespace transform
136
}  // namespace tir
137
}  // namespace tvm