Commit 138ec7be by Zhi Committed by Tianqi Chen

[Relay][Transform] merge PassContext and BuildConfig (#3234)

parent 415a270d
......@@ -22,17 +22,9 @@ tvm.relay.build_module
.. autofunction:: tvm.relay.build_module.build
.. autofunction:: tvm.relay.build_module.build_config
.. autofunction:: tvm.relay.build_module.optimize
.. autofunction:: tvm.relay.build_module.create_executor
.. autoclass:: tvm.relay.build_module.BuildConfig
:members:
.. autofunction:: tvm.relay.build_module.build_config
:members:
.. autoclass:: tvm.relay.build_module.GraphExecutor
:members:
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.transform
----------------------
.. automodule:: tvm.relay.transform
.. autofunction:: tvm.relay.transform.build_config
.. autofunction:: tvm.relay.transform.module_pass
.. autofunction:: tvm.relay.transform.function_pass
.. autoclass:: tvm.relay.transform.Pass
:members:
.. autoclass:: tvm.relay.transform.PassInfo
:members:
.. autoclass:: tvm.relay.transform.PassContext
:members:
.. autoclass:: tvm.relay.transform.ModulePass
:members:
.. autoclass:: tvm.relay.transform.FunctionPass
:members:
.. autoclass:: tvm.relay.transform.Sequential
:members:
......@@ -56,11 +56,13 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
......@@ -83,18 +85,69 @@ class PassContextNode : public RelayNode {
*/
ErrorReporter err_reporter;
/*! \brief The default optimization level. */
int opt_level{2};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
tvm::Array<tvm::Expr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::Expr> disabled_pass;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
class PassContext : public NodeRef {
public:
PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);
// Get the currently used pass context.
TVM_DLL static PassContext Current();
const PassContextNode* operator->() const;
using ContainerType = PassContextNode;
class Internal;
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
};
/*
* \brief The meta data of a pass.
......@@ -150,20 +203,28 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
* \param pass_ctx The context information for a certain pass.
* \return The updated module.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
......@@ -189,13 +250,22 @@ class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
......
......@@ -26,7 +26,8 @@ from . import module
from . import adt
from . import ir_pass
from . import transform
from .build_module import build, build_config, create_executor
from .build_module import build, create_executor
from .transform import build_config
from . import prelude
from . import parser
from . import debug
......
......@@ -28,81 +28,10 @@ from . import _build_module
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
"disable_pass": None,
"fallback_device": None,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError("invalid argument %s, candidates are %s" %
(k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
disable_pass: set of str
Optimization pass to be disabled during optimization.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
def _update_target(target):
target = target if target else _target.current_target()
if target is None:
......@@ -189,7 +118,7 @@ class BuildModule(object):
return graph_json, mod, params
def _setup_build_config(self, params):
cfg = BuildConfig.current
cfg = _transform.PassContext.current()
# Set opt_level.
self.set_opt_level(cfg.opt_level)
......@@ -199,24 +128,24 @@ class BuildModule(object):
self.set_fallback_device(cfg.fallback_device)
# Add required passes.
if cfg.add_pass:
if cfg.required_pass:
passes = set()
if isinstance(cfg.add_pass, (list, tuple, set)):
passes = set(cfg.add_pass)
if isinstance(cfg.required_pass, (list, tuple, set)):
passes = set(cfg.required_pass)
else:
raise TypeError("add_pass must be list, tuple, or set, but " +
"got {}".format(type(cfg.add_pass)))
"got {}".format(type(cfg.required_pass)))
for pass_name in passes:
self.add_pass(pass_name)
# Add disabled passes.
if cfg.disable_pass:
if cfg.disabled_pass:
passes = set()
if isinstance(cfg.disable_pass, (list, tuple, set)):
passes = set(cfg.disable_pass)
if isinstance(cfg.disabled_pass, (list, tuple, set)):
passes = set(cfg.disabled_pass)
else:
raise TypeError("disable_pass must be list, tuple, or set, " +
"but got {}".format(type(cfg.disable_pass)))
"but got {}".format(type(cfg.disabled_pass)))
for pass_name in passes:
self.disable_pass(pass_name)
......@@ -287,12 +216,11 @@ class BuildModule(object):
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if isinstance(fallback_device, str):
if isinstance(fallback_device, (int, str)):
fallback_device = _nd.context(fallback_device)
if not isinstance(fallback_device, TVMContext):
raise TypeError("fallback_device is expected to be str " +
"TVMContext, or dict of device name to target, " +
"but received: {}".format(type(fallback_device)))
raise TypeError("fallback_device is expected to be str, int, or " +
"TVMContext but received: {}".format(type(fallback_device)))
self._set_fallback_device(fallback_device.device_type)
......
......@@ -22,7 +22,7 @@ import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import build_module as _build
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
......@@ -301,7 +301,7 @@ def optimize(func, params=None):
"FoldConstant",
"CanonicalizeOps"]
cfg = _build.build_config(add_pass=opt_passes)
cfg = _transform.build_config(required_pass=opt_passes)
if params:
name_dict = {}
......@@ -321,25 +321,25 @@ def optimize(func, params=None):
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.add_pass:
if "SimplifyInference" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.add_pass:
if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.add_pass:
if "FoldScaleAxis" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.add_pass:
if "CanonicalizeOps" in cfg.required_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.add_pass:
if "FoldConstant" in cfg.required_pass:
func = _ir_pass.fold_constant(func)
return func
......
......@@ -23,8 +23,10 @@ conveniently.
"""
import types
from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
from .base import RelayNode, register_relay_node
from .. import nd as _nd
@register_relay_node
......@@ -57,10 +59,102 @@ class PassContext(RelayNode):
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
"""
def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int):
raise TypeError("required_pass is expected to be the type of " +
"int/str/TVMContext.")
required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " +
"list/tuple/set.")
def __init__(self):
self.__init_handle_by_constructor__(_transform.PassContext)
disabled = list(disabled_pass) if disabled_pass else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled_pass is expected to be the type of " +
"list/tuple/set.")
self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
disabled)
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, optional
Optimization level. The optimization pass name and level are as the
following:
.. code-block:: python
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
fallback_device : int, str, or tvm.TVMContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass)
@register_relay_node
......@@ -70,20 +164,6 @@ class Pass(RelayNode):
conveniently interact with the base class.
"""
def set_pass_context(self, pass_ctx):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if not isinstance(pass_ctx, PassContext):
raise TypeError("pass_ctx is expected to be the PassContext type")
_transform.SetContext(self, pass_ctx)
@property
def info(self):
"""Get the pass meta."""
......@@ -150,32 +230,23 @@ class Sequential(Pass):
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
"""
def __init__(self,
passes=None,
opt_level=2,
name="sequential",
required=None,
disabled=None):
required=None):
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
disabled = disabled if disabled else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
self.__init_handle_by_constructor__(_transform.Sequential,
passes, opt_level, name, required,
disabled)
passes, opt_level, name, required)
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
......
......@@ -31,7 +31,7 @@ import model_zoo
def get_tvm_output(func, x, params, target, ctx,
out_shape=(1, 1000), input_name='image', dtype='float32'):
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
......@@ -72,7 +72,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
dtype_dict = {input_name: input_data.dtype}
func, params = relay.frontend.from_coreml(coreml_model, shape_dict)
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
from tvm.contrib import graph_runtime
......
......@@ -43,7 +43,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
func, params = relay.frontend.from_keras(keras_model, shape_dict)
with relay.build_module.build_config(opt_level=2):
with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(func, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
for name, x in zip(keras_model.input_names, xs):
......
......@@ -144,7 +144,7 @@ func, params = relay.frontend.from_tflite(tflite_model,
# target x86 CPU
target = "llvm"
with relay.build_module.build_config(opt_level=3):
with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment