/*! * Copyright (c) 2017 by Contributors * \file arg_binder.h * \brief Helper utility to match and bind arguments. */ #ifndef TVM_PASS_ARG_BINDER_H_ #define TVM_PASS_ARG_BINDER_H_ #include <tvm/expr.h> #include <tvm/buffer.h> #include <string> #include <vector> namespace tvm { namespace ir { /*! * \brief Helper utility to generate match and bind of arguments. * * \note There is many places in TVM IR where we need argument bindings. * * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). * Here n is a undefined variable that is decided by the outside, tB imposes * a constraint such that it can only take tensor with shape 3, tC imposes * another constraint that it's shape must equals n + 2. * So if we call it with f(bufferA, bufferB, bufferC), we need to generate * the following binding sequence: * - define n = bufferA.shape[0] * - assert bufferB.shape[0] == 3 * - assert bufferB.shape[1] == n + 3 * * In general, this is a constraint solving problem. We have simplified assumption * over the binding declaration, such that we require the variable occured in * constraint must be declared in argument list. So it is illegal to have signature * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though * it is already enough to derive n from the input argument. */ class ArgBinder { public: /*! * \brief Constructor * \param def_map A definition map that contains definition of known variables. * ArgBinder will update this def_map when adding new definitions. */ explicit ArgBinder( std::unordered_map<const Variable*, Expr>* def_map) : def_map_(def_map) { } /*! * \brief Try to bind arg to value, generate constraint if necessary. * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. * \param with_let Whether add lets during bind */ void Bind(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_let = false); /*! * \brief Bind array to array * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. */ void BindArray(const Array<Expr>& arg, const Array<Expr>& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1. */ void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match); /*! * \brief Bind symbolic buffer to a DLTensor handle. * \param buffer The argument buffer to be binded. * \param device_type The device id to be binded. * \param device_id The device id to be binded. * \param handle The DLTensor handle. * \param arg_name argument name. */ void BindDLTensor(const Buffer& buffer, const Expr& device_type, const Expr& device_id, const Var& handle, const std::string& arg_name); /*! \return The defs generated in binding. */ const std::vector<Var>& defs() const { return defs_; } /*! \return The asserts generated in binding */ const std::vector<Stmt>& asserts() const { return asserts_; } /*! * \brief Initialization nest generated * This is only non-empty when BindDLTensor is called. * * \note The binder may choose to generate a let statement * and simply put def_map to map Variable to itself, * or update def_map to directly map to new value and not generate let statement. * * Let statement is usually generated when bind to DLTensor and memory load is involved. * \return The initialization nest generated during binding. */ const std::vector<Stmt>& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ const Map<Var, Expr>& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function bool Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets); /*! \brief The definition map, can be uses to substitute */ std::unordered_map<const Variable*, Expr>* def_map_; /*! \brief defs generated in the current binder */ std::vector<Var> defs_; /*! \brief Initialize nest */ std::vector<Stmt> init_nest_; /*! \brief handle data type in the defintiions */ Map<Var, Expr> def_handle_dtype_; /*! \brief asserts generated */ std::vector<Stmt> asserts_; }; } // namespace ir } // namespace tvm #endif // TVM_PASS_ARG_BINDER_H_