Unverified Commit 14ba49c6 by Zhi Committed by GitHub

[refactor][relay pass] Separate analysis and transform passes (#5035)

* [refactor][relay pass] Separate analysis and transform passes into different subfolders

* remove pass folder
parent d2a79a5f
...@@ -146,7 +146,9 @@ file(GLOB_RECURSE RELAY_OP_SRCS ...@@ -146,7 +146,9 @@ file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc src/relay/op/*.cc
) )
file(GLOB_RECURSE RELAY_PASS_SRCS file(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/pass/*.cc src/relay/analysis/*.cc
src/relay/transforms/*.cc
src/relay/quantize/*.cc
) )
file(GLOB RELAY_BACKEND_SRCS file(GLOB RELAY_BACKEND_SRCS
src/relay/backend/*.cc src/relay/backend/*.cc
......
...@@ -23,7 +23,6 @@ configuring the passes and scripting them in Python. ...@@ -23,7 +23,6 @@ configuring the passes and scripting them in Python.
from tvm.ir import RelayExpr, IRModule from tvm.ir import RelayExpr, IRModule
from . import _analysis from . import _analysis
from . import _make
from .ty import Type from .ty import Type
from .feature import Feature from .feature import Feature
...@@ -237,7 +236,7 @@ def alpha_equal(lhs, rhs): ...@@ -237,7 +236,7 @@ def alpha_equal(lhs, rhs):
result : bool result : bool
True iff lhs is alpha equal to rhs. True iff lhs is alpha equal to rhs.
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_analysis._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs): def assert_alpha_equal(lhs, rhs):
...@@ -251,7 +250,7 @@ def assert_alpha_equal(lhs, rhs): ...@@ -251,7 +250,7 @@ def assert_alpha_equal(lhs, rhs):
rhs : tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
""" """
_make._assert_alpha_equal(lhs, rhs) _analysis._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs): def graph_equal(lhs, rhs):
...@@ -273,7 +272,7 @@ def graph_equal(lhs, rhs): ...@@ -273,7 +272,7 @@ def graph_equal(lhs, rhs):
result : bool result : bool
True iff lhs is data-flow equivalent to rhs. True iff lhs is data-flow equivalent to rhs.
""" """
return bool(_make._graph_equal(lhs, rhs)) return bool(_analysis._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs): def assert_graph_equal(lhs, rhs):
...@@ -290,7 +289,7 @@ def assert_graph_equal(lhs, rhs): ...@@ -290,7 +289,7 @@ def assert_graph_equal(lhs, rhs):
rhs : tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
""" """
_make._assert_graph_equal(lhs, rhs) _analysis._assert_graph_equal(lhs, rhs)
def collect_device_info(expr): def collect_device_info(expr):
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "doc.h" #include "doc.h"
#include "meta_data.h" #include "meta_data.h"
#include "../relay/pass/dependency_graph.h" #include "../relay/analysis/dependency_graph.h"
#include "../ir/attr_functor.h" #include "../ir/attr_functor.h"
namespace tvm { namespace tvm {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/ir/alpha_equal.cc * \file src/relay/analysis/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes. * \brief Alpha equality check by deep comparing two nodes.
*/ */
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
...@@ -593,8 +593,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { ...@@ -593,8 +593,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
} }
// TODO(@jroesch): move to correct namespace? TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal")
TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b); return AlphaEqualHandler(false, false).Equal(a, b);
}); });
...@@ -604,18 +603,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal") ...@@ -604,18 +603,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
return AlphaEqual(a, b); return AlphaEqual(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
}); });
TVM_REGISTER_GLOBAL("relay._make._graph_equal") TVM_REGISTER_GLOBAL("relay._analysis._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b); return AlphaEqualHandler(true, false).Equal(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal") TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file tvm/relay/pass/call_graph.cc * \file src/relay/analysis/call_graph.cc
* \brief Implementation of APIs to handle the call graph of a Relay module. * \brief Implementation of APIs to handle the call graph of a Relay module.
*/ */
......
...@@ -18,15 +18,15 @@ ...@@ -18,15 +18,15 @@
*/ */
/*! /*!
* \file tvm/relay/pass/call_graph.h * \file src/relay/analysis/call_graph.h
* \brief Define data structures for the call graph of a IRModule. It borrows * \brief Define data structures for the call graph of a IRModule. It borrows
* the idea how LLVM constructs CallGraph. * the idea how LLVM constructs CallGraph.
* *
* https://llvm.org/doxygen/CallGraph_8h_source.html * https://llvm.org/doxygen/CallGraph_8h_source.html
*/ */
#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_ #ifndef TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
#define TVM_RELAY_PASS_CALL_GRAPH_H_ #define TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -510,4 +510,4 @@ class CallGraphEntry { ...@@ -510,4 +510,4 @@ class CallGraphEntry {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_CALL_GRAPH_H_ #endif // TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
*/ */
/*! /*!
* \file tvm/relay/pass/dependency_graph.cc * \file src/relay/analysis/dependency_graph.cc
* \brief * \brief Implementation of dependency graph APIs.
*/ */
#include "dependency_graph.h" #include "dependency_graph.h"
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
......
...@@ -18,16 +18,16 @@ ...@@ -18,16 +18,16 @@
*/ */
/*! /*!
* \file tvm/relay/pass/dependency_graph.h * \file src/relay/analysis/dependency_graph.h
* \brief create a dependency graph. * \brief create a dependency graph.
*/ */
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #ifndef TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "let_list.h" #include "../transforms/let_list.h"
#include "../../support/arena.h" #include "../../support/arena.h"
namespace tvm { namespace tvm {
...@@ -72,4 +72,4 @@ class DependencyGraph { ...@@ -72,4 +72,4 @@ class DependencyGraph {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #endif // TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include "pass_util.h" #include "../transforms/pass_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include "pattern_util.h" #include "../transforms/pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* \file type_solver.h * \file type_solver.h
* \brief Solver logic for type inference. * \brief Solver logic for type inference.
*/ */
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_ #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_ #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
#include <unordered_set> #include <unordered_set>
#include "../../support/arena.h" #include "../../support/arena.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -219,4 +218,4 @@ class TypeSolver { ...@@ -219,4 +218,4 @@ class TypeSolver {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SOLVER_H_ #endif // TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "pass_util.h" #include "../transforms/pass_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <vector> #include <vector>
#include "../utils.h" #include "../utils.h"
#include "../../backend/compile_engine.h" #include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h" #include "../../transforms/pass_util.h"
#include "../../op/op_common.h" #include "../../op/op_common.h"
#include "compiler.h" #include "compiler.h"
......
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
#include "../../../runtime/vm/profiler/vm.h" #include "../../../runtime/vm/profiler/vm.h"
#include "../../../runtime/vm/naive_allocator.h" #include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h" #include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h" #include "../../transforms/pass_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h> #include <topi/elemwise.h>
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/infer_layout_util.h" #include "../transforms/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <tvm/relay/attrs/memory.h> #include <tvm/relay/attrs/memory.h>
#include "../op_common.h" #include "../op_common.h"
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../type_relations.h" #include "../type_relations.h"
namespace tvm { namespace tvm {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include "../op_common.h" #include "../op_common.h"
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../op_common.h" #include "../op_common.h"
#include "convolution.h" #include "convolution.h"
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "../type_relations.h" #include "../type_relations.h"
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../op_common.h" #include "../op_common.h"
#include "nn.h" #include "nn.h"
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h> #include <topi/nn/pooling.h>
#include <vector> #include <vector>
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "type_relations.h" #include "type_relations.h"
#include "../pass/infer_layout_util.h" #include "../transforms/infer_layout_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -36,8 +36,8 @@ ...@@ -36,8 +36,8 @@
#include <vector> #include <vector>
#include "../op_common.h" #include "../op_common.h"
#include "../../../arith/compute_expr.h" #include "../../../arith/compute_expr.h"
#include "../../pass/infer_layout_util.h" #include "../../transforms/infer_layout_util.h"
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "transform.h" #include "transform.h"
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
#include "op_common.h" #include "op_common.h"
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../op/tensor/transform.h" #include "../../op/tensor/transform.h"
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "../../op/nn/convolution.h" #include "../../op/nn/convolution.h"
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../op/nn/nn.h" #include "../../op/nn/nn.h"
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
#include "op_common.h" #include "op_common.h"
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h> #include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h" #include "../../transforms/pattern_util.h"
#include "../util.h" #include "../util.h"
namespace tvm { namespace tvm {
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include "util.h" #include "util.h"
#include "../pass/pattern_util.h" #include "../transforms/pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
*/ */
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "../pattern_util.h" #include "../transforms/pattern_util.h"
#include "./quantize.h" #include "./quantize.h"
namespace tvm { namespace tvm {
......
...@@ -18,16 +18,16 @@ ...@@ -18,16 +18,16 @@
*/ */
/*! /*!
* \file tvm/relay/pass/quantize.h * \file tvm/relay/quantize.h
* \brief Header of definitions for quantization * \brief Header of definitions for quantization
*/ */
#ifndef TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_
#define TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ #define TVM_RELAY_QUANTIZE_QUANTIZE_H_
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include "../pattern_util.h" #include "../transforms/pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -163,4 +163,4 @@ struct QConfigContext { ...@@ -163,4 +163,4 @@ struct QConfigContext {
} // namespace quantize } // namespace quantize
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ #endif // TVM_RELAY_QUANTIZE_QUANTIZE_H_
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include "./quantize.h" #include "./quantize.h"
#include "../pattern_util.h" #include "../transforms/pattern_util.h"
#include "../../qnn/util.h" #include "../qnn/util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/relay/pass/annotate_target.cc * \file src/relay/transforms/annotate_target.cc
* \brief Wraps a call with compiler_begin and compiler_end to indicate that * \brief Wraps a call with compiler_begin and compiler_end to indicate that
* the op of this call node will use external compiler. * the op of this call node will use external compiler.
*/ */
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pass_util.h" #include "pass_util.h"
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -41,8 +41,8 @@ ...@@ -41,8 +41,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h" #include "./combine_parallel_op.h"
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h" #include "pattern_util.h"
#include "./combine_parallel_op_batch.h" #include "./combine_parallel_op_batch.h"
namespace tvm { namespace tvm {
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* \file combine_parallel_op.h * \file combine_parallel_op.h
* \brief Abstract class to combine parallel ops and their successive element-wise ops. * \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/ */
#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ #ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h" #include "pattern_util.h"
namespace tvm { namespace tvm {
...@@ -237,4 +237,4 @@ class ParallelOpCombiner { ...@@ -237,4 +237,4 @@ class ParallelOpCombiner {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ #endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
...@@ -53,9 +53,9 @@ ...@@ -53,9 +53,9 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h" #include "./combine_parallel_op.h"
#include "./combine_parallel_op_batch.h" #include "./combine_parallel_op_batch.h"
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
*/ */
/*! /*!
* \file combine_parallel_op_batch.cc * \file combine_parallel_op_batch.h
* \brief Combine parallel ops into a single batch op. * \brief Combine parallel ops into a single batch op.
*/ */
#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ #ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
...@@ -34,8 +34,8 @@ ...@@ -34,8 +34,8 @@
#include <unordered_set> #include <unordered_set>
#include <string> #include <string>
#include "./expr_subst.h" #include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h" #include "./combine_parallel_op.h"
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -145,4 +145,4 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { ...@@ -145,4 +145,4 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ #endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
...@@ -563,9 +563,6 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) { ...@@ -563,9 +563,6 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo") TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo")
.set_body_typed(CollectDeviceInfo); .set_body_typed(CollectDeviceInfo);
TVM_REGISTER_GLOBAL("relay._analysis.RewriteDeviceAnnotation")
.set_body_typed(RewriteAnnotatedOps);
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps")
.set_body_typed(CollectDeviceAnnotationOps); .set_body_typed(CollectDeviceAnnotationOps);
...@@ -574,7 +571,7 @@ namespace transform { ...@@ -574,7 +571,7 @@ namespace transform {
Pass RewriteAnnotatedOps(int fallback_device) { Pass RewriteAnnotatedOps(int fallback_device) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device)); return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
}; };
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
{tir::StringImmNode::make("InferType")}); {tir::StringImmNode::make("InferType")});
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <unordered_map> #include <unordered_map>
#include "./pattern_util.h" #include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* \file expr_subst.h * \file expr_subst.h
* \brief Utility functions for substituting expressions. * \brief Utility functions for substituting expressions.
*/ */
#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_ #ifndef TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_
#define TVM_RELAY_PASS_EXPR_SUBST_H_ #define TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <unordered_map> #include <unordered_map>
...@@ -34,4 +34,4 @@ Expr ExprSubst(const Expr& expr, ...@@ -34,4 +34,4 @@ Expr ExprSubst(const Expr& expr,
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_EXPR_SUBST_H_ #endif // TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
/*! /*!
* *
* \file src/tvm/relay/pass/fuse_ops.cc * \file src/relay/transforms/fuse_ops.cc
* *
* \brief This is a backend-aware optimization pass. * \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one. * Fuse necessary ops into a single one.
...@@ -29,10 +29,9 @@ ...@@ -29,10 +29,9 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "./pattern_util.h" #include "pattern_util.h"
#include "../../support/arena.h" #include "../../support/arena.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
custom layouts or other general weight pre-transformation. custom layouts or other general weight pre-transformation.
*/ */
#ifndef TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ #ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#define TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ #define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -232,4 +232,4 @@ static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts ...@@ -232,4 +232,4 @@ static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ #endif // TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file tvm/relay/pass/inline.cc * \file src/relay/transforms/inline.cc
* \brief Global function inliner. It contains the following steps: * \brief Global function inliner. It contains the following steps:
* *
* - Preprocessing: eligibility checking. Only inline the functions that can * - Preprocessing: eligibility checking. Only inline the functions that can
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "call_graph.h" #include "../analysis/call_graph.h"
using namespace tvm::runtime; using namespace tvm::runtime;
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
* if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
* the AST will contain 2 'a', as b and c are now variables. * the AST will contain 2 'a', as b and c are now variables.
*/ */
#ifndef TVM_RELAY_PASS_LET_LIST_H_ #ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_
#define TVM_RELAY_PASS_LET_LIST_H_ #define TVM_RELAY_TRANSFORMS_LET_LIST_H_
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
...@@ -149,4 +149,4 @@ class LetList { ...@@ -149,4 +149,4 @@ class LetList {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_LET_LIST_H_ #endif // TVM_RELAY_TRANSFORMS_LET_LIST_H_
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/relay/pass/merge_composite.cc * \file src/relay/transforms/merge_composite.cc
* \brief Merges expressions matching patterns into functions marked * \brief Merges expressions matching patterns into functions marked
* as 'composite'. This is primarily intended to be used alongside the * as 'composite'. This is primarily intended to be used alongside the
* external codegen infrastructure to support the case where multiple * external codegen infrastructure to support the case where multiple
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/* /*
* \file src/relay/pass/partition_graph.cc * \file src/relay/transforms/partition_graph.cc
* *
* \brief Partition an input function into multiple functions according based * \brief Partition an input function into multiple functions according based
* on the inserted annotation nodes (i.e. compiler_begin and compiler_end). * on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
/*! /*!
* *
* \file tvm/relay/pass/pass_util.h * \file tvm/relay/_transforms/pass_util.h
* \brief Utilities for writing passes * \brief Utilities for writing passes
*/ */
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_ #ifndef TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -167,4 +167,4 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> { ...@@ -167,4 +167,4 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_ #endif // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
/*! /*!
* *
* \file tvm/relay/pass/pattern_util.h * \file tvm/relay/_pattern_util.h
* \brief Header of internal operator functions * \brief Header of internal operator functions
* These can be used for writing passes. * These can be used for writing passes.
*/ */
#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ #ifndef TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#define TVM_RELAY_PASS_PATTERN_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#include <builtin_fp16.h> #include <builtin_fp16.h>
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
...@@ -606,4 +606,4 @@ Expr CastHint(Expr data, DataType dtype); ...@@ -606,4 +606,4 @@ Expr CastHint(Expr data, DataType dtype);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
/*! /*!
* *
* \file src/relay/pass/print_ir.cc * \file src/relay/transforms/print_ir.cc
* *
* \brief Print the module IR to help debugging. * \brief Print the module IR to help debugging.
*/ */
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include "./pattern_util.h" #include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -29,9 +29,9 @@ ...@@ -29,9 +29,9 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h> #include <tvm/support/logging.h>
#include "let_list.h" #include "let_list.h"
#include "../../support/arena.h"
#include "pass_util.h" #include "pass_util.h"
#include "dependency_graph.h" #include "../../support/arena.h"
#include "../analysis/dependency_graph.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
* \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and * \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and
* ConvertLayout pass. */ * ConvertLayout pass. */
#ifndef TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ #ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
#define TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ #define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -367,4 +367,4 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj ...@@ -367,4 +367,4 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ #endif // TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_
...@@ -43,8 +43,8 @@ ...@@ -43,8 +43,8 @@
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "./pass_util.h" #include "pass_util.h"
#include "type_solver.h" #include "../analysis/type_solver.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment