thread_storage_scope.h 6.44 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
/*!
 *  Copyright (c) 2017 by Contributors
 * \file thread_storage_scope.h
 * \brief Extract thread axis configuration from TVMArgs.
 */
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_

#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>

namespace tvm {
namespace runtime {

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/*!
 * \brief Memory hierachy rank in the storage system
 * \note The global rank and shared rank have one to one
 *       correspondence to the thread rank.
 */
enum class StorageRank {
  /*! \brief global memory */
  kGlobal = 0,
  /*! \brief shared memory among thread group */
  kShared = 1,
  /*!
   * \brief reserved for warp memory.
   *  This is only used by programming model.
   *  There is no such memory usually in GPU.
   *  Instead, we can simulate it by registers and shuffle.
   */
  kWarp = 2,
  /*! \brief thread local memory */
  kLocal = 3
};

/*!
 * \param thread_scope_rank The thread scope rank
 * \return default storage rank given the thread scope
 */
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
  switch (thread_scope_rank) {
    case -1: return StorageRank::kGlobal;
    case 0: return StorageRank::kShared;
    case 1: return StorageRank::kLocal;
    default: {
      LOG(FATAL) << "unknown rank";
      return StorageRank::kGlobal;
    }
  }
}

72 73 74
/*! \brief class to represent storage scope */
struct StorageScope {
  /*! \brief The rank of the storage */
75
  StorageRank rank{StorageRank::kGlobal};
76
  /*! \brief tag for special purpose memory. */
77
  std::string tag;
78 79
  // comparator
  inline bool operator==(const StorageScope& other) const {
80
    return rank == other.rank && tag == other.tag;
81
  }
82 83 84
  inline bool operator!=(const StorageScope& other) const {
    return !(*this == other);
  }
85
  inline std::string to_string() const {
86
    std::string ret;
87
    switch (rank) {
88 89 90 91
      case StorageRank::kGlobal: return "global" + tag;
      case StorageRank::kShared: return "shared" + tag;
      case StorageRank::kWarp: return "warp" + tag;
      case StorageRank::kLocal: return "local" + tag;
92 93 94 95 96 97 98 99 100 101
      default: LOG(FATAL) << "unknown storage scope"; return "";
    }
  }
  /*!
   * \brief make storage scope from string
   * \param s The string to be parsed.
   * \return The storage scope.
   */
  static StorageScope make(const std::string& s) {
    StorageScope r;
102
    if (s.compare(0, 6, "global")  == 0) {
103
      r.rank = StorageRank::kGlobal;
104 105
      r.tag = s.substr(6, std::string::npos);
    } else if (s.compare(0, 6, "shared") == 0) {
106
      r.rank = StorageRank::kShared;
107
      r.tag = s.substr(6, std::string::npos);
108 109 110
    } else if (s.compare(0, 4, "warp") == 0) {
      r.rank = StorageRank::kWarp;
      r.tag = s.substr(4, std::string::npos);
111
    } else if (s.compare(0, 5, "local") == 0) {
112
      r.rank = StorageRank::kLocal;
113
      r.tag = s.substr(5, std::string::npos);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    } else {
      LOG(FATAL) << "unknown storage scope " << s;
    }
    return r;
  }
};

/*! \brief class to represent thread scope */
struct ThreadScope {
  /*! \brief The rank of thread scope */
  int rank{0};
  /*! \brief the dimension index under the rank */
  int dim_index{0};
  /*!
   * \brief make storage scope from string
   * \param s The string to be parsed.
   * \return The storage scope.
   */
  static ThreadScope make(const std::string& s) {
    ThreadScope r;
134
    if (s == "vthread" || s == "cthread") {
135 136 137 138
      // virtual thread at the same level as local
      r.rank = 1;
      r.dim_index = -1;
    } else if (s.compare(0, 9, "blockIdx.") == 0) {
139 140 141 142 143 144 145 146 147 148 149 150 151
      r.rank = 0;
      r.dim_index = static_cast<int>(s[9] - 'x');
    } else if (s.compare(0, 10, "threadIdx.") == 0) {
      r.rank = 1;
      r.dim_index = static_cast<int>(s[10] - 'x');
    } else {
      LOG(FATAL) << "Unknown threadscope " << s;
    }
    return r;
  }
};


152
/*! \brief workload specification */
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
struct ThreadWorkLoad {
  // array, first three are thread configuration.
  size_t work_size[6];
  /*!
   * \param i The block dimension.
   * \return i-th block dim
   */
  inline size_t block_dim(size_t i) const {
    return work_size[i + 3];
  }
  /*!
   * \param i The grid dimension.
   * \return i-th grid dim
   */
  inline size_t grid_dim(size_t i) const {
    return work_size[i];
  }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
 public:
  void Init(size_t base,
            const std::vector<std::string>& thread_axis_tags)  {
    base_ = base;
    std::vector<bool> filled(6, false);
    for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
      const std::string& tag = thread_axis_tags[i];
      ThreadScope ts = ThreadScope::make(tag);
      arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
      filled[ts.rank * 3 + ts.dim_index] = true;
    }
184
    work_dim_ = 1;
185
    for (int i = 0; i < 3; ++i) {
186 187
      if (filled[i] || filled[i + 3]) {
        work_dim_ = i + 1;
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
      }
    }
  }
  // extract workload from arguments.
  ThreadWorkLoad Extract(TVMArgs x) const {
    ThreadWorkLoad w;
    std::fill(w.work_size, w.work_size + 6, 1);
    for (size_t i = 0; i < arg_index_map_.size(); ++i) {
      w.work_size[arg_index_map_[i]] =
          static_cast<size_t>(x.values[base_ + i].v_int64);
    }
    return w;
  }
  // return the work dim
  size_t work_dim() const {
    return work_dim_;
  }

 private:
  /*! \brief base axis */
  size_t base_;
  /*! \brief The worker dimension */
  size_t work_dim_;
  /*! \brief The index mapping. */
  std::vector<uint32_t> arg_index_map_;
};

}  // namespace runtime
}  // namespace tvm

namespace std {
template <>
struct hash<::tvm::runtime::StorageScope> {
  std::size_t operator()(const ::tvm::runtime::StorageScope& k) const {
    return static_cast<size_t>(k.rank);
  }
};
}  // namespace std
#endif  // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_