Commit b528acc1 by Tianqi Chen Committed by GitHub

[LINT][PY] Fixes for pylint==2.4.4 (#4849)

parent b46c2548
...@@ -94,10 +94,7 @@ javadoc: ...@@ -94,10 +94,7 @@ javadoc:
# Cython build # Cython build
cython: cython:
cd python; python setup.py build_ext --inplace cd python; python3 setup.py build_ext --inplace
cython2:
cd python; python2 setup.py build_ext --inplace
cython3: cython3:
cd python; python3 setup.py build_ext --inplace cd python; python3 setup.py build_ext --inplace
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name # pylint: disable=invalid-name, import-outside-toplevel
"""Base library for TVM FFI.""" """Base library for TVM FFI."""
import sys import sys
import os import os
...@@ -204,14 +204,14 @@ def _find_error_type(line): ...@@ -204,14 +204,14 @@ def _find_error_type(line):
if _valid_error_name(err_name): if _valid_error_name(err_name):
return err_name return err_name
return None return None
else:
end_pos = line.find(":") end_pos = line.find(":")
if end_pos == -1: if end_pos == -1:
return None
err_name = line[:end_pos]
if _valid_error_name(err_name):
return err_name
return None return None
err_name = line[:end_pos]
if _valid_error_name(err_name):
return err_name
return None
def c2pyerror(err_msg): def c2pyerror(err_msg):
......
...@@ -104,6 +104,7 @@ class RedisDatabase(Database): ...@@ -104,6 +104,7 @@ class RedisDatabase(Database):
MAGIC_SPLIT = "$" MAGIC_SPLIT = "$"
def __init__(self, db_index=REDIS_PROD): def __init__(self, db_index=REDIS_PROD):
# pylint: disable=import-outside-toplevel
import redis import redis
if db_index == RedisDatabase.REDIS_TEST: if db_index == RedisDatabase.REDIS_TEST:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name # pylint: disable=invalid-name,
"""Extract feature of iter vars """Extract feature of iter vars
There are two types of feature There are two types of feature
...@@ -148,6 +148,7 @@ def get_flatten_name(fea): ...@@ -148,6 +148,7 @@ def get_flatten_name(fea):
} }
if isinstance(fea, str): if isinstance(fea, str):
# pylint: disable=import-outside-toplevel
from .record import decode from .record import decode
# flatten line to feature # flatten line to feature
line = fea line = fea
......
...@@ -539,4 +539,3 @@ class BaseGraphTuner(object): ...@@ -539,4 +539,3 @@ class BaseGraphTuner(object):
@abstractmethod @abstractmethod
def run(self, **kwargs): def run(self, **kwargs):
"""Run graph tuning.""" """Run graph tuning."""
pass
...@@ -65,6 +65,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): ...@@ -65,6 +65,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
% op_name) % op_name)
topi_funcs += OP2COMPUTE[op_name] topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs) env.reset(topi_funcs)
# pylint: disable=not-context-manager
with env: with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list) _expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0 task_pos = 0
......
...@@ -208,6 +208,7 @@ def measure_option(builder, runner): ...@@ -208,6 +208,7 @@ def measure_option(builder, runner):
Using `min_repeat_ms` can dynamically adjusts `number`, so it is recommended. Using `min_repeat_ms` can dynamically adjusts `number`, so it is recommended.
The typical value for NVIDIA GPU is 150 ms. The typical value for NVIDIA GPU is 150 ms.
""" """
# pylint: disable=import-outside-toplevel
from .measure_methods import LocalBuilder, LocalRunner from .measure_methods import LocalBuilder, LocalRunner
if isinstance(builder, str): if isinstance(builder, str):
......
...@@ -324,11 +324,11 @@ class LocalRunner(RPCRunner): ...@@ -324,11 +324,11 @@ class LocalRunner(RPCRunner):
self.server = None self.server = None
def set_task(self, task): def set_task(self, task):
self.task = task # pylint: disable=import-outside-toplevel
from ...rpc.tracker import Tracker from ...rpc.tracker import Tracker
from ...rpc.server import Server from ...rpc.server import Server
self.task = task
tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True) tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port device_key = '$local$device$%d' % tracker.port
server = Server('0.0.0.0', port=9000, port_end=10000, server = Server('0.0.0.0', port=9000, port_end=10000,
...@@ -362,6 +362,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti ...@@ -362,6 +362,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
# if target is vta, we need to use vta build # if target is vta, we need to use vta build
if hasattr(measure_input.target, 'device_name') and \ if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta': measure_input.target.device_name == 'vta':
# pylint: disable=import-outside-toplevel
import vta import vta
func = vta.build(s, args, target_host=task.target_host) func = vta.build(s, args, target_host=task.target_host)
else: else:
...@@ -460,6 +461,7 @@ def run_through_rpc(measure_input, build_result, ...@@ -460,6 +461,7 @@ def run_through_rpc(measure_input, build_result,
# Program the FPGA every single time when targeting VTA # Program the FPGA every single time when targeting VTA
if hasattr(measure_input.target, 'device_name') and \ if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta': measure_input.target.device_name == 'vta':
# pylint: disable=import-outside-toplevel
from vta import program_fpga, reconfig_runtime from vta import program_fpga, reconfig_runtime
program_fpga(remote, None) program_fpga(remote, None)
reconfig_runtime(remote) reconfig_runtime(remote)
......
...@@ -282,6 +282,7 @@ class ApplyHistoryBest(DispatchContext): ...@@ -282,6 +282,7 @@ class ApplyHistoryBest(DispatchContext):
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
Otherwise, it is an iterator. Otherwise, it is an iterator.
""" """
# pylint: disable=import-outside-toplevel
from pathlib import Path from pathlib import Path
from ..record import load_from_file from ..record import load_from_file
...@@ -454,6 +455,7 @@ class ApplyGraphBest(DispatchContext): ...@@ -454,6 +455,7 @@ class ApplyGraphBest(DispatchContext):
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
Otherwise, it is an iterator. Otherwise, it is an iterator.
""" """
# pylint: disable=import-outside-toplevel
from ..record import load_from_file from ..record import load_from_file
super(ApplyGraphBest, self).__init__() super(ApplyGraphBest, self).__init__()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=unused-variable,invalid-name # pylint: disable=unused-variable,invalid-name, not-context-manager
""" """
Decorator and utilities for the integration with TOPI and Relay Decorator and utilities for the integration with TOPI and Relay
99.9% copy-paste of implementation by @MerryMercy 99.9% copy-paste of implementation by @MerryMercy
...@@ -37,7 +37,7 @@ def _lower(mod, ...@@ -37,7 +37,7 @@ def _lower(mod,
params): params):
""" Helper to lower VTA properly. """ Helper to lower VTA properly.
""" """
# pylint: disable=import-outside-toplevel
from tvm import relay from tvm import relay
from tvm.relay.backend import graph_runtime_codegen from tvm.relay.backend import graph_runtime_codegen
...@@ -114,6 +114,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -114,6 +114,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
task: Array of autotvm.task.Task task: Array of autotvm.task.Task
collected tasks collected tasks
""" """
# pylint: disable=import-outside-toplevel
import tvm.relay.op import tvm.relay.op
from tvm import relay from tvm import relay
import topi import topi
......
...@@ -76,6 +76,7 @@ class TaskExtractEnv: ...@@ -76,6 +76,7 @@ class TaskExtractEnv:
registered = None registered = None
def __init__(self, allow_duplicate=False): def __init__(self, allow_duplicate=False):
# pylint: disable=import-outside-toplevel
import topi import topi
# topi compute -> autotvm task name # topi compute -> autotvm task name
...@@ -168,6 +169,7 @@ class TaskExtractEnv: ...@@ -168,6 +169,7 @@ class TaskExtractEnv:
def _register_topi_task(self): def _register_topi_task(self):
"""register tuning wrapper for topi function""" """register tuning wrapper for topi function"""
# pylint: disable=import-outside-toplevel
import topi import topi
# Avoid double registration for certain targets # Avoid double registration for certain targets
......
...@@ -147,6 +147,7 @@ def check_backend(tophub_location, backend): ...@@ -147,6 +147,7 @@ def check_backend(tophub_location, backend):
if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)): if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)):
return True return True
# pylint: disable=import-outside-toplevel
if sys.version_info >= (3,): if sys.version_info >= (3,):
import urllib.request as urllib2 import urllib.request as urllib2
else: else:
......
...@@ -53,6 +53,7 @@ def log_to_file(file_out, protocol='json'): ...@@ -53,6 +53,7 @@ def log_to_file(file_out, protocol='json'):
for inp, result in zip(inputs, results): for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n") file_out.write(record.encode(inp, result, protocol) + "\n")
# pylint: disable=import-outside-toplevel
from pathlib import Path from pathlib import Path
if isinstance(file_out, Path): if isinstance(file_out, Path):
file_out = str(file_out) file_out = str(file_out)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=consider-using-enumerate, invalid-name # pylint: disable=consider-using-enumerate, invalid-name, invalid-sequence-index
""" """
Cost model optimizer based on simulated annealing Cost model optimizer based on simulated annealing
""" """
......
...@@ -420,6 +420,7 @@ def _extract_curve_feature_log(arg): ...@@ -420,6 +420,7 @@ def _extract_curve_feature_log(arg):
def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
maximize=False, verbose_eval=True): maximize=False, verbose_eval=True):
"""callback function for xgboost to support multiple custom evaluation functions""" """callback function for xgboost to support multiple custom evaluation functions"""
# pylint: disable=import-outside-toplevel
from xgboost.core import EarlyStopException from xgboost.core import EarlyStopException
from xgboost.callback import _fmt_metric from xgboost.callback import _fmt_metric
from xgboost.training import aggcv from xgboost.training import aggcv
......
...@@ -467,7 +467,7 @@ def _build_for_device(flist, target, target_host): ...@@ -467,7 +467,7 @@ def _build_for_device(flist, target, target_host):
func = ir_pass.InferFragment(func) func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0]) fhost.append(fsplits[0])
for x in fsplits[1:]: for x in fsplits[1:]:
fdevice.append(x) fdevice.append(x)
......
...@@ -76,8 +76,7 @@ def get_target_by_dump_machine(compiler): ...@@ -76,8 +76,7 @@ def get_target_by_dump_machine(compiler):
msg += py_str(out) msg += py_str(out)
return None return None
return py_str(out) return py_str(out)
else: return None
return None
return get_target_triple return get_target_triple
......
...@@ -54,6 +54,7 @@ def to_pytorch_func(tvm_func): ...@@ -54,6 +54,7 @@ def to_pytorch_func(tvm_func):
wrapped_func: Function wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors Wrapped tvm function that operates on PyTorch tensors
""" """
# pylint: disable=import-outside-toplevel
import torch import torch
import torch.utils.dlpack import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack) return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Helper utility for downloading""" """Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs
import os import os
import sys import sys
import time import time
...@@ -48,10 +45,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= ...@@ -48,10 +45,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
retries: int, optional retries: int, optional
Number of time to retry download, default at 3. Number of time to retry download, default at 3.
""" """
if sys.version_info >= (3,): # pylint: disable=import-outside-toplevel
import urllib.request as urllib2 import urllib.request as urllib2
else:
import urllib2
if os.path.isfile(path) and not overwrite: if os.path.isfile(path) and not overwrite:
if size_compare: if size_compare:
...@@ -114,9 +109,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= ...@@ -114,9 +109,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
if os.path.exists(tempfile): if os.path.exists(tempfile):
os.remove(tempfile) os.remove(tempfile)
raise err raise err
else: print("download failed due to {}, retrying, {} attempt{} left"
print("download failed due to {}, retrying, {} attempt{} left" .format(repr(err), retries, 's' if retries > 1 else ''))
.format(repr(err), retries, 's' if retries > 1 else ''))
if "TEST_DATA_ROOT_PATH" in os.environ: if "TEST_DATA_ROOT_PATH" in os.environ:
......
...@@ -49,7 +49,7 @@ def to_mxnet_func(func, const_loc=None): ...@@ -49,7 +49,7 @@ def to_mxnet_func(func, const_loc=None):
Run asynchrously in MXNet's async engine. Run asynchrously in MXNet's async engine.
""" """
# only import mxnet when wrap get called. # only import mxnet when wrap get called.
# pylint: disable=import-self # pylint: disable=import-self, import-outside-toplevel
import mxnet import mxnet
if isinstance(func, Module): if isinstance(func, Module):
func = func.entry_func func = func.entry_func
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Common system utilities""" """Common system utilities"""
from __future__ import absolute_import as _abs
import os import os
import tempfile import tempfile
import shutil import shutil
...@@ -167,35 +166,3 @@ def which(exec_name): ...@@ -167,35 +166,3 @@ def which(exec_name):
if os.path.isfile(full_path) and os.access(full_path, os.X_OK): if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
return full_path return full_path
return None return None
def get_lower_ir(s):
"""Get lower ir code of a schedule.
This is useful for debug, since you don't have to find all inputs/outputs
for a schedule in a fused subgraph.
Parameters
----------
s: Schedule
Returns
-------
ir: str
The lower ir
"""
from .. import tensor
from ..build_module import lower
outputs = s.outputs
inputs = []
def find_all(op):
if isinstance(op, tensor.PlaceholderOp):
inputs.append(op.output(0))
else:
for x in op.input_tensors:
find_all(x.op)
for out in outputs:
find_all(out)
return lower(s, inputs, simple_mode=True)
...@@ -50,7 +50,8 @@ def script(pyfunc): ...@@ -50,7 +50,8 @@ def script(pyfunc):
hybrid_func : function hybrid_func : function
A decorated hybrid script function. A decorated hybrid script function.
""" """
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring # pylint: disable=import-outside-toplevel, missing-docstring
def wrapped_func(func, *args, **kwargs):
from .util import _is_tvm_arg_types from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
src = _pruned_source(func) src = _pruned_source(func)
......
...@@ -69,6 +69,7 @@ def bind(func_id, args): ...@@ -69,6 +69,7 @@ def bind(func_id, args):
def _math_intrin(func_id, args): def _math_intrin(func_id, args):
# pylint: disable=import-outside-toplevel
from .. import intrin from .. import intrin
return getattr(intrin, func_id)(*args) return getattr(intrin, func_id)(*args)
......
...@@ -198,7 +198,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -198,7 +198,7 @@ class HybridParser(ast.NodeVisitor):
ty, entry = self.symbols[key] #pylint: disable=invalid-name ty, entry = self.symbols[key] #pylint: disable=invalid-name
if ty in [Symbol.Input, Symbol.OutputBuffer]: if ty in [Symbol.Input, Symbol.OutputBuffer]:
continue continue
elif 'Buffer' in ty.name: if 'Buffer' in ty.name:
_buf = entry _buf = entry
_scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower() _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
to_pop.append(key) to_pop.append(key)
......
...@@ -70,6 +70,7 @@ def _pruned_source(func): ...@@ -70,6 +70,7 @@ def _pruned_source(func):
def replace_io(body, rmap): def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given""" """Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel
from .. import ir_pass from .. import ir_pass
def replace(op): def replace(op):
......
...@@ -78,7 +78,7 @@ class ParseError(Exception): ...@@ -78,7 +78,7 @@ class ParseError(Exception):
class OpWrapper: class OpWrapper:
"""Overload the __call__ for op.""" """Overload the __call__ for op."""
pass
class ExprOp(OpWrapper): class ExprOp(OpWrapper):
"""Call an expr. The default, but does not handle attrs well.""" """Call an expr. The default, but does not handle attrs well."""
...@@ -273,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -273,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def _type_expr_name(self, e): def _type_expr_name(self, e):
if isinstance(e, adt.Constructor): if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.name_hint) return "`{0}` ADT constructor".format(e.belong_to.name_hint)
elif isinstance(e, ty.GlobalTypeVar): if isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle: if e.kind == ty.Kind.AdtHandle:
return "ADT definition" return "ADT definition"
return "function definition" return "function definition"
...@@ -623,7 +623,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -623,7 +623,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def call(self, func, args, attrs, type_args): def call(self, func, args, attrs, type_args):
if isinstance(func, OpWrapper): if isinstance(func, OpWrapper):
return func(args, attrs, type_args) return func(args, attrs, type_args)
elif isinstance(func, adt.Constructor): if isinstance(func, adt.Constructor):
return func(*args) return func(*args)
return expr.Call(func, args, attrs, type_args) return expr.Call(func, args, attrs, type_args)
......
...@@ -384,7 +384,7 @@ def detect_feature(a, b=None): ...@@ -384,7 +384,7 @@ def detect_feature(a, b=None):
""" """
if isinstance(a, Module): if isinstance(a, Module):
a, b = b, a a, b = b, a
return set([Feature(int(x)) for x in _analysis.detect_feature(a, b)]) return {Feature(int(x)) for x in _analysis.detect_feature(a, b)}
def structural_hash(value): def structural_hash(value):
......
...@@ -44,8 +44,9 @@ def lower(sch, inputs, func_name, source_func): ...@@ -44,8 +44,9 @@ def lower(sch, inputs, func_name, source_func):
lowered_funcs : List[tvm.LoweredFunc] lowered_funcs : List[tvm.LoweredFunc]
The result of lowering. The result of lowering.
""" """
# pylint: disable=broad-except, import-outside-toplevel
import traceback import traceback
# pylint: disable=broad-except
try: try:
f = _build.lower(sch, inputs, name=func_name) f = _build.lower(sch, inputs, name=func_name)
# logging.debug("lower function %s", func_name) # logging.debug("lower function %s", func_name)
......
...@@ -86,7 +86,7 @@ class CompileEngine(Object): ...@@ -86,7 +86,7 @@ class CompileEngine(Object):
cached_func: CachedFunc cached_func: CachedFunc
The result of lowering. The result of lowering.
""" """
# pylint: disable=broad-except # pylint: disable=broad-except, import-outside-toplevel
try: try:
key = _get_cache_key(source_func, target) key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key) return _backend._CompileEngineLower(self, key)
......
...@@ -407,7 +407,6 @@ def create_executor(kind="debug", ...@@ -407,7 +407,6 @@ def create_executor(kind="debug",
return _interpreter.Interpreter(mod, ctx, target) return _interpreter.Interpreter(mod, ctx, target)
if kind == "graph": if kind == "graph":
return GraphExecutor(mod, ctx, target) return GraphExecutor(mod, ctx, target)
elif kind == "vm": if kind == "vm":
return VMExecutor(mod, ctx, target) return VMExecutor(mod, ctx, target)
else: raise RuntimeError("unknown execution strategy: {0}".format(kind))
raise RuntimeError("unknown execution strategy: {0}".format(kind))
...@@ -20,7 +20,7 @@ from __future__ import absolute_import ...@@ -20,7 +20,7 @@ from __future__ import absolute_import
from ..api import register_func from ..api import register_func
# pylint: disable=unused-argument # pylint: disable=unused-argument, import-outside-toplevel
def _debugger_init(expr, stack): def _debugger_init(expr, stack):
import pdb import pdb
pdb.set_trace() pdb.set_trace()
......
...@@ -125,8 +125,7 @@ class Expr(RelayNode): ...@@ -125,8 +125,7 @@ class Expr(RelayNode):
def __rsub__(self, other): def __rsub__(self, other):
if isinstance(other, _Number): if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other)) raise TypeError('convert "%s" with `const` first' % str(other))
else: raise TypeError("type %s not supported" % str(type(other)))
raise TypeError("type %s not supported" % str(type(other)))
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Expr): if isinstance(other, Expr):
...@@ -150,8 +149,7 @@ class Expr(RelayNode): ...@@ -150,8 +149,7 @@ class Expr(RelayNode):
def __rdiv__(self, other): def __rdiv__(self, other):
if isinstance(other, _Number): if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other)) raise TypeError('convert "%s" with `const` first' % str(other))
else: raise TypeError("type %s not supported" % str(type(other)))
raise TypeError("type %s not supported" % str(type(other)))
def __truediv__(self, other): def __truediv__(self, other):
return self.__div__(other) return self.__div__(other)
......
...@@ -401,6 +401,7 @@ class Caffe2NetDef(object): ...@@ -401,6 +401,7 @@ class Caffe2NetDef(object):
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
# pylint: disable=import-outside-toplevel
from caffe2.python import workspace from caffe2.python import workspace
workspace.RunNetOnce(init_net) workspace.RunNetOnce(init_net)
......
...@@ -302,7 +302,7 @@ class ExprTable(object): ...@@ -302,7 +302,7 @@ class ExprTable(object):
self.exprs[name] = expr self.exprs[name] = expr
def has_expr(self, name): def has_expr(self, name):
return True if name in self.exprs else False return name in self.exprs
def set_padding(self, paddings): def set_padding(self, paddings):
self.paddings = paddings self.paddings = paddings
...@@ -391,7 +391,7 @@ class AttrCvt(object): ...@@ -391,7 +391,7 @@ class AttrCvt(object):
if k in self._excludes: if k in self._excludes:
raise NotImplementedError('Attribute %s in operator %s is not' + raise NotImplementedError('Attribute %s in operator %s is not' +
' supported.', k, op_name) ' supported.', k, op_name)
elif k in self._disables: if k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores: elif k in self._ignores:
if k != 'tvm_custom': if k != 'tvm_custom':
...@@ -485,6 +485,7 @@ def infer_value(input_val, params): ...@@ -485,6 +485,7 @@ def infer_value(input_val, params):
portion of the relay graph. This is often needed for functions that portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor. whose output shape depends on the value of a tensor.
""" """
# pylint: disable=import-outside-toplevel
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters. # Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars( assert all(var.name_hint in params.keys() for var in analysis.free_vars(
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, import-self, unused-argument, unused-variable, inconsistent-return-statements # pylint: disable=invalid-name, import-self, unused-argument, unused-variable
# pylint: disable=inconsistent-return-statements, import-outside-toplevel
"""CoreML frontend.""" """CoreML frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import math import math
...@@ -111,14 +112,13 @@ def _BatchnormLayerParams(op, inexpr, etab): ...@@ -111,14 +112,13 @@ def _BatchnormLayerParams(op, inexpr, etab):
if op.instanceNormalization: if op.instanceNormalization:
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator "instance normalization" is not supported in frontend CoreML.') 'Operator "instance normalization" is not supported in frontend CoreML.')
else: params = {'gamma':etab.new_const(list(op.gamma.floatValue)),
params = {'gamma':etab.new_const(list(op.gamma.floatValue)), 'beta':etab.new_const(list(op.beta.floatValue)),
'beta':etab.new_const(list(op.beta.floatValue)), 'moving_mean':etab.new_const(list(op.mean.floatValue)),
'moving_mean':etab.new_const(list(op.mean.floatValue)), 'moving_var': etab.new_const(list(op.variance.floatValue)),
'moving_var': etab.new_const(list(op.variance.floatValue)), 'epsilon': op.epsilon}
'epsilon': op.epsilon} result, moving_mean, moving_var = _op.nn.batch_norm(data=inexpr, **params)
result, moving_mean, moving_var = _op.nn.batch_norm(data=inexpr, **params) return result
return result
def _ActivationParams(op, inexpr, etab): def _ActivationParams(op, inexpr, etab):
...@@ -197,37 +197,36 @@ def _PoolingLayerParams(op, inexpr, etab): ...@@ -197,37 +197,36 @@ def _PoolingLayerParams(op, inexpr, etab):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Only Max and Average Pooling are supported in frontend CoreML.') 'Only Max and Average Pooling are supported in frontend CoreML.')
else: params = {'pool_size':list(op.kernelSize),
params = {'pool_size':list(op.kernelSize), 'strides':list(op.stride)}
'strides':list(op.stride)}
if op.WhichOneof('PoolingPaddingType') == 'valid':
valid = op.valid
if valid.paddingAmounts.borderAmounts:
assert len(valid.paddingAmounts.borderAmounts) == 2
pad_t = valid.paddingAmounts.borderAmounts[0].startEdgeSize
pad_l = valid.paddingAmounts.borderAmounts[1].startEdgeSize
pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
elif op.WhichOneof('PoolingPaddingType') == 'includeLastPixel':
# I don't know if this is correct
valid = op.includeLastPixel
padding = list(valid.paddingAmounts)
params['padding'] = padding
params['ceil_mode'] = True
else:
msg = 'PoolingPaddingType {} is not supported in operator Pooling.'
op_name = op.WhichOneof('PoolingPaddingType')
raise tvm.error.OpAttributeUnImplemented(msg.format(op_name))
if op.type == 0: if op.WhichOneof('PoolingPaddingType') == 'valid':
return _op.nn.max_pool2d(inexpr, **params) valid = op.valid
if op.type == 1: if valid.paddingAmounts.borderAmounts:
return _op.nn.avg_pool2d(inexpr, **params) assert len(valid.paddingAmounts.borderAmounts) == 2
raise tvm.error.OpNotImplemented( pad_t = valid.paddingAmounts.borderAmounts[0].startEdgeSize
'Only Max and Average Pooling are supported in CoreML.') pad_l = valid.paddingAmounts.borderAmounts[1].startEdgeSize
pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
elif op.WhichOneof('PoolingPaddingType') == 'includeLastPixel':
# I don't know if this is correct
valid = op.includeLastPixel
padding = list(valid.paddingAmounts)
params['padding'] = padding
params['ceil_mode'] = True
else:
msg = 'PoolingPaddingType {} is not supported in operator Pooling.'
op_name = op.WhichOneof('PoolingPaddingType')
raise tvm.error.OpAttributeUnImplemented(msg.format(op_name))
if op.type == 0:
return _op.nn.max_pool2d(inexpr, **params)
if op.type == 1:
return _op.nn.avg_pool2d(inexpr, **params)
raise tvm.error.OpNotImplemented(
'Only Max and Average Pooling are supported in CoreML.')
def _SoftmaxLayerParams(op, inexpr, etab): def _SoftmaxLayerParams(op, inexpr, etab):
...@@ -297,10 +296,8 @@ def _PaddingLayerParams(op, inexpr, etab): ...@@ -297,10 +296,8 @@ def _PaddingLayerParams(op, inexpr, etab):
(0, 0), (0, 0),
(pad_t, pad_b), (pad_t, pad_b),
(pad_l, pad_r))) (pad_l, pad_r)))
raise tvm.error.OpNotImplemented(
else: 'Non-constant padding is not supported in frontend CoreML.')
raise tvm.error.OpNotImplemented(
'Non-constant padding is not supported in frontend CoreML.')
def _PermuteLayerParams(op, inexpr, etab): def _PermuteLayerParams(op, inexpr, etab):
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, import-self # pylint: disable=invalid-name, import-self, import-outside-toplevel
"""Keras frontend.""" """Keras frontend."""
from __future__ import absolute_import as _abs
import sys import sys
import numpy as np import numpy as np
import tvm import tvm
...@@ -133,7 +132,7 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): ...@@ -133,7 +132,7 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
# f(x) = max_value, for x >= max_value # f(x) = max_value, for x >= max_value
# f(x) = x, for threshold <= x < max_value # f(x) = x, for threshold <= x < max_value
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value)) return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
elif keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'): if keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'):
# f(x) = negative_slope * (inexpr - threshold) # f(x) = negative_slope * (inexpr - threshold)
negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32') negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32')
return _op.multiply(negative_slope, _op.subtract(inexpr, threshold)) return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
......
...@@ -16,15 +16,13 @@ ...@@ -16,15 +16,13 @@
# under the License. # under the License.
# pylint: disable=invalid-name, import-self, len-as-condition # pylint: disable=invalid-name, import-self, len-as-condition
"""Utility functions common to NNVM and MxNet conversion.""" """Utility functions common to NNVM and MxNet conversion."""
from __future__ import absolute_import as _abs import warnings
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .common import get_relay_op from .common import get_relay_op
from .common import infer_type as _infer_type from .common import infer_type as _infer_type
def _warn_not_used(attr, op='nnvm'): def _warn_not_used(attr, op='nnvm'):
import warnings
err = "{} is ignored in {}.".format(attr, op) err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err) warnings.warn(err)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay.""" """ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel
"""TF: Tensorflow frontend.""" """TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""TF: Tensorflow parser""" """TF: Tensorflow parser"""
from __future__ import absolute_import as _abs # pylint: disable=import-outside-toplevel, assignment-from-no-return
from __future__ import print_function
import os import os
from tvm.contrib import util from tvm.contrib import util
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines # pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend.""" """Tensorflow lite frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import math import math
...@@ -1458,8 +1459,7 @@ class OperatorConverter(object): ...@@ -1458,8 +1459,7 @@ class OperatorConverter(object):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} with fused activation is not supported yet.' 'Operator {} with fused activation is not supported yet.'
.format('qnn.op.pool2d')) .format('qnn.op.pool2d'))
else: out = self.convert_fused_activation_function(out, fused_activation_fn)
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out return out
def convert_pad(self, op): def convert_pad(self, op):
......
...@@ -46,6 +46,7 @@ from ..base import register_relay_node ...@@ -46,6 +46,7 @@ from ..base import register_relay_node
def _register_op_make(): def _register_op_make():
# pylint: disable=import-outside-toplevel
from . import _make from . import _make
from .. import expr from .. import expr
expr._op_make = _make expr._op_make = _make
......
...@@ -200,13 +200,12 @@ def take_shape_func(attrs, inputs, out_ndims): ...@@ -200,13 +200,12 @@ def take_shape_func(attrs, inputs, out_ndims):
""" """
if attrs.axis is None: if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])] return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
else: axis = get_const_int(attrs.axis)
axis = get_const_int(attrs.axis) data_ndim = int(inputs[0].shape[0])
data_ndim = int(inputs[0].shape[0]) if axis < 0:
if axis < 0: axis += data_ndim
axis += data_ndim assert 0 <= axis < data_ndim
assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
@script @script
def _argwhere_shape_func_1d(condition): def _argwhere_shape_func_1d(condition):
...@@ -275,13 +274,13 @@ def argwhere_shape_func(attrs, inputs, out_ndims): ...@@ -275,13 +274,13 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
""" """
if len(inputs[0].shape) == 1: if len(inputs[0].shape) == 1:
return [_argwhere_shape_func_1d(inputs[0])] return [_argwhere_shape_func_1d(inputs[0])]
elif len(inputs[0].shape) == 2: if len(inputs[0].shape) == 2:
return [_argwhere_shape_func_2d(inputs[0])] return [_argwhere_shape_func_2d(inputs[0])]
elif len(inputs[0].shape) == 3: if len(inputs[0].shape) == 3:
return [_argwhere_shape_func_3d(inputs[0])] return [_argwhere_shape_func_3d(inputs[0])]
elif len(inputs[0].shape) == 4: if len(inputs[0].shape) == 4:
return [_argwhere_shape_func_4d(inputs[0])] return [_argwhere_shape_func_4d(inputs[0])]
elif len(inputs[0].shape) == 5: if len(inputs[0].shape) == 5:
return [_argwhere_shape_func_5d(inputs[0])] return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere") return ValueError("Does not support rank higher than 5 in argwhere")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments # pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -265,6 +265,7 @@ def schedule_conv2d(attrs, outs, target): ...@@ -265,6 +265,7 @@ def schedule_conv2d(attrs, outs, target):
@reg.register_alter_op_layout("nn.conv2d") @reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos): def alter_op_layout_conv2d(attrs, inputs, tinfos):
"""Alternate the layout of conv2d""" """Alternate the layout of conv2d"""
# pylint: disable=import-outside-toplevel
from ... import op from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
...@@ -309,7 +310,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): ...@@ -309,7 +310,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
result : tvm.relay.Expr result : tvm.relay.Expr
The transformed expr The transformed expr
""" """
# pylint: disable=import-outside-toplevel
from tvm import relay from tvm import relay
data_layout = attrs['data_layout'] data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout'] kernel_layout = attrs['kernel_layout']
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=import-outside-toplevel
"""Transform operators.""" """Transform operators."""
from . import _make from . import _make
......
...@@ -22,6 +22,7 @@ from .. import register_func ...@@ -22,6 +22,7 @@ from .. import register_func
@register_func("relay.fromtext") @register_func("relay.fromtext")
def fromtext(data, source_name=None): def fromtext(data, source_name=None):
"""Parse a Relay program.""" """Parse a Relay program."""
# pylint: disable=import-outside-toplevel
from tvm.relay import _parser from tvm.relay import _parser
x = _parser.fromtext(data + "\n", source_name) x = _parser.fromtext(data + "\n", source_name)
if x is None: if x is None:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
#pylint: disable=unused-argument #pylint: disable=unused-argument
"""The register functions for the QNN dialect.""" """The register functions for the QNN dialect."""
from tvm.relay.op.op import register as register from tvm.relay.op.op import register
def register_qnn_legalize(op_name, legal_op=None, level=10): def register_qnn_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for a QNN op """Register legal transformation function for a QNN op
......
...@@ -88,7 +88,7 @@ def add_partition_generic(ref_call, new_args, ctx): ...@@ -88,7 +88,7 @@ def add_partition_generic(ref_call, new_args, ctx):
lhs = new_args[0].realize() lhs = new_args[0].realize()
rhs = new_args[1].realize() rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs]) return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and rhs_cond: if not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet # - introduced by residual connection in ResNet
# ... # ...
# %13 = nn.conv2d(%12, %meta[relay.Constant]) # %13 = nn.conv2d(%12, %meta[relay.Constant])
...@@ -104,7 +104,7 @@ def add_partition_generic(ref_call, new_args, ctx): ...@@ -104,7 +104,7 @@ def add_partition_generic(ref_call, new_args, ctx):
# ... # ...
rhs = new_args[1].realize() rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs]) return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond: if lhs_cond and not rhs_cond:
if _analysis.check_constant(rhs): if _analysis.check_constant(rhs):
# - introduced by batch_norm: add(out, bias) # - introduced by batch_norm: add(out, bias)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
...@@ -121,11 +121,11 @@ def add_partition_generic(ref_call, new_args, ctx): ...@@ -121,11 +121,11 @@ def add_partition_generic(ref_call, new_args, ctx):
# ... # ...
lhs = new_args[0].realize() lhs = new_args[0].realize()
return _forward_op(ref_call, [lhs, rhs]) return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and not rhs_cond: if not lhs_cond and not rhs_cond:
# trivial case # trivial case
return None return None
else:
raise ValueError raise ValueError
# TODO(ziheng) enhance `register_partition_function` to dispatch # TODO(ziheng) enhance `register_partition_function` to dispatch
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
#pylint: disable=unused-argument #pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit.""" """Automatic quantization toolkit."""
from __future__ import absolute_import from __future__ import absolute_import
from . import _quantize from . import _quantize
......
...@@ -41,8 +41,7 @@ class WithScope(object): ...@@ -41,8 +41,7 @@ class WithScope(object):
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
if value: if value:
raise value raise value
else: self._exit_cb()
self._exit_cb()
def _make_lets(bindings, ret_value): def _make_lets(bindings, ret_value):
"""Make a nested let expressions. """Make a nested let expressions.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init # pylint: disable=invalid-name, unused-variable, unused-argument, no-init, unpacking-non-sequence
""" """
Compile DarkNet Models Compile DarkNet Models
==================== ====================
......
...@@ -85,24 +85,25 @@ def residual_unit(data, ...@@ -85,24 +85,25 @@ def residual_unit(data,
data=act1, channels=num_filter, kernel_size=(1, 1), data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc') strides=stride, name=name+'_sc')
return relay.add(conv3, shortcut) return relay.add(conv3, shortcut)
bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = relay.nn.relu(data=bn1)
conv1 = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), name=name + '_conv1')
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
if dim_match:
shortcut = data
else: else:
bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') shortcut = layers.conv2d(
act1 = relay.nn.relu(data=bn1) data=act1, channels=num_filter, kernel_size=(1, 1),
conv1 = layers.conv2d( strides=stride, name=name+'_sc')
data=act1, channels=num_filter, kernel_size=(3, 3), return relay.add(conv2, shortcut)
strides=stride, padding=(1, 1), name=name + '_conv1')
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
return relay.add(conv2, shortcut)
def resnet(units, def resnet(units,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init # pylint: disable=invalid-name, unused-variable, unused-argument, no-init, import-outside-toplevel
""" """
Tensorflow Model Helpers Tensorflow Model Helpers
======================== ========================
...@@ -346,7 +346,7 @@ def get_workload_ptb(): ...@@ -346,7 +346,7 @@ def get_workload_ptb():
sample_data_file = 'simple-examples.tgz' sample_data_file = 'simple-examples.tgz'
sample_url = sample_repo+sample_data_file sample_url = sample_repo+sample_data_file
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb' ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
# pylint: disable=import-outside-toplevel
import tarfile import tarfile
file_path = download_testdata(sample_url, sample_data_file, module=['data', 'ptb_data']) file_path = download_testdata(sample_url, sample_data_file, module=['data', 'ptb_data'])
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init # pylint: disable=invalid-name, unused-variable, unused-argument, no-init,
""" """
Yolo detection boxes helper functions Yolo detection boxes helper functions
==================== ====================
...@@ -224,6 +224,7 @@ def _draw_label(im, r, c, label, rgb): ...@@ -224,6 +224,7 @@ def _draw_label(im, r, c, label, rgb):
_set_pixel(im, i+c, j+r, k, val)#rgb[k] * val) _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
def _get_label(font_path, labelstr, rgb): def _get_label(font_path, labelstr, rgb):
# pylint: disable=import-outside-toplevel
from PIL import Image from PIL import Image
from PIL import ImageDraw from PIL import ImageDraw
from PIL import ImageFont from PIL import ImageFont
......
...@@ -508,8 +508,7 @@ class Proxy(object): ...@@ -508,8 +508,7 @@ class Proxy(object):
except socket.error as sock_err: except socket.error as sock_err:
if sock_err.errno in [98, 48]: if sock_err.errno in [98, 48]:
continue continue
else: raise sock_err
raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port) logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
...@@ -569,7 +568,7 @@ def websocket_proxy_server(url, key=""): ...@@ -569,7 +568,7 @@ def websocket_proxy_server(url, key=""):
magic = struct.unpack('<i', msg[:4])[0] magic = struct.unpack('<i', msg[:4])[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH: if magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key) logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % url) raise RuntimeError("%s is not RPC Proxy" % url)
......
...@@ -161,11 +161,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -161,11 +161,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
conn.close() conn.close()
logger.warning("mismatch key from %s", addr) logger.warning("mismatch key from %s", addr)
continue continue
else: conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS)) conn.sendall(struct.pack("<i", len(server_key)))
conn.sendall(struct.pack("<i", len(server_key))) conn.sendall(server_key.encode("utf-8"))
conn.sendall(server_key.encode("utf-8")) return conn, addr, _parse_server_opt(arr[1:])
return conn, addr, _parse_server_opt(arr[1:])
# Server logic # Server logic
tracker_conn = None tracker_conn = None
...@@ -208,6 +207,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -208,6 +207,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
server_proc.join(opts.get("timeout", None)) server_proc.join(opts.get("timeout", None))
if server_proc.is_alive(): if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..") logger.info("Timeout in RPC session, kill..")
# pylint: disable=import-outside-toplevel
import psutil import psutil
parent = psutil.Process(server_proc.pid) parent = psutil.Process(server_proc.pid)
# terminate worker childs # terminate worker childs
...@@ -233,7 +233,8 @@ def _connect_proxy_loop(addr, key, load_library): ...@@ -233,7 +233,8 @@ def _connect_proxy_loop(addr, key, load_library):
magic = struct.unpack("<i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE: if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
if magic == base.RPC_CODE_MISMATCH:
logger.warning("RPCProxy do not have matching client key %s", key) logger.warning("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS: elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr)) raise RuntimeError("%s is not RPC Proxy" % str(addr))
...@@ -380,8 +381,7 @@ class Server(object): ...@@ -380,8 +381,7 @@ class Server(object):
except socket.error as sock_err: except socket.error as sock_err:
if sock_err.errno in [98, 48]: if sock_err.errno in [98, 48]:
continue continue
else: raise sock_err
raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port) logger.info("bind to %s:%d", host, self.port)
......
...@@ -92,8 +92,8 @@ class TCPHandler(object): ...@@ -92,8 +92,8 @@ class TCPHandler(object):
except socket.error as err: except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break break
else: self.on_error(err)
self.on_error(err)
if self._pending_write: if self._pending_write:
self._ioloop.update_handler( self._ioloop.update_handler(
self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE) self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE)
......
...@@ -393,8 +393,7 @@ class Tracker(object): ...@@ -393,8 +393,7 @@ class Tracker(object):
except socket.error as sock_err: except socket.error as sock_err:
if sock_err.errno in [98, 48]: if sock_err.errno in [98, 48]:
continue continue
else: raise sock_err
raise sock_err
if not self.port: if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port) logger.info("bind to %s:%d", host, self.port)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import, import-outside-toplevel
"""Runtime Module namespace.""" """Runtime Module namespace."""
import ctypes import ctypes
import struct import struct
......
...@@ -184,7 +184,7 @@ class NDArray(NDArrayBase): ...@@ -184,7 +184,7 @@ class NDArray(NDArrayBase):
""" """
if isinstance(target, NDArrayBase): if isinstance(target, NDArrayBase):
return self._copyto(target) return self._copyto(target)
elif isinstance(target, TVMContext): if isinstance(target, TVMContext):
res = empty(self.shape, self.dtype, target) res = empty(self.shape, self.dtype, target)
return self._copyto(res) return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
......
...@@ -179,7 +179,6 @@ class BaseComputeOp(Operation): ...@@ -179,7 +179,6 @@ class BaseComputeOp(Operation):
@tvm._ffi.register_object @tvm._ffi.register_object
class ComputeOp(BaseComputeOp): class ComputeOp(BaseComputeOp):
"""Scalar operation.""" """Scalar operation."""
pass
@tvm._ffi.register_object @tvm._ffi.register_object
......
...@@ -112,7 +112,7 @@ def decl_tensor_intrin(op, ...@@ -112,7 +112,7 @@ def decl_tensor_intrin(op,
raise TypeError("expect Operation") raise TypeError("expect Operation")
inputs = op.input_tensors inputs = op.input_tensors
binds = binds if binds else {} binds = binds if binds else {}
tensors = [x for x in inputs] tensors = list(inputs)
for i in range(op.num_outputs): for i in range(op.num_outputs):
tensors.append(op.output(i)) tensors.append(op.output(i))
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument # pylint: disable=invalid-name, unused-variable, no-else-return, unused-argument, import-outside-toplevel
"""Conv2D schedule for ARM CPU""" """Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
...@@ -528,8 +528,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -528,8 +528,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
Unlike other TOPI functions, this function operates on both graph level and operator level, Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, Relay. so we have to pass 'F' to make it support our two versions of graph IR, Relay.
""" """
copy_inputs = [s for s in inputs] copy_inputs = list(inputs)
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
if F.__name__ == 'tvm.relay.op': if F.__name__ == 'tvm.relay.op':
......
...@@ -74,8 +74,7 @@ def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -74,8 +74,7 @@ def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=3) dilation, out_dtype, num_tile=3)
else: raise ValueError("Unsupported layout {}".format(layout))
raise ValueError("Unsupported layout {}".format(layout))
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'bifrost', ['direct', 'winograd']) @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'bifrost', ['direct', 'winograd'])
......
...@@ -328,7 +328,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): ...@@ -328,7 +328,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs: if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs:
return None return None
copy_inputs = [s for s in inputs] copy_inputs = list(inputs)
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
# pylint: disable=bad-continuation, unused-argument
"""Non-maximum suppression operator""" """Non-maximum suppression operator"""
import math import math
import tvm import tvm
...@@ -397,7 +398,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): ...@@ -397,7 +398,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
name="get_valid_counts_phase_four") name="get_valid_counts_phase_four")
valid_count, out_tensor = \ valid_count, out_tensor = \
tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
lambda ins, outs: get_valid_counts_ir( lambda ins, outs: get_valid_counts_ir(
ins[0], ins[1], ins[2], outs[0], outs[1]), ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype], dtype=["int32", data.dtype],
in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, singleton-comparison # pylint: disable=invalid-name, singleton-comparison, bad-continuation
"""Proposal operator""" """Proposal operator"""
import math import math
import tvm import tvm
...@@ -177,7 +177,7 @@ def argsort_ir(data_buf, out_index_buf): ...@@ -177,7 +177,7 @@ def argsort_ir(data_buf, out_index_buf):
with ib.for_range(0, num_bbox) as k: with ib.for_range(0, num_bbox) as k:
offset = start + 2 * tid + idxm(k, 2) offset = start + 2 * tid + idxm(k, 2)
with ib.if_scope( with ib.if_scope(
tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])): tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
temp_data[0] = p_data[offset] temp_data[0] = p_data[offset]
p_data[offset] = p_data[offset + 1] p_data[offset] = p_data[offset + 1]
p_data[offset + 1] = temp_data[0] p_data[offset + 1] = temp_data[0]
......
...@@ -54,9 +54,9 @@ def schedule_softmax(outs): ...@@ -54,9 +54,9 @@ def schedule_softmax(outs):
if len(softmax.shape) > 2: if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op] ops = [max_elem.op, expsum.op, softmax.op]
if exp != None: if exp is not None:
ops.append(exp.op) ops.append(exp.op)
for op in ops: for op in ops:
s = schedule_injective_from_existing(s, op.output(0)) s = schedule_injective_from_existing(s, op.output(0))
else: else:
...@@ -64,7 +64,7 @@ def schedule_softmax(outs): ...@@ -64,7 +64,7 @@ def schedule_softmax(outs):
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
if exp != None: if exp is not None:
s[exp].bind(exp.op.axis[0], block_x) s[exp].bind(exp.op.axis[0], block_x)
s[max_elem].bind(max_elem.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x)
......
...@@ -42,6 +42,7 @@ def _schedule_sort(outs): ...@@ -42,6 +42,7 @@ def _schedule_sort(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = [] scheduled_ops = []
# pylint: disable=import-outside-toplevel
from .injective import schedule_injective_from_existing from .injective import schedule_injective_from_existing
def traverse(op): def traverse(op):
if tag.is_injective(op.tag): if tag.is_injective(op.tag):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-member # pylint: disable=invalid-name, unused-variable, unused-argument, no-member, import-outside-toplevel
"""Schedule for vision operators""" """Schedule for vision operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
......
...@@ -275,7 +275,7 @@ def schedule_softmax(outs): ...@@ -275,7 +275,7 @@ def schedule_softmax(outs):
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag)) Got {0}'.format(op_tag))
if exp != None: if exp is not None:
s[exp].compute_at(s[softmax], s[softmax].op.axis[1]) s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
s[expsum].compute_at(s[softmax], s[softmax].op.axis[1]) s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
......
...@@ -38,17 +38,17 @@ from ..util import simplify, get_const_tuple ...@@ -38,17 +38,17 @@ from ..util import simplify, get_const_tuple
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
if is_depthwise: if is_depthwise:
raise RuntimeError("Depthwise not supported for intel graphics.") raise RuntimeError("Depthwise not supported for intel graphics.")
else:
batch_size, in_channel, height, width = get_const_tuple(data.shape)
out_channel, _, hkernel, _ = get_const_tuple(kernel.shape)
HSTR, _ = strides
ic_bn = 1 batch_size, in_channel, height, width = get_const_tuple(data.shape)
oc_bn, oc_bn_upper = 16, 16 out_channel, _, hkernel, _ = get_const_tuple(kernel.shape)
for i in range(oc_bn_upper, 0, -1): HSTR, _ = strides
if out_channel % i == 0:
oc_bn = i ic_bn = 1
break oc_bn, oc_bn_upper = 16, 16
for i in range(oc_bn_upper, 0, -1):
if out_channel % i == 0:
oc_bn = i
break
if HSTR == 2: if HSTR == 2:
if out_channel + hkernel == 515: if out_channel + hkernel == 515:
...@@ -189,7 +189,7 @@ def __topi_nn_conv2d_NCHWc(*args, **kwargs): ...@@ -189,7 +189,7 @@ def __topi_nn_conv2d_NCHWc(*args, **kwargs):
@conv2d_alter_layout.register(["intel_graphics"]) @conv2d_alter_layout.register(["intel_graphics"])
def _alter_conv2d_layout(attrs, inputs, tinfo, F): def _alter_conv2d_layout(attrs, inputs, tinfo, F):
copy_inputs = [s for s in inputs] copy_inputs = list(inputs)
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
if F.__name__ == 'tvm.relay.op': if F.__name__ == 'tvm.relay.op':
......
...@@ -60,7 +60,7 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): ...@@ -60,7 +60,7 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"):
for i in range(n+1): for i in range(n+1):
if i == bit_axis: if i == bit_axis:
continue continue
elif i == pack_axis: if i == pack_axis:
idx[j] = indices[i] * data_width + k idx[j] = indices[i] * data_width + k
else: else:
idx[j] = indices[i] idx[j] = indices[i]
...@@ -88,4 +88,3 @@ def binary_op_multiplier(pack_dtype): ...@@ -88,4 +88,3 @@ def binary_op_multiplier(pack_dtype):
pack_dtype: string pack_dtype: string
pack type for the operator (must be a uint)""" pack type for the operator (must be a uint)"""
return int(pack_dtype[4:]) return int(pack_dtype[4:])
\ No newline at end of file
...@@ -66,9 +66,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -66,9 +66,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype) return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'HWCN': if layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype) return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'NHWC': if layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype) return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
...@@ -764,6 +764,7 @@ def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_d ...@@ -764,6 +764,7 @@ def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_d
output : tvm.Tensor output : tvm.Tensor
4-D with shape [alpha, alpha, CO, CI] 4-D with shape [alpha, alpha, CO, CI]
""" """
# pylint: disable=import-outside-toplevel
from tvm.contrib import nnpack from tvm.contrib import nnpack
return nnpack.convolution_inference_weight_transform( return nnpack.convolution_inference_weight_transform(
kernel, algorithm=convolution_algorithm, dtype=out_dtype) kernel, algorithm=convolution_algorithm, dtype=out_dtype)
......
...@@ -76,7 +76,7 @@ def fifo_buffer(data, buffer, axis): ...@@ -76,7 +76,7 @@ def fifo_buffer(data, buffer, axis):
buffer[i + data_size], buffer[i + data_size],
data[i - buflen + data_size]), data[i - buflen + data_size]),
name='new_buffer') name='new_buffer')
elif len(buffer.shape) == 2: if len(buffer.shape) == 2:
if axis == 0: if axis == 0:
return tvm.compute(buffer.shape, return tvm.compute(buffer.shape,
lambda i, j: lambda i, j:
......
...@@ -51,7 +51,7 @@ def schedule_softmax(outs): ...@@ -51,7 +51,7 @@ def schedule_softmax(outs):
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag)) Got {0}'.format(op_tag))
if exp != None: if exp is not None:
s[exp].opengl() s[exp].opengl()
s[max_elem].opengl() s[max_elem].opengl()
......
...@@ -62,7 +62,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): ...@@ -62,7 +62,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
indices_index += 1 indices_index += 1
out = np.empty(oshape) out = np.empty(oshape)
output_indices = [index for index in np.ndindex(out.shape)] output_indices = list(np.ndindex(out.shape))
for output_index in output_indices: for output_index in output_indices:
indices_indices = [] indices_indices = []
for i, out_idx in enumerate(output_index): for i, out_idx in enumerate(output_index):
......
...@@ -238,13 +238,10 @@ def strided_set(a, v, begin, end, strides=None): ...@@ -238,13 +238,10 @@ def strided_set(a, v, begin, end, strides=None):
from_val = [] from_val = []
index_tuple = [] index_tuple = []
for i in range(n): for i in range(n):
from_val.append( from_val.append(within_index(begin[i], end[i], strides[i], indices[i]))
within_index(begin[i], end[i], strides[i], indices[i]))
index_tuple.append( index_tuple.append(
make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i])) make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i]))
return tvm.if_then_else(tvm.all(*from_val), return tvm.if_then_else(tvm.all(*from_val), v(*index_tuple), a(*indices))
v(*index_tuple),
a(*indices))
return tvm.compute(a.shape, _select, name="strided_set") return tvm.compute(a.shape, _select, name="strided_set")
...@@ -568,7 +565,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): ...@@ -568,7 +565,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
assert len(data.shape) >= 2,\ assert len(data.shape) >= 2,\
"only support data.ndim >= 2, received data.shape = {}".format(data.shape) "only support data.ndim >= 2, received data.shape = {}".format(data.shape)
assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis) assert axis in (0, 1), "only support axis = 0, 1, received axis = {}".format(axis)
return cpp.sequence_mask(data, valid_length, mask_value, axis) return cpp.sequence_mask(data, valid_length, mask_value, axis)
......
...@@ -25,7 +25,6 @@ from . import tag, cpp ...@@ -25,7 +25,6 @@ from . import tag, cpp
class InvalidShapeError(ValueError): class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass
def nchw_pack_layout(layout_info): def nchw_pack_layout(layout_info):
"""Check whether the layout type is NCHWinic""" """Check whether the layout type is NCHWinic"""
...@@ -350,7 +349,7 @@ def get_shape(src_shape, src_layout, dst_layout): ...@@ -350,7 +349,7 @@ def get_shape(src_shape, src_layout, dst_layout):
layout_mapping = bijective_layout(src_layout, dst_layout) layout_mapping = bijective_layout(src_layout, dst_layout)
dst_indices = layout_mapping.forward_index( dst_indices = layout_mapping.forward_index(
tvm.convert([i for i in range(len(src_layout))])) tvm.convert(list(range(len(src_layout)))))
return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, singleton-comparison # pylint: disable=invalid-name, singleton-comparison, bad-continuation
"""Proposal operator""" """Proposal operator"""
import math import math
import tvm import tvm
...@@ -303,7 +303,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): ...@@ -303,7 +303,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
with ib.for_range(0, batch) as b: with ib.for_range(0, batch) as b:
with ib.if_scope(nkeep[b] > 0): with ib.if_scope(nkeep[b] > 0):
with ib.for_range(0, tvm.ceil( with ib.for_range(0, tvm.ceil(
tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')): tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')):
with ib.for_range(0, num_bbox) as j: with ib.for_range(0, num_bbox) as j:
offset_j = (b * num_bbox + j) * 5 offset_j = (b * num_bbox + j) * 5
offset_i = (b * rpn_post_nms_top_n + i[b]) * 5 offset_i = (b * rpn_post_nms_top_n + i[b]) * 5
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member # pylint: disable=invalid-name,unused-variable,unused-argument,no-member,import-outside-toplevel
"""Conv2D schedule on x86""" """Conv2D schedule on x86"""
import logging import logging
...@@ -126,7 +126,7 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out ...@@ -126,7 +126,7 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
# # specialize for INT8 1X1 conv on X86 # # specialize for INT8 1X1 conv on X86
# return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, # return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
# padding, dilation, out_dtype) # padding, dilation, out_dtype)
elif layout == 'NHWC': if layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -63,7 +63,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -63,7 +63,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
is_depthwise = groups == kshape[0] and kshape[1] == 1 is_depthwise = groups == kshape[0] and kshape[1] == 1
# Save the input exprs. # Save the input exprs.
copy_inputs = [s for s in inputs] copy_inputs = list(inputs)
# Set the new attrs # Set the new attrs
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member # pylint: disable=invalid-name,unused-variable,unused-argument,no-member, import-outside-toplevel
"""Conv2D int8 schedule on x86""" """Conv2D int8 schedule on x86"""
import re import re
...@@ -70,7 +70,7 @@ def _is_int8_hw_support(data_dtype, kernel_dtype): ...@@ -70,7 +70,7 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
# 3) Check target # 3) Check target
mcpu = tvm.target.current_target().mcpu mcpu = tvm.target.current_target().mcpu
is_target_support = False is_target_support = False
if mcpu == 'skylake-avx512' or mcpu == 'cascadelake': if mcpu in ('skylake-avx512', 'cascadelake'):
is_target_support = True is_target_support = True
return is_dtype_support and is_llvm_support and is_target_support return is_dtype_support and is_llvm_support and is_target_support
......
...@@ -63,7 +63,7 @@ def schedule_softmax(outs): ...@@ -63,7 +63,7 @@ def schedule_softmax(outs):
s[max_elem].compute_at(s[softmax], fused_outer_axes) s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes) s[expsum].compute_at(s[softmax], fused_outer_axes)
if exp != None: if exp is not None:
s[exp].compute_at(s[softmax], fused_outer_axes) s[exp].compute_at(s[softmax], fused_outer_axes)
return s return s
...@@ -21,6 +21,6 @@ import tvm ...@@ -21,6 +21,6 @@ import tvm
def get_fp32_len(): def get_fp32_len():
mcpu = tvm.target.current_target().mcpu mcpu = tvm.target.current_target().mcpu
fp32_vec_len = 8 fp32_vec_len = 8
if mcpu == 'skylake-avx512' or mcpu == 'cascadelake': if mcpu in ('skylake-avx512', 'cascadelake'):
fp32_vec_len = 16 fp32_vec_len = 16
return fp32_vec_len return fp32_vec_len
...@@ -79,11 +79,10 @@ compilation guide to get Xilinx toolchains setup) and add it to your \ ...@@ -79,11 +79,10 @@ compilation guide to get Xilinx toolchains setup) and add it to your \
$VTA_CACHE_PATH. Alternatively edit your config.json back to its default \ $VTA_CACHE_PATH. Alternatively edit your config.json back to its default \
settings. You can see the list of available bitstreams under {}" settings. You can see the list of available bitstreams under {}"
.format(url, BITSTREAM_URL)) .format(url, BITSTREAM_URL))
else: raise RuntimeError(
raise RuntimeError( # This could happen when trying to access the URL behind a proxy
# This could happen when trying to access the URL behind a proxy "Something went wrong when trying to access {}. Check your \
"Something went wrong when trying to access {}. Check your \
internet connection or proxy settings." internet connection or proxy settings."
.format(url)) .format(url))
return success return success
...@@ -231,9 +231,9 @@ class Environment(object): ...@@ -231,9 +231,9 @@ class Environment(object):
"""The target host""" """The target host"""
if self.TARGET in ["pynq", "de10nano"]: if self.TARGET in ["pynq", "de10nano"]:
return "llvm -target=armv7-none-linux-gnueabihf" return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "ultra96": if self.TARGET == "ultra96":
return "llvm -target=aarch64-linux-gnu" return "llvm -target=aarch64-linux-gnu"
elif self.TARGET in ["sim", "tsim"]: if self.TARGET in ["sim", "tsim"]:
return "llvm" return "llvm"
raise ValueError("Unknown target %s" % self.TARGET) raise ValueError("Unknown target %s" % self.TARGET)
......
...@@ -66,6 +66,7 @@ def server_start(): ...@@ -66,6 +66,7 @@ def server_start():
@tvm.register_func("tvm.contrib.vta.init", override=True) @tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name): def program_fpga(file_name):
# pylint: disable=import-outside-toplevel
from pynq import xlnk from pynq import xlnk
# Reset xilinx driver # Reset xilinx driver
xlnk.Xlnk().xlnk_reset() xlnk.Xlnk().xlnk_reset()
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Additional IR Pass for VTA""" """Additional IR Pass for VTA"""
# pylint: disable=len-as-condition # pylint: disable=len-as-condition, no-else-return
from __future__ import absolute_import as _abs
import tvm import tvm
from topi import util from topi import util
......
...@@ -43,6 +43,7 @@ def main(): ...@@ -43,6 +43,7 @@ def main():
bitstream_program(args.target, args.bitstream) bitstream_program(args.target, args.bitstream)
def pynq_bitstream_program(bitstream_path): def pynq_bitstream_program(bitstream_path):
# pylint: disable=import-outside-toplevel
from pynq import Bitstream from pynq import Bitstream
bitstream = Bitstream(bitstream_path) bitstream = Bitstream(bitstream_path)
bitstream.download() bitstream.download()
......
...@@ -151,14 +151,12 @@ class ExprPack(ExprMutator): ...@@ -151,14 +151,12 @@ class ExprPack(ExprMutator):
assert not self.start_pack assert not self.start_pack
self.start_pack = True self.start_pack = True
return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor) return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor)
elif call.op == self.bitpack_end: if call.op == self.bitpack_end:
if self.start_pack: if self.start_pack:
self.start_pack = False self.start_pack = False
data = args[0] data = args[0]
data_shape = _get_shape(call.args[0]) data_shape = _get_shape(call.args[0])
return _unpack_batch_channel(data, data_shape) return _unpack_batch_channel(data, data_shape)
else:
pass
if self.start_pack: if self.start_pack:
# Operator cases # Operator cases
if call.op == self.conv2d and odtype == 'int32': if call.op == self.conv2d and odtype == 'int32':
...@@ -188,7 +186,8 @@ class ExprPack(ExprMutator): ...@@ -188,7 +186,8 @@ class ExprPack(ExprMutator):
kernel_layout=kernel_layout, kernel_layout=kernel_layout,
out_dtype=call.attrs.out_dtype) out_dtype=call.attrs.out_dtype)
return conv2d return conv2d
elif call.op == self.conv2d_transpose and odtype == 'int32':
if call.op == self.conv2d_transpose and odtype == 'int32':
self.number_of_conv2d += 1 self.number_of_conv2d += 1
assert 8 % self.weight_bits == 0 assert 8 % self.weight_bits == 0
w_lanes = 8 // self.weight_bits w_lanes = 8 // self.weight_bits
...@@ -213,7 +212,7 @@ class ExprPack(ExprMutator): ...@@ -213,7 +212,7 @@ class ExprPack(ExprMutator):
output_padding=call.attrs.output_padding, output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype) out_dtype=call.attrs.out_dtype)
return conv2d return conv2d
elif call.op == self.add and \ if call.op == self.add and \
tuple(input_types[0].shape) == tuple(input_types[1].shape): tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass pass
elif call.op == self.add and len(input_types[1].shape) == 3: elif call.op == self.add and len(input_types[1].shape) == 3:
...@@ -272,7 +271,7 @@ def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, cou ...@@ -272,7 +271,7 @@ def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, cou
_recursion(anf.body, start_found, stop_found, _recursion(anf.body, start_found, stop_found,
operator_current_idx), operator_current_idx),
anf.ret_type, anf.type_params, anf.attrs) anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let): if isinstance(anf, relay.expr.Let):
value = anf.value value = anf.value
if isinstance(value, relay.expr.Call): if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op): if isinstance(value.op, relay.op.Op):
......
...@@ -127,10 +127,9 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target): ...@@ -127,10 +127,9 @@ def compute_conv2d_transpose(attrs, inputs, output_type, target):
if is_packed_layout(layout): if is_packed_layout(layout):
return [topi.nn.conv2d_transpose_nchw( return [topi.nn.conv2d_transpose_nchw(
inputs[0], inputs[1], strides, padding, out_dtype)] inputs[0], inputs[1], strides, padding, out_dtype)]
else: # If it's not packed, run on ARM CPU
# If it's not packed, run on ARM CPU with tvm.target.arm_cpu(tvm.target.current_target().model):
with tvm.target.arm_cpu(tvm.target.current_target().model): return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target) return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
...@@ -145,10 +144,9 @@ def schedule_conv2d_transpose(attrs, outputs, target): ...@@ -145,10 +144,9 @@ def schedule_conv2d_transpose(attrs, outputs, target):
if target.device_name == "vta": if target.device_name == "vta":
if is_packed_layout(layout): if is_packed_layout(layout):
return topi.nn.schedule_conv2d_transpose_nchw(outputs) return topi.nn.schedule_conv2d_transpose_nchw(outputs)
else: # If it's not packed, run on ARM CPU
# If it's not packed, run on ARM CPU with tvm.target.arm_cpu(tvm.target.current_target().model):
with tvm.target.arm_cpu(tvm.target.current_target().model): return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target())
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target())
# If VTA is not the target, default to _nn def # If VTA is not the target, default to _nn def
return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target()) return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.current_target())
......
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi import topi
import vta import vta
import vta.testing import vta.testing
......
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi import topi
import vta import vta
import vta.testing import vta.testing
......
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi import topi
import vta import vta
import vta.testing import vta.testing
......
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi import topi
import vta import vta
import vta.testing import vta.testing
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment