build_module.h 17.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22 23
 * \file tvm/build_module.h
 * \brief Functions for compiling ops.
 */
24 25 26 27 28
#ifndef TVM_BUILD_MODULE_H_
#define TVM_BUILD_MODULE_H_

#include <string>
#include <vector>
29
#include <utility>
30 31
#include <unordered_map>
#include <unordered_set>
32 33 34
#include "runtime/packed_func.h"
#include "schedule_pass.h"
#include "lowered_func.h"
35 36 37 38 39

namespace tvm {

/*!
* \brief Container for target device information.
40
*   Use target::llvm, target::cuda etc functions instead of constructing directly.
41
*/
42 43
class TargetNode : public Node {
 public:
44 45
  /*! \brief The name of the target device */
  std::string target_name;
46 47
  /*! \brief The name of the target device */
  std::string device_name;
48
  /*! \brief The type of the target device */
49
  int device_type;
50 51 52 53 54
  /*! \brief The maximum threads that a schedule should use for this device */
  int max_num_threads = 1;
  /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
  int thread_warp_size = 1;
  /*! \brief Keys for this target */
55
  Array<Expr> keys_array;
56
  /*! \brief Options for this target */
57 58 59
  Array<Expr> options_array;
  /*! \brief Collection of imported libs */
  Array<Expr> libs_array;
60 61

  /*! \return the full device string to pass to codegen::Build */
62
  TVM_DLL const std::string& str() const;
63

64 65
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("target_name", &target_name);
66
    v->Visit("device_name", &device_name);
67 68 69 70 71 72 73 74 75
    v->Visit("device_type", &device_type);
    v->Visit("max_num_threads", &max_num_threads);
    v->Visit("thread_warp_size", &thread_warp_size);
    v->Visit("keys_array", &keys_array);
    v->Visit("options_array", &options_array);
    v->Visit("libs_array", &libs_array);
  }

  /*! \brief Get the keys for this target as a vector of string */
76
  TVM_DLL std::vector<std::string> keys() const;
77 78

  /*! \brief Get the options for this target as a vector of string */
79
  TVM_DLL std::vector<std::string> options() const;
80 81

  /*! \brief Get the keys for this target as an unordered_set of string */
82
  TVM_DLL std::unordered_set<std::string> libs() const;
83 84 85

  static constexpr const char* _type_key = "Target";
  TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
86 87 88 89

 private:
  /*! \brief Internal string repr. */
  mutable std::string str_repr_;
90 91
};

92
/*! \brief reference cpass to the target. */
93 94 95
class Target : public NodeRef {
 public:
  Target() {}
96
  explicit Target(NodePtr<Node> n) : NodeRef(n) {}
97
  /*!
98 99 100
  * \brief Create a Target given a string
  * \param target_str the string to parse
  */
101
  TVM_DLL static Target Create(const std::string& target_str);
102
  /*!
103 104 105 106 107 108 109 110
   * \brief Get the current target context from thread local storage.
   * \param allow_not_defined If the context stack is empty and this is set to true, an
   *   undefined Target will be returned. Otherwise, an empty context stack will cause a
   *   runtime error.
   * \return The target that is the current context. The target may not be defined if
   * allow_not_defined is true.
   */
  TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
111

112
  const TargetNode* operator->() const {
113 114 115 116
      return static_cast<const TargetNode*>(node_.get());
  }

  using ContainerType = TargetNode;
117 118 119 120 121
  class Internal;
 private:
  // enable with syntax.
  friend class Internal;
  friend class With<Target>;
122
  /*!
123 124 125
   * \brief Push a new target context onto the thread local stack.
   *  The Target on top of the stack is used to determine which
   *  specialization to use when invoking a GenericFunc.
126
   */
127 128 129 130 131 132
  TVM_DLL void EnterWithScope();
  /*!
   * \brief Pop a target off the thread local context stack,
   *  restoring the previous target as the current context.
   */
  TVM_DLL void ExitWithScope();
133 134 135 136 137
};

/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
138
TVM_DLL Target llvm(const std::vector<std::string>& options =
139
                   std::vector<std::string>());
140 141

/*! \return A target for CUDA */
142
TVM_DLL Target cuda(const std::vector<std::string>& options =
143
                   std::vector<std::string>());
144 145

/*! \return A target for ROCm */
146
TVM_DLL Target rocm(const std::vector<std::string>& options =
147
                   std::vector<std::string>());
148 149

/*! \return A target for OpenCL */
150
TVM_DLL Target opencl(const std::vector<std::string>& options =
151
                     std::vector<std::string>());
152 153

/*! \return A target for Metal */
154
TVM_DLL Target metal(const std::vector<std::string>& options =
155
                    std::vector<std::string>());
156 157

/*! \return A target for rasp */
158
TVM_DLL Target rasp(const std::vector<std::string>& options =
159
                   std::vector<std::string>());
160

161
/*! \return A target for Mali */
162
TVM_DLL Target mali(const std::vector<std::string>& options =
163
                   std::vector<std::string>());
164

165
/*! \return A target for Intel Graphics */
166
TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
167
                             std::vector<std::string>());
168

169
/*! \return A target for stackvm */
170
TVM_DLL Target stackvm(const std::vector<std::string>& options =
171
                      std::vector<std::string>());
172 173 174 175

}  // namespace target

/*!
176 177
 * \brief Container for build configuration options
 */
178 179
class BuildConfigNode : public Node {
 public:
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
  /*!
   * \brief The data alignment to use when constructing buffers. If this is set to
   * -1, then TVM's internal default will be used
   */
  int data_alignment = -1;
  /*!
   * \brief The offset factor to use when constructing buffers. If this is set to
   * 0, then the offset field is not used.
   */
  int offset_factor = 0;

  /*!
   * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
   * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
   */
  int double_buffer_split_loop = 1;
  /*! \brief Threshold of number of steps in the loop to be automatically unrolled */
  int auto_unroll_max_step = 0;
  /*! \brief The maximum nested level of loops that can be automatically unrolled */
  int auto_unroll_max_depth = 8;
  /*! \brief The maximum extent of loop that will be unrolled */
  int auto_unroll_max_extent = 0;
  /*!
   * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
   * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
   */
  bool unroll_explicit = true;

  /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
  bool restricted_func = true;

  /*! \brief Whether to detect global barrier */
  bool detect_global_barrier = false;

214 215 216
  /*! \brief Whether to partition const loop */
  bool partition_const_loop = false;

217
  /*! \brief Whether to dump the IR of each pass (only when building from python) */
218
  std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
219 220 221 222

  /*! \brief Whether to dump the IR of each pass (only when building from python) */
  bool dump_pass_ir = false;

223 224 225
  /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
  bool instrument_bound_checkers = false;

226 227 228
  /*! \brief Whether to disable select rewriting. */
  bool disable_select_rewriting = false;

229 230 231
  /*! \brief Whether to disable loop vectorization. */
  bool disable_vectorize = false;

232 233 234 235 236 237 238 239 240 241 242
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("data_alignment", &data_alignment);
    v->Visit("offset_factor", &offset_factor);
    v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
    v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
    v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
    v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
    v->Visit("unroll_explicit", &unroll_explicit);
    v->Visit("restricted_func", &restricted_func);
    v->Visit("detect_global_barrier", &detect_global_barrier);
    v->Visit("partition_const_loop", &partition_const_loop);
243
    v->Visit("dump_pass_ir", &dump_pass_ir);
244
    v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
245
    v->Visit("disable_select_rewriting", &disable_select_rewriting);
246
    v->Visit("disable_vectorize", &disable_vectorize);
247
  }
248 249 250

  static constexpr const char* _type_key = "BuildConfig";
  TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
251 252
};

253
/*!
254 255
 * \brief Build configuration for compilations.
 */
256 257 258
class BuildConfig : public ::tvm::NodeRef {
 public:
  BuildConfig() {}
259
  explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
260 261 262 263 264 265 266
  const BuildConfigNode* operator->() const {
    return static_cast<const BuildConfigNode*>(node_.get());
  }
  BuildConfigNode* operator->() {
    return static_cast<BuildConfigNode*>(node_.get());
  }
  /*!
267 268
   * \brief Construct a BuildConfig containing a empty build config node.
   * \return The new BuildConfig
269
   */
270
  TVM_DLL static BuildConfig Create();
271 272 273 274 275
  /*!
   * \brief Get the current BuildConfig context from thread local storage, or a default
   * configuration if a BuildConfig scope has not been entered.
   * \return The configuration that is the current context.
   */
276
  TVM_DLL static BuildConfig Current();
277 278

  using ContainerType = BuildConfigNode;
279
  class Internal;
280

281 282 283
 private:
  // Enable with syntax.
  friend class With<BuildConfig>;
284
  /*!
285
   * \brief Push a new BuildConfig context onto the thread local stack.
286
   */
287
  TVM_DLL void EnterWithScope();
288

289 290 291 292 293
  /*!
   * \brief Pop a build config off the thread local context stack,
   * restoring the previous configuration as the current context.
   */
  TVM_DLL void ExitWithScope();
294
};
295 296

/*!
297 298 299 300 301 302 303 304
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
*/
305 306 307 308 309
TVM_DLL Array<LoweredFunc> lower(Schedule sch,
                                 const Array<Tensor>& args,
                                 const std::string& name,
                                 const std::unordered_map<Tensor, Buffer>& binds,
                                 const BuildConfig& config);
310 311 312 313 314 315 316 317 318 319 320 321 322
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
          second is device function array
*/
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
                                                        const Target& target,
                                                        const Target& target_host,
                                                        const BuildConfig& config);
323 324 325 326 327

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
328
* \param target_host The target for building host code. To use the default, pass Target()
329 330 331
* \param config The build configuration.
* \return The built module.
*/
332 333 334 335
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
                              const Target& target,
                              const Target& target_host,
                              const BuildConfig& config);
336

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
/*!
 * \brief Build a device and host module for a specific target from a map
 * contains target to a list of lowered functions pairs. This function is used
 * for heterogeneous build.
 * \param input The map contains target to a list of lowered functions pairs.
 * \param target_host The target for building host code. To use the default,
 *        pass Target().
 * \param config The build configuration.
 * \return The built module that contains code for different processors.
 */
TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
                              const Target& target_host,
                              const BuildConfig& config);

/*!
 * \brief Build a device and host module for a specific target from a map
 * contains target to a list of lowered functions pairs. This function is used
 * for heterogeneous build.
 * \param input The map contains target string to a list of lowered functions
 *        pairs.
 * \param target_host The target for building host code. To use the default,
 *        pass Target().
 * \param config The build configuration.
 * \return The built module that contains code for different processors.
 */
TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input,
                              const Target& target_host,
                              const BuildConfig& config);

366 367 368 369 370 371 372 373
class GenericFuncNode;

/*!
 * \brief Generic function that can be specialized on a per-target basis.
 */
class GenericFunc : public NodeRef {
 public:
  GenericFunc() {}
374
  explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
375 376 377 378 379 380 381 382

  /*!
   * \brief Set the default function implementaiton.
   * \param value The default function
   * \param allow_override If true, this call may override a previously registered function. If
   * false, an error will be logged if the call would override a previously registered function.
   * \return reference to self.
   */
383
  TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
384 385 386 387 388 389 390 391 392 393
                                   bool allow_override = false);
  /*!
   * \brief Register a specialized function
   * \param tags The tags for this specialization
   * \param value The specialized function
   * \param allow_override If true, this call may override previously registered tags. If false,
   * an error will be logged if the call would override previously registered tags.
   * \return reference to self.
   */
  TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
394
                                     const runtime::PackedFunc value,
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
                                     bool allow_override = false);
  /*!
   * \brief Call generic function by directly passing in unpacked format.
   * \param args Arguments to be passed.
   * \tparam Args arguments to be passed.
   *
   * \code
   *   // Example code on how to call generic function
   *   void CallGeneirc(GenericFunc f) {
   *     // call like normal functions by pass in arguments
   *     // return value is automatically converted back
   *     int rvalue = f(1, 2.0);
   *   }
   * \endcode
   */
  template<typename... Args>
411
  inline runtime::TVMRetValue operator()(Args&& ...args) const;
412 413 414 415 416 417
  /*!
   * \brief Invoke the relevant function for the current target context, set by set_target_context.
   * Arguments are passed in packed format.
   * \param args The arguments to pass to the function.
   * \param ret The return value
   */
418 419
  TVM_DLL void CallPacked(runtime::TVMArgs args,
                          runtime::TVMRetValue* ret) const;
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451

  /*!
   * \brief Find or register the GenericFunc instance corresponding to the give name
   * \param name The name of the registered GenericFunc
   * \return The GenericFunc instance
   */
  TVM_DLL static GenericFunc Get(const std::string& name);

  /*!
   * \brief Add a GenericFunc instance to the registry
   * \param func The GenericFunc instance
   * \param name The name of the registered GenericFunc
   */
  TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);

  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline GenericFuncNode* operator->();

  // declare container type
  using ContainerType = GenericFuncNode;

  // Internal class.
  struct Manager;

 private:
  friend struct Manager;
};

template<typename... Args>
452
inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
453 454 455 456
  const int kNumArgs = sizeof...(Args);
  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
  TVMValue values[kArraySize];
  int type_codes[kArraySize];
457
  runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
458
    std::forward<Args>(args)...);
459
  runtime::TVMRetValue rv;
460 461 462 463 464 465 466 467 468 469 470 471
  CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
  return rv;
}

/*!
 * \brief Represents a generic function that can be specialized on a per-target basis.
 */
class GenericFuncNode : public Node {
 public:
  /*! \brief name of the function */
  std::string name_;
  /* \brief the generic builder */
472
  runtime::PackedFunc generic_func_;
473
  /* \brief map from keys to registered functions */
474
  std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498

  static constexpr const char* _type_key = "GenericFunc";
  TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};

inline GenericFuncNode* GenericFunc::operator->() {
  return static_cast<GenericFuncNode*>(node_.get());
}

#define TVM_GENERIC_FUNC_REG_VAR_DEF                               \
  static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM

/*!
 * \def TVM_REGISTER_GENERIC_FUNC
 * \brief Register a new generic function, or set a device-specific variant
 * of the corresponding function.
 *
 * \param name The name of the function
 */
#define TVM_REGISTER_GENERIC_FUNC(name)                           \
  TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) =     \
      ::tvm::GenericFunc::Get(#name)


499 500 501
}  // namespace tvm

#endif  // TVM_BUILD_MODULE_H_