arg_binder.h 4.85 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
/*!
 *  Copyright (c) 2017 by Contributors
 * \file arg_binder.h
 * \brief Helper utility to match and bind arguments.
 */
#ifndef TVM_PASS_ARG_BINDER_H_
#define TVM_PASS_ARG_BINDER_H_

#include <tvm/expr.h>
#include <tvm/buffer.h>
#include <string>
#include <vector>

namespace tvm {
namespace ir {

/*!
 * \brief Helper utility to generate match and bind of arguments.
 *
 * \note There is many places in TVM IR where we need argument bindings.
 *
 *  Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)).
 *  Here n is a undefined variable that is decided by the outside, tB imposes
 *  a constraint such that it can only take tensor with shape 3, tC imposes
 *  another constraint that it's shape must equals n + 2.
 *  So if we call it with f(bufferA, bufferB, bufferC), we need to generate
 *  the following binding sequence:
 *  - define n = bufferA.shape[0]
 *  - assert bufferB.shape[0] == 3
 *  - assert bufferB.shape[1] == n + 3
 *
 *  In general, this is a constraint solving problem. We have simplified assumption
 *  over the binding declaration, such that we require the variable occured in
 *  constraint must be declared in argument list. So it is illegal to have signature
 *  f(tA(shape=(n+3))) without any argument variable corresponds to n, even though
 *  it is already enough to derive n from the input argument.
 */
class ArgBinder {
 public:
  /*!
   * \brief Constructor
   * \param def_map A definition map that contains definition of known variables.
   *   ArgBinder will update this def_map when adding new definitions.
   */
  explicit ArgBinder(
      std::unordered_map<const Variable*, Expr>* def_map)
      : def_map_(def_map) {
  }
  /*!
   * \brief Try to bind arg to value, generate constraint if necessary.
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
   * \param with_let Whether add lets during bind
   */
  void Bind(const Expr& arg,
            const Expr& value,
            const std::string& arg_name,
            bool with_let = false);
  /*!
   * \brief Bind array to array
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
   */
  void BindArray(const Array<Expr>& arg,
                 const Array<Expr>& value,
                 const std::string& arg_name);
  /*!
   * \brief Bind symbolic buffer to another symbolic buffer
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
74
   * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
75 76 77
   */
  void BindBuffer(const Buffer& arg,
                  const Buffer& value,
78 79
                  const std::string& arg_name,
                  bool fuzzy_match);
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
  /*!
   * \brief Bind symbolic buffer to a DLTensor handle.
   * \param buffer The argument buffer to be binded.
   * \param device_type The device id to be binded.
   * \param device_id The device id to be binded.
   * \param handle The DLTensor handle.
   * \param arg_name argument name.
   */
  void BindDLTensor(const Buffer& buffer,
                    const Expr& device_type,
                    const Expr& device_id,
                    const Var& handle,
                    const std::string& arg_name);

  /*! \return The defs generated in binding. */
  const std::vector<Var>& defs() const {
    return defs_;
  }
  /*! \return The asserts generated in binding */
  const std::vector<Stmt>& asserts() const {
    return asserts_;
  }
  /*!
   * \brief Initialization nest generated
   *  This is only non-empty when BindDLTensor is called.
   *
   * \note The binder may choose to generate a let statement
   *  and simply put def_map to map Variable to itself,
   *  or update def_map to directly map to new value and not generate let statement.
   *
   *  Let statement is usually generated when bind to DLTensor and memory load is involved.
   * \return The initialization nest generated during binding.
   */
  const std::vector<Stmt>& init_nest() const {
    return init_nest_;
  }
  /*! \return Handle data type of the data */
  const Map<Var, Expr>& def_handle_dtype() const {
    return def_handle_dtype_;
  }

 private:
  // Internal bind function
  bool Bind_(const Expr& arg,
             const Expr& value,
             const std::string& arg_name,
             bool with_lets);
  /*! \brief The definition map, can be uses to substitute */
  std::unordered_map<const Variable*, Expr>* def_map_;
  /*! \brief defs generated in the current binder */
  std::vector<Var> defs_;
  /*! \brief Initialize nest */
  std::vector<Stmt> init_nest_;
  /*! \brief handle data type in the defintiions */
  Map<Var, Expr> def_handle_dtype_;
  /*! \brief asserts generated */
  std::vector<Stmt> asserts_;
};
}  // namespace ir
}  // namespace tvm
#endif  // TVM_PASS_ARG_BINDER_H_