/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2017 by Contributors
 * \file storage_access.h
 * \brief Common data structure for storage access analysis.
 */
#ifndef TVM_PASS_STORAGE_ACCESS_H_
#define TVM_PASS_STORAGE_ACCESS_H_

#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"

namespace tvm {
namespace ir {

using runtime::StorageScope;
using runtime::StorageRank;
/*!
 * \brief Base class of storage access analysis
 */
class StorageAccessVisitor : public IRVisitor {
 public:
  /*! \brief Storage access type */
  enum AccessType {
    kRead,
    kWrite,
    kSync,
    kAlloc,
    // acquired version of read, only need to handle WAR dep.
    kReadAcquire
  };
  /*! \brief An access entry */
  struct AccessEntry {
    /*! \brief The thread index that access this entry */
    Array<IterVar> threads;
    /*! \brief The buffer variable, if any */
    Var buffer = NullValue<Var>();
    /*! \brief The access data type */
    Type dtype;
    /*! \brief The touched access range */
    arith::IntSet touched;
    /*! \brief The type of access */
    AccessType type;
    /*! \brief The storage scope */
    StorageScope scope;
    /*! \brief Whether the access is double buffer write */
    bool double_buffer_write = false;
  };
  /*! \brief Access pattern about a single statement */
  struct StmtEntry {
    /*! \brief The statement */
    const Node* stmt;
    /*! \brief access patterns in the statement */
    std::vector<AccessEntry> access;
  };
  // override visitor pattern
  void Visit_(const Load* op) final;
  void Visit_(const Store* op) final;
  void Visit_(const Evaluate* op) final;
  void Visit_(const AttrStmt* op) final;
  void Visit_(const For* op) final;
  void Visit_(const IfThenElse* op) final;
  void Visit_(const Call* op) final;

 protected:
  StorageAccessVisitor() {
    scope_.push_back(std::vector<StmtEntry>());
  }
  /*! \return number of conditions in the current scope. */
  int condition_counter() const {
    return condition_counter_;
  }
  /*! \return whether we are in device environment. */
  bool in_device_env() const {
    return in_device_env_;
  }
  /*! \return environment threads */
  const Array<IterVar>& env_threads() const {
    return env_threads_;
  }
  /*!
   * \brief Whether we need analyze the buffer in current scope.
   * \param buffer The buffer to be checked
   * \param scope The scope of the buffer.
   * \return Whether the analysis of buffer is enabled.
   */
  virtual bool Enabled(const Variable* buffer,
                       const StorageScope& scope) const {
    return true;
  }
  /*!
   * \brief Summarize the sequence of operations into parent.
   *
   *  Insert synchronization if necessary and remove un-necessary
   *  memory access which are already synced.
   *
   * \param seq The sequence of the access operations.
   * \param loop Pass loop node if it is a loop, otherwise nullptr.
   * \return The summarized sequence that represent access that
   *  the parent should taken care of to synchronize.
   */
  virtual std::vector<AccessEntry> Summarize(
      std::vector<StmtEntry> seq, const For* loop) = 0;
  /*!
   * \brief Get the scope of the buffer array.
   * \return The scope of the final buffer array.
   */
  StorageScope GetScope(const Variable* buf) const;
  // access scope
  std::vector<std::vector<StmtEntry> > scope_;

 private:
  // whether access appending is enabled.
  bool allow_append_{false};
  // Whether we are in device environment
  bool in_device_env_{false};
  // Whether we are inside condition.
  int condition_counter_{0};
  // The current double buffer write scope.
  const Variable* double_buffer_write_{nullptr};
  // the current free stmt entry.
  StmtEntry curr_stmt_;
  // The involving threads
  Array<IterVar> env_threads_;
  // The storage scope of each buffer
  std::unordered_map<const Variable*, StorageScope> storage_scope_;
};

}  // namespace ir
}  // namespace tvm
#endif  // TVM_PASS_STORAGE_ACCESS_H_