message_passing.h 2.63 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
/*!
 *  Copyright (c) 2017 by Contributors
 * \file message_passing.h
 * \brief Common utilities to do message passing
 *  on the schedule hyper graph.
 */
#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
#define TVM_SCHEDULE_MESSAGE_PASSING_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace tvm {
namespace schedule {
/*!
 * \brief Downward inference of domain of each IterVar.
 *  Caller set the range of the root, then the function
 *  propagates it towards the leaves.
 *
 * \param stage The stage to operate on.
 * \param p_state The state of the message passing.
 * \param allow_missing Whether allow missing value.
 */
void PassDownDomain(
    const Stage& stage,
    std::unordered_map<IterVar, Range>* p_state,
    bool allow_missing = false);

/*!
 * \param Upward inference of index of each IterVar.
 *  given index assignement of the leaves,
 *
 * \param stage The stage to operate on.
 * \param dom_map The domain map of each iteration variable's domain.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassUpIndex(const Stage& stage,
                 const Map<IterVar, Range>& dom_map,
                 std::unordered_map<IterVar, Expr>* p_state,
                 bool allow_missing = false);

/*!
 * \param Upward inference of domain set of each IterVar.
 *  given domain assignment of the leaves,
 *
 * \param stage The stage to operate on.
 * \param dom_map The domain map of each iteration variable's maximum domain.
 * \param p_state The index state of each IterVar.
 */
void PassUpDomain(const Stage& stage,
                  const std::unordered_map<IterVar, Range>& dom_map,
                  std::unordered_map<IterVar, IntSet>* p_state);

/*!
 * \brief Upward message passing of bitmask with or relation.
 * \param stage The stage to operate on.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassUpBitMaskOr(const Stage& stage,
                     std::unordered_map<IterVar, int>* p_state,
                     bool allow_missing = false);

/*!
 * \brief Downward message passing of bitmask with or relation.
 * \param stage The stage to operate on.
 * \param p_state The index state of each IterVar.
 * \param allow_missing Whether allow missing value.
 */
void PassDownBitMaskOr(const Stage& stage,
                       std::unordered_map<IterVar, int>* p_state,
                       bool allow_missing = false);
}  // namespace schedule
}  // namespace tvm
#endif  // TVM_SCHEDULE_MESSAGE_PASSING_H_