Commit 138ec7be by Zhi Committed by Tianqi Chen

[Relay][Transform] merge PassContext and BuildConfig (#3234)

parent 415a270d
......@@ -22,17 +22,9 @@ tvm.relay.build_module
.. autofunction:: tvm.relay.build_module.build
.. autofunction:: tvm.relay.build_module.build_config
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.create_executor
.. autoclass:: tvm.relay.build_module.BuildConfig
:members:
.. autofunction:: tvm.relay.build_module.build_config
:members:
.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------
.. automodule:: tvm.relay.transform
.. autofunction:: tvm.relay.transform.build_config
.. autofunction:: tvm.relay.transform.module_pass
.. autofunction:: tvm.relay.transform.function_pass
.. autoclass:: tvm.relay.transform.Pass
:members:
.. autoclass:: tvm.relay.transform.PassInfo
:members:
.. autoclass:: tvm.relay.transform.PassContext
:members:
.. autoclass:: tvm.relay.transform.ModulePass
:members:
.. autoclass:: tvm.relay.transform.FunctionPass
:members:
.. autoclass:: tvm.relay.transform.Sequential
:members:
......@@ -56,11 +56,13 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
......@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
*/
ErrorReporter err_reporter;
/*! \brief The default optimization level. */
int opt_level{2};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
tvm::Array<tvm::Expr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::Expr> disabled_pass;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
class PassContext : public NodeRef {
public:
PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);
// Get the currently used pass context.
TVM_DLL static PassContext Current();
const PassContextNode* operator->() const;
using ContainerType = PassContextNode;
class Internal;
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
};
/*
* \brief The meta data of a pass.
......@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
* \param pass_ctx The context information for a certain pass.
* \return The updated module.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
......@@ -189,13 +250,22 @@ class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
......
......@@ -26,7 +26,8 @@ from . import module
from . import adt
from . import ir_pass
from . import transform
from .build_module import build, build_config, create_executor
from .build_module import build, create_executor
from .transform import build_config
from . import prelude
from . import parser
from . import debug
......
......@@ -28,81 +28,10 @@ from . import _build_module
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
"disable_pass": None,
"fallback_device": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError("invalid argument %s, candidates are %s" %
(k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def _update_target(target):
target = target if target else _target.current_target()
if target is None:
......@@ -189,7 +118,7 @@ class BuildModule(object):
return graph_json, mod, params
def _setup_build_config(self, params):
cfg = BuildConfig.current
cfg = _transform.PassContext.current()
# Set opt_level.
self.set_opt_level(cfg.opt_level)
......@@ -199,24 +128,24 @@ class BuildModule(object):
self.set_fallback_device(cfg.fallback_device)
# Add required passes.
if cfg.add_pass:
if cfg.required_pass:
passes = set()
if isinstance(cfg.add_pass, (list, tuple, set)):
passes = set(cfg.add_pass)
if isinstance(cfg.required_pass, (list, tuple, set)):
passes = set(cfg.required_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
"got {}".format(type(cfg.add_pass)))
"got {}".format(type(cfg.required_pass)))
for pass_name in passes:
self.add_pass(pass_name)
# Add disabled passes.
if cfg.disable_pass:
if cfg.disabled_pass:
passes = set()
if isinstance(cfg.disable_pass, (list, tuple, set)):
passes = set(cfg.disable_pass)
if isinstance(cfg.disabled_pass, (list, tuple, set)):
passes = set(cfg.disabled_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
"but got {}".format(type(cfg.disable_pass)))
"but got {}".format(type(cfg.disabled_pass)))
for pass_name in passes:
self.disable_pass(pass_name)
......@@ -287,12 +216,11 @@ class BuildModule(object):
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if isinstance(fallback_device, str):
if isinstance(fallback_device, (int, str)):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
raise TypeError("fallback_device is expected to be str " +
"TVMContext, or dict of device name to target, " +
"but received: {}".format(type(fallback_device)))
raise TypeError("fallback_device is expected to be str, int, or " +
"TVMContext but received: {}".format(type(fallback_device)))
self._set_fallback_device(fallback_device.device_type)
......
......@@ -22,7 +22,7 @@ import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import build_module as _build
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
......@@ -301,7 +301,7 @@ def optimize(func, params=None):
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
cfg = _transform.build_config(required_pass=opt_passes)
if params:
name_dict = {}
......@@ -321,25 +321,25 @@ def optimize(func, params=None):
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
if "SimplifyInference" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
if "FoldScaleAxis" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
if "CanonicalizeOps" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
return func
......
......@@ -23,8 +23,10 @@ conveniently.
"""
import types
from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
from .base import RelayNode, register_relay_node
from .. import nd as _nd
@register_relay_node
......@@ -57,10 +59,102 @@ class PassContext(RelayNode):
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
"""
def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int):
raise TypeError("required_pass is expected to be the type of " +
"int/str/TVMContext.")
required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " +
"list/tuple/set.")
def __init__(self):
self.__init_handle_by_constructor__(_transform.PassContext)
disabled = list(disabled_pass) if disabled_pass else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled_pass is expected to be the type of " +
"list/tuple/set.")
self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
disabled)
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, optional
Optimization level. The optimization pass name and level are as the
following:
.. code-block:: python
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
fallback_device : int, str, or tvm.TVMContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass)
@register_relay_node
......@@ -70,20 +164,6 @@ class Pass(RelayNode):
conveniently interact with the base class.
"""
def set_pass_context(self, pass_ctx):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if not isinstance(pass_ctx, PassContext):
raise TypeError("pass_ctx is expected to be the PassContext type")
_transform.SetContext(self, pass_ctx)
@property
def info(self):
"""Get the pass meta."""
......@@ -150,32 +230,23 @@ class Sequential(Pass):
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
"""
def __init__(self,
passes=None,
opt_level=2,
name="sequential",
required=None,
disabled=None):
required=None):
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
disabled = disabled if disabled else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
self.__init_handle_by_constructor__(_transform.Sequential,
passes, opt_level, name, required,
disabled)
passes, opt_level, name, required)
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
......
......@@ -22,8 +22,14 @@
* \file src/relay/pass/pass_manager.cc
* \brief Relay pass manager implementation.
*/
#include <dmlc/thread_local.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/device_api.h>
#include <algorithm>
#include <stack>
#include <unordered_set>
namespace tvm {
namespace relay {
......@@ -31,6 +37,98 @@ namespace transform {
using tvm::IRPrinter;
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*/
class OptPassLevel {
public:
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int operator[](const std::string& key) const {
const auto data = CreateMap();
auto it = data.find(key);
if (it == data.end()) {
return -1;
}
return it->second;
}
private:
static const std::unordered_map<std::string, int> CreateMap() {
const std::unordered_map<std::string, int> m = {
{"SimplifyInference", 0},
{"OpFusion", 1},
{"FoldConstant", 2},
{"CombineParallelConv2D", 3},
{"FoldScaleAxis", 3},
{"AlterOpLayout", 3},
{"CanonicalizeOps", 3},
{"EliminateCommonSubexpr", 3}
};
return m;
}
};
PassContext::PassContext(int opt_level, int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass) {
auto ctx = make_node<PassContextNode>();
ctx->opt_level = opt_level;
ctx->fallback_device = fallback_device;
ctx->required_pass = std::move(required_pass);
ctx->disabled_pass = std::move(disabled_pass);
node_ = std::move(ctx);
}
const PassContextNode* PassContext::operator->() const {
return static_cast<const PassContextNode*>(node_.get());
}
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
RelayPassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief Thread local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
RelayPassContextThreadLocalStore;
void PassContext::EnterWithScope() {
RelayPassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
RelayPassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
PassContext PassContext::Current() {
RelayPassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
} else {
return entry->default_context;
}
}
class ModulePass;
/*!
......@@ -58,38 +156,26 @@ class ModulePassNode : public PassNode {
}
/*!
* \brief Run a module pass on a certain module.
* \brief Run a module pass on given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
/*!
* \brief Set the context information for a module pass.
*
* \param pass_ctx The context information for a module pass.
*/
void SetContext(const PassContext& pass_ctx) final;
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.ModulePass";
TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode);
private:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext pass_ctx_;
};
RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass);
......@@ -124,26 +210,20 @@ class FunctionPassNode : public PassNode {
}
/*!
* \brief Run a function pass on a certain module.
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }
/*!
* \brief Set the context information for a function-level pass.
*
* \param pass_ctx The context information for a function-level pass.
*/
void SetContext(const PassContext& pass_ctx) final;
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
PassInfo pass_info);
......@@ -160,11 +240,6 @@ class FunctionPassNode : public PassNode {
* \return Return true if the function will be skipped, otherwise false.
*/
bool SkipFunction(const Function& func) const;
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext pass_ctx_;
};
RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
......@@ -182,18 +257,17 @@ class SequentialNode : public PassNode {
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
/*!
* \brief A list of disabled passes that should be excluded when executing the
* sequential pass.
* \brief A helper struct to get the optimization pass name to opt level
* mapping.
*/
tvm::Array<tvm::Expr> disabled;
OptPassLevel opt_pass_level;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
v->Visit("disabled", &disabled);
}
/*!
......@@ -211,6 +285,15 @@ class SequentialNode : public PassNode {
}
/*!
* \brief Check if a pass is enabled.
*
* \param pass_name The name of an optimization/analysis pass.
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool pass_enabled(const std::string& pass_name) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
......@@ -224,7 +307,11 @@ class SequentialNode : public PassNode {
*/
void ResolveDependency(const Module& mod);
TVM_DLL std::vector<std::string> DisabledPasses() const;
std::unordered_set<std::string> DisabledPasses(
const Array<tvm::Expr>& disabled) const;
std::unordered_set<std::string> RequiredPasses(
const Array<tvm::Expr>& disabled) const;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
......@@ -232,27 +319,15 @@ class SequentialNode : public PassNode {
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that an optimization pass runs on.
* \param mod The module that these passes are applied on.
* \param pass_ctx The context that these passes execute on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;
/*!
* \brief Set the context information for a sequential pass.
*
* \param pass_ctx The context information for a sequential pass.
*/
void SetContext(const PassContext& pass_ctx) final;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
private:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext pass_ctx_;
};
PassInfo PassInfoNode::make(int opt_level, std::string name,
......@@ -264,11 +339,6 @@ PassInfo PassInfoNode::make(int opt_level, std::string name,
return PassInfo(pass_info);
}
PassContext PassContextNode::make() {
auto ctx = make_node<PassContextNode>();
return PassContext(ctx);
}
ModulePass ModulePassNode::make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
PassInfo pass_info) {
......@@ -279,23 +349,19 @@ ModulePass ModulePassNode::make(
}
// Module -> Module optimizations.
// TODO(zhiics) 1. Check and handle the required passes.
// 2. Probably use CoW for all places that use module instead of
// returning the updated one.
Module ModulePassNode::operator()(const Module& mod) const {
// TODO(zhiics) Check and handle the required passes.
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
auto updated_mod = pass_func(mod, pass_ctx_);
auto updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
void ModulePassNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
}
FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
PassInfo pass_info) {
......@@ -307,31 +373,22 @@ FunctionPass FunctionPassNode::make(
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod) const {
Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
std::vector<std::pair<GlobalVar, Function>> updated_funcs;
ModuleNode* mod_node = mod.operator->();
for (const auto& it : mod_node->functions) {
if (!SkipFunction(it.second)) {
auto updated_func = pass_func(it.second, pass_ctx_);
CHECK(updated_func.defined());
updated_funcs.push_back({std::move(it.first), std::move(updated_func)});
}
}
Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Update the optimized functions.
for (const auto& it : updated_funcs) {
mod_node->Update(it.first, it.second);
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
auto updated_func =
SkipFunction(it.second) ? it.second : pass_func(it.second, pass_ctx);
new_mod->Add(it.first, updated_func);
}
return GetRef<Module>(mod_node);
}
void FunctionPassNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
return new_mod;
}
// TODO(zhiics) Create an enum attribute for FunctionNode
......@@ -342,31 +399,23 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return pval && pval->value != 0;
}
Sequential::Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled) {
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
n->disabled = std::move(disabled);
node_ = std::move(n);
}
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfoNode::make(2, std::move(name), {});
n->pass_info = std::move(pass_info);
node_ = std::move(n);
}
// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module) const {
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const auto* pn = pass.operator->();
mod = (*pn)(mod);
}
return mod;
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
}
void SequentialNode::ResolveDependency(const Module& mod) {
......@@ -378,18 +427,68 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<< "\n";
}
std::vector<std::string> SequentialNode::DisabledPasses() const {
std::vector<std::string> ret;
std::unordered_set<std::string> SequentialNode::DisabledPasses(
const Array<tvm::Expr>& disabled) const {
std::unordered_set<std::string> ret;
for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "disabled passes must be string.";
ret.push_back(str->value);
ret.emplace(str->value);
}
return ret;
}
void SequentialNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
std::unordered_set<std::string> SequentialNode::RequiredPasses(
const Array<tvm::Expr>& required) const {
std::unordered_set<std::string> ret;
for (const auto& it : required) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "disabled passes must be string.";
ret.emplace(str->value);
}
return ret;
}
bool SequentialNode::pass_enabled(const std::string& pass_name) const {
PassContext ctx = PassContext::Current();
const PassContextNode* ctx_node = ctx.operator->();
auto required = RequiredPasses(ctx_node->required_pass);
auto disabled = DisabledPasses(ctx_node->required_pass);
if (disabled.count(pass_name)) {
return false;
}
if (required.count(pass_name)) {
return true;
}
return ctx_node->opt_level >= opt_pass_level[pass_name];
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
const auto* ctx_node = pass_ctx.operator->();
int opt_level = ctx_node->opt_level;
auto disabled = DisabledPasses(ctx_node->disabled_pass);
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
PassInfo info = pass->Info();
const auto& pass_name = info.operator->()->name;
const auto& pass_opt_level = info.operator->()->opt_level;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if (pass_opt_level > opt_level || disabled.count(pass_name)) {
continue;
}
const auto* pn = pass.operator->();
mod = (*pn)(mod, pass_ctx);
}
return mod;
}
Pass CreateModulePass(
......@@ -481,9 +580,8 @@ TVM_REGISTER_API("relay._transform.Sequential")
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
*ret = Sequential(passes, pass_info, disabled);
*ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -501,26 +599,58 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "]";
});
TVM_REGISTER_API("relay._transform.SetContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
PassContext pass_ctx = args[1];
pass->SetContext(pass_ctx);
});
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._transform.PassContext")
.set_body_typed(PassContextNode::make);
.set_body([](TVMArgs args, TVMRetValue* ret) {
int opt_level = args[0];
int fallback_device = args[1];
tvm::Array<tvm::Expr> required = args[2];
tvm::Array<tvm::Expr> disabled = args[3];
*ret = PassContext(opt_level, fallback_device, required, disabled);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node,
tvm::IRPrinter* p) {
p->stream << "TODO(zhiics): printing context";
LOG(FATAL) << "PassContext printer has not been implemented yet."
<< "\n";
tvm::IRPrinter* p) {
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level)
<< "\n";
p->stream << "\trequired passes: [" << node->opt_level;
for (const auto& it : node->required_pass) {
p->stream << it << " ";
}
p->stream << "]\n";
p->stream << "\tdisabled passes: [" << node->opt_level;
for (const auto& it : node->disabled_pass) {
p->stream << it << " ";
}
p->stream << "]";
});
class PassContext::Internal {
public:
static void EnterScope(PassContext pass_ctx) {
pass_ctx.EnterWithScope();
}
static void ExitScope(PassContext pass_ctx) {
pass_ctx.ExitWithScope();
}
};
TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
.set_body_typed(PassContext::Current);
TVM_REGISTER_API("relay._transform.EnterPassContext")
.set_body_typed(PassContext::Internal::EnterScope);
TVM_REGISTER_API("relay._transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -31,7 +31,7 @@ import model_zoo
def get_tvm_output(func, x, params, target, ctx,
out_shape=(1, 1000), input_name='image', dtype='float32'):
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
......@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
dtype_dict = {input_name: input_data.dtype}
func, params = relay.frontend.from_coreml(coreml_model, shape_dict)
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
from tvm.contrib import graph_runtime
......
......@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
func, params = relay.frontend.from_keras(keras_model, shape_dict)
with relay.build_module.build_config(opt_level=2):
with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
for name, x in zip(keras_model.input_names, xs):
......
......@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
# target x86 CPU
target = "llvm"
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment