graph.h 3.95 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file graph.h
 * \brief Utilities to get information about schedule graph.
 */
24 25
#ifndef TVM_TE_SCHEDULE_GRAPH_H_
#define TVM_TE_SCHEDULE_GRAPH_H_
26

27
#include <tvm/tir/expr.h>
28 29
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
30
#include <unordered_map>
31
#include <unordered_set>
32 33 34
#include <vector>

namespace tvm {
35
namespace te {
36 37 38 39

/*!
 * \brief data structure of Operation->Tensors it reads
 */
40
using ReadGraph = Map<Operation, Array<Tensor> >;
41 42

/*!
43 44 45 46 47
 * \brief AttachPath maps op-> a list of IterVar
 */
using AttachPath = Map<Operation, Array<IterVar> >;

/*!
48
 * \brief The map between tensor and operation it feeds to.
49 50 51 52
 */
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;

/*!
53 54 55 56
 * \brief Get read graph of each operation to all the
 *  Tensors that it directly depends on.
 *
 *  The result map contains Operations needed to finish root Operation.
57
 * \param roots The root operation.
58 59
 * \return The result map.
 */
60
ReadGraph CreateReadGraph(const Array<Operation>& roots);
61 62

/*!
63 64 65 66
 * \brief Get minimum subgraph between outputs and inputs.
 *  The operations contains node which input-reachable from any inputs
 *  output reachable to any outputs.
 *
67
 *  The inputs won't be included in the subgraph, the outputs will be included.
68 69 70 71 72 73 74 75 76 77 78 79
 *
 * \param outputs The outputs of the subgraph
 * \param inputs The inputs to the subgraph.
 * \param include_inputs Whether to include inputs
 *
 * \return The subgraph.
 */
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
                             const Array<Tensor>& inputs,
                             bool include_inputs);

/*!
80
 * \brief Get a post DFS ordered of operations in the graph.
81
 * \param roots The root of the graph.
82 83 84 85 86 87
 * \param g The read graph.
 * \return vector order of Operations in PostDFS order.
 *
 * \note PostDFSOrder is a special case of Topoligical order,
 *   and can be used when topoligical order is needed.
 */
88
Array<Operation> PostDFSOrder(
89
    const Array<Operation>& roots, const ReadGraph& g);
90

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
/*!
 * \brief Create feedgraph for given Schedule
 * \param  g The read graph.
 * \return The created feedgraph.
 */
FeedGraph CreateFeedGraph(const ReadGraph& g);

/*!
 * \brief Create AttachPath that  maps op-> a list of IterVar
 *  That represents the loop nest op sits in from inner most to outermost
 *  Also inserts attach_stage for scan updates when needed.
 *
 * \param sch The schedule.
 * \return The attach path.
 */
AttachPath CreateAttachPath(Schedule sch);

/*!
 * \brief Get all operations inside the recursion of scan.
110
 * \param scan_op The scan node ops.
111 112
 * \return The body operations, in read dependency order.
 */
113
Array<Operation> ScanGetBody(const Operation& scan_op);
114 115 116 117 118 119 120 121 122 123 124 125

/*!
 * \brief Analyze each spatial dimension of scan's result.
 *  Give check on whether each dimension is fix point,
 *  An axis is a fixed point if it only refers back to itself in recursion
 *  and it is not used in axis of other recursion field.
 *
 *  next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
 *
 * \param scan The scan node.
 * \return Map of spatial_axis -> IntImm
 */
126
Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
127

128
}  // namespace te
129 130
}  // namespace tvm

131
#endif  // TVM_TE_SCHEDULE_GRAPH_H_