storage_access.h 4.71 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
/*!
 * Copyright (c) 2017 by Contributors
 * \file storage_access.h
 * \brief Common data structure for storage access analysis.
 */
#ifndef TVM_PASS_STORAGE_ACCESS_H_
#define TVM_PASS_STORAGE_ACCESS_H_

#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"

namespace tvm {
namespace ir {
37

38
using runtime::StorageScope;
39
using runtime::StorageRank;
40 41 42 43 44 45 46 47 48 49
/*!
 * \brief Base class of storage access analysis
 */
class StorageAccessVisitor : public IRVisitor {
 public:
  /*! \brief Storage access type */
  enum AccessType {
    kRead,
    kWrite,
    kSync,
50 51 52
    kAlloc,
    // acquired version of read, only need to handle WAR dep.
    kReadAcquire
53 54 55 56 57 58
  };
  /*! \brief An access entry */
  struct AccessEntry {
    /*! \brief The thread index that access this entry */
    Array<IterVar> threads;
    /*! \brief The buffer variable, if any */
59
    VarExpr buffer;
60 61 62 63 64 65 66 67
    /*! \brief The access data type */
    Type dtype;
    /*! \brief The touched access range */
    arith::IntSet touched;
    /*! \brief The type of access */
    AccessType type;
    /*! \brief The storage scope */
    StorageScope scope;
68 69
    /*! \brief Whether the access is double buffer write */
    bool double_buffer_write{false};
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  };
  /*! \brief Access pattern about a single statement */
  struct StmtEntry {
    /*! \brief The statement */
    const Node* stmt;
    /*! \brief access patterns in the statement */
    std::vector<AccessEntry> access;
  };
  // override visitor pattern
  void Visit_(const Load* op) final;
  void Visit_(const Store* op) final;
  void Visit_(const Evaluate* op) final;
  void Visit_(const AttrStmt* op) final;
  void Visit_(const For* op) final;
  void Visit_(const IfThenElse* op) final;
  void Visit_(const Call* op) final;
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
 protected:
  StorageAccessVisitor() {
    scope_.push_back(std::vector<StmtEntry>());
  }
  /*! \return number of conditions in the current scope. */
  int condition_counter() const {
    return condition_counter_;
  }
  /*! \return whether we are in device environment. */
  bool in_device_env() const {
    return in_device_env_;
  }
  /*! \return environment threads */
  const Array<IterVar>& env_threads() const {
    return env_threads_;
  }
  /*!
   * \brief Whether we need analyze the buffer in current scope.
   * \param buffer The buffer to be checked
   * \param scope The scope of the buffer.
   * \return Whether the analysis of buffer is enabled.
   */
  virtual bool Enabled(const Variable* buffer,
                       const StorageScope& scope) const {
    return true;
  }
  /*!
   * \brief Summarize the sequence of operations into parent.
   *
   *  Insert synchronization if necessary and remove un-necessary
   *  memory access which are already synced.
   *
   * \param seq The sequence of the access operations.
   * \param loop Pass loop node if it is a loop, otherwise nullptr.
   * \return The summarized sequence that represent access that
   *  the parent should taken care of to synchronize.
   */
  virtual std::vector<AccessEntry> Summarize(
      std::vector<StmtEntry> seq, const For* loop) = 0;
  /*!
   * \brief Get the scope of the buffer array.
   * \return The scope of the final buffer array.
   */
  StorageScope GetScope(const Variable* buf) const;
131 132
  // access scope
  std::vector<std::vector<StmtEntry> > scope_;
133

134 135 136 137 138 139 140
 private:
  // whether access appending is enabled.
  bool allow_append_{false};
  // Whether we are in device environment
  bool in_device_env_{false};
  // Whether we are inside condition.
  int condition_counter_{0};
141 142
  // The current double buffer write scope.
  const Variable* double_buffer_write_{nullptr};
143 144 145 146 147 148
  // the current free stmt entry.
  StmtEntry curr_stmt_;
  // The involving threads
  Array<IterVar> env_threads_;
  // The storage scope of each buffer
  std::unordered_map<const Variable*, StorageScope> storage_scope_;
149
};
150

151 152 153
}  // namespace ir
}  // namespace tvm
#endif  // TVM_PASS_STORAGE_ACCESS_H_