arg_binder.h 5.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file arg_binder.h
 * \brief Helper utility to match and bind arguments.
 */
24 25
#ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_
#define TVM_TIR_TRANSFORMS_ARG_BINDER_H_
26

27 28
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
29 30
#include <tvm/arith/analyzer.h>

31 32
#include <string>
#include <vector>
33
#include <unordered_map>
34 35

namespace tvm {
36
namespace tir {
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

/*!
 * \brief Helper utility to generate match and bind of arguments.
 *
 * \note There is many places in TVM IR where we need argument bindings.
 *
 *  Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)).
 *  Here n is a undefined variable that is decided by the outside, tB imposes
 *  a constraint such that it can only take tensor with shape 3, tC imposes
 *  another constraint that it's shape must equals n + 2.
 *  So if we call it with f(bufferA, bufferB, bufferC), we need to generate
 *  the following binding sequence:
 *  - define n = bufferA.shape[0]
 *  - assert bufferB.shape[0] == 3
 *  - assert bufferB.shape[1] == n + 3
 *
 *  In general, this is a constraint solving problem. We have simplified assumption
54
 *  over the binding declaration, such that we require the variable occurred in
55 56 57 58 59 60 61 62 63 64 65 66
 *  constraint must be declared in argument list. So it is illegal to have signature
 *  f(tA(shape=(n+3))) without any argument variable corresponds to n, even though
 *  it is already enough to derive n from the input argument.
 */
class ArgBinder {
 public:
  /*!
   * \brief Constructor
   * \param def_map A definition map that contains definition of known variables.
   *   ArgBinder will update this def_map when adding new definitions.
   */
  explicit ArgBinder(
67
      std::unordered_map<const VarNode*, PrimExpr>* def_map)
68 69 70 71 72 73 74 75 76
      : def_map_(def_map) {
  }
  /*!
   * \brief Try to bind arg to value, generate constraint if necessary.
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
   * \param with_let Whether add lets during bind
   */
77 78
  void Bind(const PrimExpr& arg,
            const PrimExpr& value,
79 80 81 82 83 84 85 86
            const std::string& arg_name,
            bool with_let = false);
  /*!
   * \brief Bind array to array
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
   */
87 88
  void BindArray(const Array<PrimExpr>& arg,
                 const Array<PrimExpr>& value,
89 90 91 92 93 94
                 const std::string& arg_name);
  /*!
   * \brief Bind symbolic buffer to another symbolic buffer
   * \param arg The argument to be binded.
   * \param value The target expression value
   * \param arg_name argument name.
95
   * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
96 97 98
   */
  void BindBuffer(const Buffer& arg,
                  const Buffer& value,
99 100
                  const std::string& arg_name,
                  bool fuzzy_match);
101 102 103 104 105 106 107 108 109
  /*!
   * \brief Bind symbolic buffer to a DLTensor handle.
   * \param buffer The argument buffer to be binded.
   * \param device_type The device id to be binded.
   * \param device_id The device id to be binded.
   * \param handle The DLTensor handle.
   * \param arg_name argument name.
   */
  void BindDLTensor(const Buffer& buffer,
110 111
                    const PrimExpr& device_type,
                    const PrimExpr& device_id,
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
                    const Var& handle,
                    const std::string& arg_name);

  /*! \return The defs generated in binding. */
  const std::vector<Var>& defs() const {
    return defs_;
  }
  /*! \return The asserts generated in binding */
  const std::vector<Stmt>& asserts() const {
    return asserts_;
  }
  /*!
   * \brief Initialization nest generated
   *  This is only non-empty when BindDLTensor is called.
   *
   * \note The binder may choose to generate a let statement
   *  and simply put def_map to map Variable to itself,
   *  or update def_map to directly map to new value and not generate let statement.
   *
   *  Let statement is usually generated when bind to DLTensor and memory load is involved.
   * \return The initialization nest generated during binding.
   */
  const std::vector<Stmt>& init_nest() const {
    return init_nest_;
  }
  /*! \return Handle data type of the data */
138
  const Map<Var, PrimExpr>& def_handle_dtype() const {
139 140 141 142 143
    return def_handle_dtype_;
  }

 private:
  // Internal bind function
144 145
  bool Bind_(const PrimExpr& arg,
             const PrimExpr& value,
146 147 148
             const std::string& arg_name,
             bool with_lets);
  /*! \brief The definition map, can be uses to substitute */
149
  std::unordered_map<const VarNode*, PrimExpr>* def_map_;
150 151 152 153 154
  /*! \brief defs generated in the current binder */
  std::vector<Var> defs_;
  /*! \brief Initialize nest */
  std::vector<Stmt> init_nest_;
  /*! \brief handle data type in the defintiions */
155
  Map<Var, PrimExpr> def_handle_dtype_;
156 157
  /*! \brief asserts generated */
  std::vector<Stmt> asserts_;
158 159
  /*! \brief internal analyzer. */
  arith::Analyzer analyzer_;
160
};
161
}  // namespace tir
162
}  // namespace tvm
163
#endif  // TVM_TIR_TRANSFORMS_ARG_BINDER_H_