message_passing.h 3.78 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
/*!
 *  Copyright (c) 2017 by Contributors
 * \file message_passing.h
 * \brief Common utilities to do message passing
 *  on the schedule hyper graph.
 */
#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
#define TVM_SCHEDULE_MESSAGE_PASSING_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace tvm {
namespace schedule {
/*!
 * \brief Downward inference of domain of each IterVar.
 *  Caller set the range of the root, then the function
 *  propagates it towards the leaves.
 *
 * \param stage The stage to operate on.
 * \param p_state The state of the message passing.
 * \param allow_missing Whether allow missing value.
 */
void PassDownDomain(
    const Stage& stage,
    std::unordered_map<IterVar, Range>* p_state,
    bool allow_missing = false);

/*!
 * \param Upward inference of index of each IterVar.
 *  given index assignement of the leaves,
 *
 * \param stage The stage to operate on.
 * \param dom_map The domain map of each iteration variable's domain.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassUpIndex(const Stage& stage,
                 const Map<IterVar, Range>& dom_map,
                 std::unordered_map<IterVar, Expr>* p_state,
                 bool allow_missing = false);

/*!
49 50 51 52 53 54 55 56 57 58 59 60 61 62
 * \param Downward inference of index of each IterVar.
 *  given index assignement of roots.
 *
 * \param stage The stage to operate on.
 * \param dom_map The domain map of each iteration variable's domain.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassDownIndex(const Stage& stage,
                   const Map<IterVar, Range>& dom_map,
                   std::unordered_map<IterVar, Expr>* p_state,
                   bool allow_missing = false);

/*!
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
 * \param Upward inference of domain set of each IterVar.
 *  given domain assignment of the leaves,
 *
 * \param stage The stage to operate on.
 * \param dom_map The domain map of each iteration variable's maximum domain.
 * \param p_state The index state of each IterVar.
 */
void PassUpDomain(const Stage& stage,
                  const std::unordered_map<IterVar, Range>& dom_map,
                  std::unordered_map<IterVar, IntSet>* p_state);

/*!
 * \brief Upward message passing of bitmask with or relation.
 * \param stage The stage to operate on.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassUpBitMaskOr(const Stage& stage,
                     std::unordered_map<IterVar, int>* p_state,
                     bool allow_missing = false);

/*!
 * \brief Downward message passing of bitmask with or relation.
 * \param stage The stage to operate on.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassDownBitMaskOr(const Stage& stage,
                       std::unordered_map<IterVar, int>* p_state,
                       bool allow_missing = false);
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

/*!
 * \brief Create boundary check predicates given remapped value of root
 * \param stage The stage we operate on
 * \param dom_map The domain map of each value.
 * \param value_map The value map of the root iter var.
 * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
 * \param skip_iter The set of variables to skip bound condition.
 * \return List of predicates that we need to check.
 */
std::vector<Expr>
MakeBoundCheck(
    const Stage& stage,
    const Map<IterVar, Range>& dom_map,
    const std::unordered_map<IterVar, Expr>& value_map,
    bool skip_ivar_domain,
    const std::unordered_set<IterVar>& skip_iter);

111 112 113
}  // namespace schedule
}  // namespace tvm
#endif  // TVM_SCHEDULE_MESSAGE_PASSING_H_