storage_access.h 4.81 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file storage_access.h
 * \brief Common data structure for storage access analysis.
 */
24 25
#ifndef TVM_TIR_PASS_STORAGE_ACCESS_H_
#define TVM_TIR_PASS_STORAGE_ACCESS_H_
26

27
#include <tvm/tir/expr.h>
28
#include <tvm/ir/attrs.h>
29 30
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
31 32
#include <vector>
#include <unordered_map>
33
#include "../../runtime/thread_storage_scope.h"
34 35

namespace tvm {
36
namespace tir {
37

38
using runtime::StorageScope;
39
using runtime::StorageRank;
40 41 42
/*!
 * \brief Base class of storage access analysis
 */
43
class StorageAccessVisitor : public StmtExprVisitor {
44 45 46 47 48 49
 public:
  /*! \brief Storage access type */
  enum AccessType {
    kRead,
    kWrite,
    kSync,
50 51 52
    kAlloc,
    // acquired version of read, only need to handle WAR dep.
    kReadAcquire
53 54 55 56 57 58
  };
  /*! \brief An access entry */
  struct AccessEntry {
    /*! \brief The thread index that access this entry */
    Array<IterVar> threads;
    /*! \brief The buffer variable, if any */
59
    Var buffer = NullValue<Var>();
60
    /*! \brief The access data type */
61
    DataType dtype;
62 63 64 65 66 67
    /*! \brief The touched access range */
    arith::IntSet touched;
    /*! \brief The type of access */
    AccessType type;
    /*! \brief The storage scope */
    StorageScope scope;
68
    /*! \brief Whether the access is double buffer write */
69
    bool double_buffer_write = false;
70 71 72 73
  };
  /*! \brief Access pattern about a single statement */
  struct StmtEntry {
    /*! \brief The statement */
74
    const Object* stmt;
75 76 77 78
    /*! \brief access patterns in the statement */
    std::vector<AccessEntry> access;
  };
  // override visitor pattern
79 80 81 82 83 84 85
  void VisitExpr_(const LoadNode* op) final;
  void VisitStmt_(const StoreNode* op) final;
  void VisitStmt_(const EvaluateNode* op) final;
  void VisitStmt_(const AttrStmtNode* op) final;
  void VisitStmt_(const ForNode* op) final;
  void VisitStmt_(const IfThenElseNode* op) final;
  void VisitExpr_(const CallNode* op) final;
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
 protected:
  StorageAccessVisitor() {
    scope_.push_back(std::vector<StmtEntry>());
  }
  /*! \return number of conditions in the current scope. */
  int condition_counter() const {
    return condition_counter_;
  }
  /*! \return whether we are in device environment. */
  bool in_device_env() const {
    return in_device_env_;
  }
  /*! \return environment threads */
  const Array<IterVar>& env_threads() const {
    return env_threads_;
  }
  /*!
   * \brief Whether we need analyze the buffer in current scope.
   * \param buffer The buffer to be checked
   * \param scope The scope of the buffer.
   * \return Whether the analysis of buffer is enabled.
   */
109
  virtual bool Enabled(const VarNode* buffer,
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                       const StorageScope& scope) const {
    return true;
  }
  /*!
   * \brief Summarize the sequence of operations into parent.
   *
   *  Insert synchronization if necessary and remove un-necessary
   *  memory access which are already synced.
   *
   * \param seq The sequence of the access operations.
   * \param loop Pass loop node if it is a loop, otherwise nullptr.
   * \return The summarized sequence that represent access that
   *  the parent should taken care of to synchronize.
   */
  virtual std::vector<AccessEntry> Summarize(
125
      std::vector<StmtEntry> seq, const ForNode* loop) = 0;
126 127 128 129
  /*!
   * \brief Get the scope of the buffer array.
   * \return The scope of the final buffer array.
   */
130
  StorageScope GetScope(const VarNode* buf) const;
131 132
  // access scope
  std::vector<std::vector<StmtEntry> > scope_;
133

134 135 136 137 138 139 140
 private:
  // whether access appending is enabled.
  bool allow_append_{false};
  // Whether we are in device environment
  bool in_device_env_{false};
  // Whether we are inside condition.
  int condition_counter_{0};
141
  // The current double buffer write scope.
142
  const VarNode* double_buffer_write_{nullptr};
143 144 145 146 147
  // the current free stmt entry.
  StmtEntry curr_stmt_;
  // The involving threads
  Array<IterVar> env_threads_;
  // The storage scope of each buffer
148
  std::unordered_map<const VarNode*, StorageScope> storage_scope_;
149
};
150

151
}  // namespace tir
152
}  // namespace tvm
153
#endif  // TVM_TIR_PASS_STORAGE_ACCESS_H_