graph.h 3.98 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *  Copyright (c) 2016 by Contributors
 * \file graph.h
 * \brief Utilities to get information about schedule graph.
 */
#ifndef TVM_SCHEDULE_GRAPH_H_
#define TVM_SCHEDULE_GRAPH_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
30
#include <tvm/operation.h>
31
#include <unordered_map>
32
#include <unordered_set>
33 34 35 36 37 38 39 40
#include <vector>

namespace tvm {
namespace schedule {

/*!
 * \brief data structure of Operation->Tensors it reads
 */
41
using ReadGraph = Map<Operation, Array<Tensor> >;
42 43

/*!
44 45 46 47 48
 * \brief AttachPath maps op-> a list of IterVar
 */
using AttachPath = Map<Operation, Array<IterVar> >;

/*!
49
 * \brief The map between tensor and operation it feeds to.
50 51 52 53
 */
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;

/*!
54 55 56 57
 * \brief Get read graph of each operation to all the
 *  Tensors that it directly depends on.
 *
 *  The result map contains Operations needed to finish root Operation.
58
 * \param roots The root operation.
59 60
 * \return The result map.
 */
61
ReadGraph CreateReadGraph(const Array<Operation>& roots);
62 63

/*!
64 65 66 67
 * \brief Get minimum subgraph between outputs and inputs.
 *  The operations contains node which input-reachable from any inputs
 *  output reachable to any outputs.
 *
68
 *  The inputs won't be included in the subgraph, the outputs will be included.
69 70 71 72 73 74 75 76 77 78 79 80
 *
 * \param outputs The outputs of the subgraph
 * \param inputs The inputs to the subgraph.
 * \param include_inputs Whether to include inputs
 *
 * \return The subgraph.
 */
Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
                             const Array<Tensor>& inputs,
                             bool include_inputs);

/*!
81
 * \brief Get a post DFS ordered of operations in the graph.
82
 * \param roots The root of the graph.
83 84 85 86 87 88
 * \param g The read graph.
 * \return vector order of Operations in PostDFS order.
 *
 * \note PostDFSOrder is a special case of Topoligical order,
 *   and can be used when topoligical order is needed.
 */
89
Array<Operation> PostDFSOrder(
90
    const Array<Operation>& roots, const ReadGraph& g);
91

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
/*!
 * \brief Create feedgraph for given Schedule
 * \param  g The read graph.
 * \return The created feedgraph.
 */
FeedGraph CreateFeedGraph(const ReadGraph& g);

/*!
 * \brief Create AttachPath that  maps op-> a list of IterVar
 *  That represents the loop nest op sits in from inner most to outermost
 *  Also inserts attach_stage for scan updates when needed.
 *
 * \param sch The schedule.
 * \return The attach path.
 */
AttachPath CreateAttachPath(Schedule sch);

/*!
 * \brief Get all operations inside the recursion of scan.
111
 * \param scan_op The scan node ops.
112 113
 * \return The body operations, in read dependency order.
 */
114
Array<Operation> ScanGetBody(const Operation& scan_op);
115 116 117 118 119 120 121 122 123 124 125 126

/*!
 * \brief Analyze each spatial dimension of scan's result.
 *  Give check on whether each dimension is fix point,
 *  An axis is a fixed point if it only refers back to itself in recursion
 *  and it is not used in axis of other recursion field.
 *
 *  next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
 *
 * \param scan The scan node.
 * \return Map of spatial_axis -> IntImm
 */
127
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
128

129 130 131 132
}  // namespace schedule
}  // namespace tvm

#endif  // TVM_SCHEDULE_GRAPH_H_