Commit e20ef0d4 by Marcus Shawcroft Committed by Tianqi Chen

Fix pylint 2.2.2 gripes. (#2642)

parent 81334be3
...@@ -31,7 +31,7 @@ else: ...@@ -31,7 +31,7 @@ else:
class NNVMError(Exception): class NNVMError(Exception):
"""Error that will be throwed by all nnvm functions""" """Error that will be throwed by all nnvm functions"""
pass
def _load_lib(): def _load_lib():
"""Load libary by searching possible path.""" """Load libary by searching possible path."""
......
...@@ -42,7 +42,6 @@ class AttrScope(object): ...@@ -42,7 +42,6 @@ class AttrScope(object):
if attr: if attr:
ret.update(attr) ret.update(attr)
return ret return ret
else:
return attr return attr
def __enter__(self): def __enter__(self):
......
...@@ -23,13 +23,11 @@ class GraphKey(tvm.node.NodeBase): ...@@ -23,13 +23,11 @@ class GraphKey(tvm.node.NodeBase):
@tvm.register_node @tvm.register_node
class GraphCacheEntry(tvm.node.NodeBase): class GraphCacheEntry(tvm.node.NodeBase):
"""CacheEntry of compilation into a TVM Function""" """CacheEntry of compilation into a TVM Function"""
pass
@tvm.register_node @tvm.register_node
class GraphFunc(tvm.node.NodeBase): class GraphFunc(tvm.node.NodeBase):
"""Compiled result of a graph into a TVM Function""" """Compiled result of a graph into a TVM Function"""
pass
class Engine(object): class Engine(object):
......
...@@ -73,7 +73,6 @@ class Caffe2OpConverter(object): ...@@ -73,7 +73,6 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'): if hasattr(cls, '_impl'):
return getattr(cls, '_impl') return getattr(cls, '_impl')
else:
raise NotImplementedError('{} not implemented'.format( raise NotImplementedError('{} not implemented'.format(
cls.__name__)) cls.__name__))
...@@ -175,9 +174,8 @@ class Concat(Caffe2OpConverter): ...@@ -175,9 +174,8 @@ class Concat(Caffe2OpConverter):
order = order if isinstance(order, str) else order.decode('UTF-8') order = order if isinstance(order, str) else order.decode('UTF-8')
if order == 'NCHW': if order == 'NCHW':
return 1 return 1
elif order == 'NHWC': if order == 'NHWC':
return 3 return 3
else:
raise RuntimeError( raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order)) "Unsupported storage order: {} in caffe2".format(order))
......
...@@ -98,33 +98,33 @@ def ActivationParams(op, insym, symtab): ...@@ -98,33 +98,33 @@ def ActivationParams(op, insym, symtab):
par = getattr(op, whichActivation) par = getattr(op, whichActivation)
if whichActivation == 'linear': if whichActivation == 'linear':
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, scalar=par.alpha), scalar=par.beta) return _sym.__add_scalar__(_sym.__mul_scalar__(insym, scalar=par.alpha), scalar=par.beta)
elif whichActivation == 'ReLU': if whichActivation == 'ReLU':
return _sym.relu(insym) return _sym.relu(insym)
elif whichActivation == 'leakyReLU': if whichActivation == 'leakyReLU':
return _sym.leaky_relu(insym, alpha=par.alpha) return _sym.leaky_relu(insym, alpha=par.alpha)
elif whichActivation == 'thresholdedReLU': if whichActivation == 'thresholdedReLU':
alpha_tensor = _sym.full_like(insym, fill_value=float(par.alpha)) alpha_tensor = _sym.full_like(insym, fill_value=float(par.alpha))
return _sym.elemwise_mul(insym, _sym.greater(insym, alpha_tensor)) return _sym.elemwise_mul(insym, _sym.greater(insym, alpha_tensor))
elif whichActivation == 'PReLU': if whichActivation == 'PReLU':
return _sym.prelu(insym, alpha=par.alpha) return _sym.prelu(insym, alpha=par.alpha)
elif whichActivation == 'tanh': if whichActivation == 'tanh':
return _sym.tanh(insym) return _sym.tanh(insym)
elif whichActivation == 'scaledTanh': if whichActivation == 'scaledTanh':
return _sym.__mul_scalar__(_sym.tanh(_sym.__mul_scalar__( return _sym.__mul_scalar__(_sym.tanh(_sym.__mul_scalar__(
insym, scalar=par.beta)), scalar=par.alpha) insym, scalar=par.beta)), scalar=par.alpha)
elif whichActivation == 'sigmoid': if whichActivation == 'sigmoid':
return _sym.sigmoid(insym) return _sym.sigmoid(insym)
elif whichActivation == 'sigmoidHard': if whichActivation == 'sigmoidHard':
transformX = (par.alpha * insym) + par.beta transformX = (par.alpha * insym) + par.beta
return _sym.clip(transformX, a_min=0, a_max=1) return _sym.clip(transformX, a_min=0, a_max=1)
elif whichActivation == 'ELU': if whichActivation == 'ELU':
return _sym.__mul_scalar__(_sym.__add_scalar__( return _sym.__mul_scalar__(_sym.__add_scalar__(
_sym.exp(insym), scalar=-1), scalar=par.alpha) _sym.exp(insym), scalar=-1), scalar=par.alpha)
elif whichActivation == 'softsign': if whichActivation == 'softsign':
return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym)))) return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym))))
elif whichActivation == 'softplus': if whichActivation == 'softplus':
return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1)) return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1))
elif whichActivation == 'parametricSoftplus': if whichActivation == 'parametricSoftplus':
alpha = list(par.alpha.floatValue) alpha = list(par.alpha.floatValue)
beta = list(par.alpha.floatValue) beta = list(par.alpha.floatValue)
if len(alpha) == 1: if len(alpha) == 1:
...@@ -136,7 +136,6 @@ def ActivationParams(op, insym, symtab): ...@@ -136,7 +136,6 @@ def ActivationParams(op, insym, symtab):
betasym = symtab.new_const(beta) betasym = symtab.new_const(beta)
return _sym.broadcast_mul(_sym.log(_sym.broadcast_add( return _sym.broadcast_mul(_sym.log(_sym.broadcast_add(
_sym.exp(insym), betasym)), alphasym) _sym.exp(insym), betasym)), alphasym)
else:
raise NotImplementedError('%s not implemented' % whichActivation) raise NotImplementedError('%s not implemented' % whichActivation)
def ScaleLayerParams(op, insym, symtab): def ScaleLayerParams(op, insym, symtab):
...@@ -157,9 +156,8 @@ def PoolingLayerParams(op, insym, symtab): ...@@ -157,9 +156,8 @@ def PoolingLayerParams(op, insym, symtab):
if op.globalPooling: if op.globalPooling:
if op.type == 0: if op.type == 0:
return _sym.global_max_pool2d(insym) return _sym.global_max_pool2d(insym)
elif op.type == 1: if op.type == 1:
return _sym.global_avg_pool2d(insym) return _sym.global_avg_pool2d(insym)
else:
raise NotImplementedError("Only max and average pooling implemented") raise NotImplementedError("Only max and average pooling implemented")
else: else:
...@@ -190,9 +188,8 @@ def PoolingLayerParams(op, insym, symtab): ...@@ -190,9 +188,8 @@ def PoolingLayerParams(op, insym, symtab):
if op.type == 0: if op.type == 0:
return _sym.max_pool2d(insym, **params) return _sym.max_pool2d(insym, **params)
elif op.type == 1: if op.type == 1:
return _sym.avg_pool2d(insym, **params) return _sym.avg_pool2d(insym, **params)
else:
raise NotImplementedError("Only max and average pooling implemented") raise NotImplementedError("Only max and average pooling implemented")
def SoftmaxLayerParams(op, insym, symtab): def SoftmaxLayerParams(op, insym, symtab):
......
...@@ -921,8 +921,6 @@ class GraphProto(object): ...@@ -921,8 +921,6 @@ class GraphProto(object):
if layer_num != self.net.n-1: if layer_num != self.net.n-1:
self._outs.insert(0, sym) self._outs.insert(0, sym)
return
def from_darknet(self): def from_darknet(self):
"""To convert the darknet symbol to nnvm symbols.""" """To convert the darknet symbol to nnvm symbols."""
for i in range(self.net.n): for i in range(self.net.n):
......
...@@ -47,34 +47,33 @@ def _convert_activation(insym, keras_layer, _): ...@@ -47,34 +47,33 @@ def _convert_activation(insym, keras_layer, _):
beta = keras_layer.beta if hasattr(keras_layer, "beta") else 0 beta = keras_layer.beta if hasattr(keras_layer, "beta") else 0
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \ return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \
scalar=alpha), scalar=beta) scalar=alpha), scalar=beta)
elif act_type == 'softmax': if act_type == 'softmax':
return _sym.softmax(insym, axis=1) return _sym.softmax(insym, axis=1)
elif act_type == 'sigmoid': if act_type == 'sigmoid':
return _sym.sigmoid(insym) return _sym.sigmoid(insym)
elif act_type == 'tanh': if act_type == 'tanh':
return _sym.tanh(insym) return _sym.tanh(insym)
elif act_type == 'relu': if act_type == 'relu':
return _sym.relu(insym) return _sym.relu(insym)
elif act_type == 'softplus': if act_type == 'softplus':
return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1)) return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1))
elif act_type == 'elu': if act_type == 'elu':
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1 alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1
return _get_elu(insym, alpha) return _get_elu(insym, alpha)
elif act_type == 'selu': if act_type == 'selu':
# Alpha, Gamma values, obtained from https://arxiv.org/abs/1706.02515 # Alpha, Gamma values, obtained from https://arxiv.org/abs/1706.02515
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") \ alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") \
else 1.6732632423543772848170429916717 else 1.6732632423543772848170429916717
gamma = keras_layer.gamma if hasattr(keras_layer, "gamma") \ gamma = keras_layer.gamma if hasattr(keras_layer, "gamma") \
else 1.0507009873554804934193349852946 else 1.0507009873554804934193349852946
return gamma * _get_elu(insym, alpha) return gamma * _get_elu(insym, alpha)
elif act_type == 'relu6': if act_type == 'relu6':
return _sym.clip(insym, a_min=0, a_max=6) return _sym.clip(insym, a_min=0, a_max=6)
elif act_type == 'softsign': if act_type == 'softsign':
return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym)))) return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym))))
elif act_type == 'hard_sigmoid': if act_type == 'hard_sigmoid':
transformX = (0.2 * insym) + 0.5 transformX = (0.2 * insym) + 0.5
return _sym.clip(transformX, a_min=0, a_max=1) return _sym.clip(transformX, a_min=0, a_max=1)
else:
raise TypeError("Unsupported activation type : {}".format(act_type)) raise TypeError("Unsupported activation type : {}".format(act_type))
...@@ -84,12 +83,12 @@ def _convert_advanced_activation(insym, keras_layer, symtab): ...@@ -84,12 +83,12 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
if keras_layer.max_value: if keras_layer.max_value:
return _sym.clip(insym, a_min=0, a_max=keras_layer.max_value) return _sym.clip(insym, a_min=0, a_max=keras_layer.max_value)
return _sym.relu(insym) return _sym.relu(insym)
elif act_type == 'LeakyReLU': if act_type == 'LeakyReLU':
return _sym.leaky_relu(insym, alpha=keras_layer.alpha) return _sym.leaky_relu(insym, alpha=keras_layer.alpha)
elif act_type == 'ELU': if act_type == 'ELU':
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1 alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1
return _get_elu(insym, alpha) return _get_elu(insym, alpha)
elif act_type == 'PReLU': if act_type == 'PReLU':
assert hasattr(keras_layer, "alpha"), \ assert hasattr(keras_layer, "alpha"), \
"alpha required for PReLU." "alpha required for PReLU."
_check_data_format(keras_layer) _check_data_format(keras_layer)
...@@ -97,11 +96,10 @@ def _convert_advanced_activation(insym, keras_layer, symtab): ...@@ -97,11 +96,10 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
return -symtab.new_const(keras_layer.get_weights()[0] \ return -symtab.new_const(keras_layer.get_weights()[0] \
.transpose(np.roll(range(size), 1))) \ .transpose(np.roll(range(size), 1))) \
* _sym.relu(-insym) + _sym.relu(insym) * _sym.relu(-insym) + _sym.relu(insym)
elif act_type == 'ThresholdedReLU': if act_type == 'ThresholdedReLU':
theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0 theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0
theta_tensor = _sym.full_like(insym[0], fill_value=float(theta)) theta_tensor = _sym.full_like(insym[0], fill_value=float(theta))
return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32")) return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32"))
else:
raise TypeError("Unsupported advanced activation type : {}".format(act_type)) raise TypeError("Unsupported advanced activation type : {}".format(act_type))
...@@ -280,9 +278,8 @@ def _convert_pooling(insym, keras_layer, symtab): ...@@ -280,9 +278,8 @@ def _convert_pooling(insym, keras_layer, symtab):
# global pool in keras = global pool + flatten in nnvm # global pool in keras = global pool + flatten in nnvm
if pool_type == 'GlobalMaxPooling2D': if pool_type == 'GlobalMaxPooling2D':
return _convert_flatten(_sym.global_max_pool2d(insym), keras_layer, symtab) return _convert_flatten(_sym.global_max_pool2d(insym), keras_layer, symtab)
elif pool_type == 'GlobalAveragePooling2D': if pool_type == 'GlobalAveragePooling2D':
return _convert_flatten(_sym.global_avg_pool2d(insym), keras_layer, symtab) return _convert_flatten(_sym.global_avg_pool2d(insym), keras_layer, symtab)
else:
pool_h, pool_w = keras_layer.pool_size pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w], params = {'pool_size': [pool_h, pool_w],
...@@ -300,10 +297,9 @@ def _convert_pooling(insym, keras_layer, symtab): ...@@ -300,10 +297,9 @@ def _convert_pooling(insym, keras_layer, symtab):
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D': if pool_type == 'MaxPooling2D':
return _sym.max_pool2d(insym, **params) return _sym.max_pool2d(insym, **params)
elif pool_type == 'AveragePooling2D': if pool_type == 'AveragePooling2D':
# TODO: in keras, padded zeros are not calculated # TODO: in keras, padded zeros are not calculated
return _sym.avg_pool2d(insym, **params) return _sym.avg_pool2d(insym, **params)
else:
raise TypeError("Unsupported pooling type : {}".format(keras_layer)) raise TypeError("Unsupported pooling type : {}".format(keras_layer))
......
...@@ -424,7 +424,7 @@ def _topo_sort(symbol): ...@@ -424,7 +424,7 @@ def _topo_sort(symbol):
if childs is None: if childs is None:
dep_cnts[name] = 0 dep_cnts[name] = 0
else: else:
dep_cnts[name] = len(set([c.attr('name') for c in childs])) dep_cnts[name] = len({c.attr('name') for c in childs})
for child in childs: for child in childs:
child_name = child.attr('name') child_name = child.attr('name')
if child_name not in deps: if child_name not in deps:
......
...@@ -9,7 +9,6 @@ def dimension_picker(prefix, surfix=''): ...@@ -9,7 +9,6 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
......
...@@ -68,7 +68,6 @@ def _dimension_picker(prefix, surfix=''): ...@@ -68,7 +68,6 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
...@@ -433,7 +432,6 @@ def _reshape(): ...@@ -433,7 +432,6 @@ def _reshape():
op_name="reshape", op_name="reshape",
extras={'shape':tuple(params_new[0].asnumpy().flatten())}, extras={'shape':tuple(params_new[0].asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr) ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.") raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl return _impl
...@@ -1394,7 +1392,7 @@ class GraphProto(object): ...@@ -1394,7 +1392,7 @@ class GraphProto(object):
self._nodes[name] = _sym.Variable(name=name, self._nodes[name] = _sym.Variable(name=name,
shape=self._params[name].shape) shape=self._params[name].shape)
else: else:
if key != 'dtype' and key != '_output_shapes' and key != '_class': if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \ raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key)) ("Other attributes for a Const(param) Node {} ? .".format(key))
......
...@@ -115,6 +115,8 @@ class TFParser(object): ...@@ -115,6 +115,8 @@ class TFParser(object):
"""TODO: Load checkpoint model.""" """TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is " raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.") "not supported yet.")
# pylint: disable=unreachable
return 0
def parse(self): def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb """Parse tensorflow models: checkpoints, saved models, and single pb
......
...@@ -50,9 +50,8 @@ class Symbol(SymbolBase): ...@@ -50,9 +50,8 @@ class Symbol(SymbolBase):
"""x.__add__(y) <=> x+y""" """x.__add__(y) <=> x+y"""
if isinstance(other, Symbol): if isinstance(other, Symbol):
return __add_symbol__(self, other) return __add_symbol__(self, other)
elif isinstance(other, _Number): if isinstance(other, _Number):
return __add_scalar__(self, scalar=other) return __add_scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other))) raise TypeError("type %s not supported" % str(type(other)))
def __radd__(self, other): def __radd__(self, other):
...@@ -64,13 +63,11 @@ class Symbol(SymbolBase): ...@@ -64,13 +63,11 @@ class Symbol(SymbolBase):
return __sub_symbol__(self, other) return __sub_symbol__(self, other)
if isinstance(other, _Number): if isinstance(other, _Number):
return __sub_scalar__(self, scalar=other) return __sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other): def __rsub__(self, other):
if isinstance(other, _Number): if isinstance(other, _Number):
return __rsub_scalar__(self, scalar=other) return __rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __mul__(self, other): def __mul__(self, other):
...@@ -79,7 +76,6 @@ class Symbol(SymbolBase): ...@@ -79,7 +76,6 @@ class Symbol(SymbolBase):
return __mul_symbol__(self, other) return __mul_symbol__(self, other)
if isinstance(other, _Number): if isinstance(other, _Number):
return __mul_scalar__(self, scalar=other) return __mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other): def __rmul__(self, other):
...@@ -91,27 +87,23 @@ class Symbol(SymbolBase): ...@@ -91,27 +87,23 @@ class Symbol(SymbolBase):
return __div_symbol__(self, other) return __div_symbol__(self, other)
if isinstance(other, _Number): if isinstance(other, _Number):
return __div_scalar__(self, scalar=other) return __div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __rdiv__(self, other): def __rdiv__(self, other):
if isinstance(other, _Number): if isinstance(other, _Number):
return __rdiv_scalar__(self, scalar=other) return __rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __lshift__(self, other): def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y""" """x.__lshift__(y) <=> x << y"""
if isinstance(other, _Number): if isinstance(other, _Number):
return __lshift_scalar__(self, scalar=other) return __lshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __rshift__(self, other): def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y""" """x.__rshift__(y) <=> x >> y"""
if isinstance(other, _Number): if isinstance(other, _Number):
return __rshift_scalar__(self, scalar=other) return __rshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other): def __truediv__(self, other):
...@@ -126,13 +118,11 @@ class Symbol(SymbolBase): ...@@ -126,13 +118,11 @@ class Symbol(SymbolBase):
return __pow_symbol__(self, other) return __pow_symbol__(self, other)
if isinstance(other, _Number): if isinstance(other, _Number):
return __pow_scalar__(self, scalar=other) return __pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __rpow__(self, other): def __rpow__(self, other):
if isinstance(other, _Number): if isinstance(other, _Number):
return __rpow_scalar__(self, scalar=other) return __rpow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
def __neg__(self): def __neg__(self):
...@@ -238,11 +228,10 @@ class Symbol(SymbolBase): ...@@ -238,11 +228,10 @@ class Symbol(SymbolBase):
"""internal function to get list option""" """internal function to get list option"""
if option == 'all': if option == 'all':
return _ctypes.c_int(0) return _ctypes.c_int(0)
elif option == 'read_only': if option == 'read_only':
return _ctypes.c_int(1) return _ctypes.c_int(1)
elif option == 'aux_state': if option == 'aux_state':
return _ctypes.c_int(2) return _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}") raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
def list_input_variables(self, option='all'): def list_input_variables(self, option='all'):
......
...@@ -23,10 +23,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, ...@@ -23,10 +23,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def Pooling(data, kernel, stride, pad, pool_type, name): def Pooling(data, kernel, stride, pad, pool_type, name):
if pool_type == 'max': if pool_type == 'max':
return sym.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name) return sym.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name)
elif pool_type == 'avg': if pool_type == 'avg':
return sym.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name, return sym.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name,
count_include_pad=True) count_include_pad=True)
else:
raise ValueError("Invalid pooling type: " + pool_type) raise ValueError("Invalid pooling type: " + pool_type)
def Inception7A(data, def Inception7A(data,
......
...@@ -88,7 +88,6 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets): ...@@ -88,7 +88,6 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets):
before_correct_dets.append(detection) before_correct_dets.append(detection)
dets.extend(_correct_boxes(before_correct_dets, im_shape[0], im_shape[1], dets.extend(_correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
net_shape[0], net_shape[1], relative)) net_shape[0], net_shape[1], relative))
return
def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets): def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
data = l['output'] data = l['output']
...@@ -114,7 +113,6 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets): ...@@ -114,7 +113,6 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
_correct_boxes(before_correct_dets, im_shape[0], im_shape[1], _correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
net_shape[0], net_shape[1], relative) net_shape[0], net_shape[1], relative)
dets.extend(before_correct_dets) dets.extend(before_correct_dets)
return
def fill_network_boxes(net_shape, im_shape, def fill_network_boxes(net_shape, im_shape,
thresh, relative, tvm_out): thresh, relative, tvm_out):
......
...@@ -129,13 +129,12 @@ class AttrDict(object): ...@@ -129,13 +129,12 @@ class AttrDict(object):
lowercase = self[key].lower() lowercase = self[key].lower()
if lowercase == "1": if lowercase == "1":
return True return True
elif lowercase == "0": if lowercase == "0":
return False return False
elif lowercase == "true": if lowercase == "true":
return True return True
elif lowercase == "false": if lowercase == "false":
return False return False
else:
raise ValueError("Wrong bool format for key %s" % key) raise ValueError("Wrong bool format for key %s" % key)
def get_str(self, key): def get_str(self, key):
......
...@@ -32,7 +32,6 @@ else: ...@@ -32,7 +32,6 @@ else:
class TVMError(Exception): class TVMError(Exception):
"""Error thrown by TVM function""" """Error thrown by TVM function"""
pass
def _load_lib(): def _load_lib():
......
...@@ -51,7 +51,6 @@ class Function(_FunctionBase): ...@@ -51,7 +51,6 @@ class Function(_FunctionBase):
tvm.register_func: How to register global function. tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function. tvm.get_global_func: How to get global function.
""" """
pass
class ModuleBase(object): class ModuleBase(object):
...@@ -207,10 +206,10 @@ def get_global_func(name, allow_missing=False): ...@@ -207,10 +206,10 @@ def get_global_func(name, allow_missing=False):
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value: if handle.value:
return Function(handle, False) return Function(handle, False)
else:
if allow_missing: if allow_missing:
return None return None
else:
raise ValueError("Cannot find global function %s" % name) raise ValueError("Cannot find global function %s" % name)
......
...@@ -36,16 +36,16 @@ def convert_to_node(value): ...@@ -36,16 +36,16 @@ def convert_to_node(value):
""" """
if isinstance(value, _CLASS_NODE_BASE): if isinstance(value, _CLASS_NODE_BASE):
return value return value
elif isinstance(value, bool): if isinstance(value, bool):
return const(value, 'uint1x1') return const(value, 'uint1x1')
elif isinstance(value, Number): if isinstance(value, Number):
return const(value) return const(value)
elif isinstance(value, string_types): if isinstance(value, string_types):
return _api_internal._str(value) return _api_internal._str(value)
elif isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value] value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value) return _api_internal._Array(*value)
elif isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and if (not isinstance(item[0], _CLASS_NODE_BASE) and
...@@ -54,11 +54,11 @@ def convert_to_node(value): ...@@ -54,11 +54,11 @@ def convert_to_node(value):
vlist.append(item[0]) vlist.append(item[0])
vlist.append(convert_to_node(item[1])) vlist.append(convert_to_node(item[1]))
return _api_internal._Map(*vlist) return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric): if isinstance(value, NodeGeneric):
return value.asnode() return value.asnode()
elif value is None: if value is None:
return None return None
else:
raise ValueError("don't know how to convert type %s to node" % type(value)) raise ValueError("don't know how to convert type %s to node" % type(value))
......
...@@ -31,11 +31,11 @@ class IntervalSet(IntSet): ...@@ -31,11 +31,11 @@ class IntervalSet(IntSet):
@register_node @register_node
class StrideSet(IntSet): class StrideSet(IntSet):
"""Represent set of strided integers""" """Represent set of strided integers"""
pass
@register_node @register_node
class ModularSet(IntSet): class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """ """Represent range of (coeff * x + base) for x in Z """
pass
_init_api("tvm.arith") _init_api("tvm.arith")
...@@ -69,15 +69,14 @@ class Future(object): ...@@ -69,15 +69,14 @@ class Future(object):
class FutureError(RuntimeError): class FutureError(RuntimeError):
"""Base error class of all future events""" """Base error class of all future events"""
pass
# pylint:disable=redefined-builtin # pylint:disable=redefined-builtin
class TimeoutError(FutureError): class TimeoutError(FutureError):
"""Error raised when a task is timeout.""" """Error raised when a task is timeout."""
pass
class ExecutionError(FutureError): class ExecutionError(FutureError):
""" """
Error raised when future execution crashes or failed. Error raised when future execution crashes or failed.
""" """
pass
...@@ -83,7 +83,7 @@ def encode(inp, result, protocol='json'): ...@@ -83,7 +83,7 @@ def encode(inp, result, protocol='json'):
"v": AUTOTVM_LOG_VERSION "v": AUTOTVM_LOG_VERSION
} }
return json.dumps(json_dict) return json.dumps(json_dict)
elif protocol == 'pickle': if protocol == 'pickle':
row = (str(inp.target), row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name, str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args, inp.task.args,
...@@ -92,7 +92,7 @@ def encode(inp, result, protocol='json'): ...@@ -92,7 +92,7 @@ def encode(inp, result, protocol='json'):
str(base64.b64encode(pickle.dumps(inp.config)).decode()), str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode())) str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
return '\t'.join(row) return '\t'.join(row)
else:
raise RuntimeError("Invalid log protocol: " + protocol) raise RuntimeError("Invalid log protocol: " + protocol)
...@@ -136,7 +136,7 @@ def decode(row, protocol='json'): ...@@ -136,7 +136,7 @@ def decode(row, protocol='json'):
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]]) result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
return inp, result return inp, result
elif protocol == 'pickle': if protocol == 'pickle':
items = row.split("\t") items = row.split("\t")
tgt = _target.create(items[0]) tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode())) task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
...@@ -146,7 +146,7 @@ def decode(row, protocol='json'): ...@@ -146,7 +146,7 @@ def decode(row, protocol='json'):
tsk = task.Task(task_tuple[0], task_tuple[1]) tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3] tsk.workload = task_tuple[3]
return MeasureInput(tgt, tsk, config), MeasureResult(*result) return MeasureInput(tgt, tsk, config), MeasureResult(*result)
else:
raise RuntimeError("Invalid log protocol: " + protocol) raise RuntimeError("Invalid log protocol: " + protocol)
......
...@@ -32,7 +32,6 @@ class InstantiationError(ValueError): ...@@ -32,7 +32,6 @@ class InstantiationError(ValueError):
raised by cfg.raise_error raised by cfg.raise_error
e.g. too many unrolling, too many threads in a block e.g. too many unrolling, too many threads in a block
""" """
pass
class TransformSpace(object): class TransformSpace(object):
...@@ -321,7 +320,7 @@ class ReorderSpace(TransformSpace): ...@@ -321,7 +320,7 @@ class ReorderSpace(TransformSpace):
if np.sum(tmp_pt) == size: if np.sum(tmp_pt) == size:
merged.append(list(tmp_stack)) merged.append(list(tmp_stack))
return return
else:
for i in range(len(chains)): for i in range(len(chains)):
# use i == np.argmax(....) here to take spatial order into consideration # use i == np.argmax(....) here to take spatial order into consideration
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....) # if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
...@@ -441,7 +440,7 @@ class AnnotateSpace(TransformSpace): ...@@ -441,7 +440,7 @@ class AnnotateSpace(TransformSpace):
if now == self.num_axis: if now == self.num_axis:
# only vectorize inner most dimension # only vectorize inner most dimension
vec_ct = tmp_stack.count('vec') vec_ct = tmp_stack.count('vec')
if vec_ct == 0 or vec_ct == 1: if vec_ct in (0, 1):
self.entities.append(AnnotateEntity(list(tmp_stack))) self.entities.append(AnnotateEntity(list(tmp_stack)))
else: else:
for ann in self.anns[now]: for ann in self.anns[now]:
......
...@@ -294,7 +294,7 @@ def get_config(): ...@@ -294,7 +294,7 @@ def get_config():
class FlopCalculationError(RuntimeError): class FlopCalculationError(RuntimeError):
"""Error happens when estimating FLOP for a compute op""" """Error happens when estimating FLOP for a compute op"""
pass
def compute_flop(sch): def compute_flop(sch):
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule """Calculate number of FLOP (floating number operations) of the compute ops in a schedule
...@@ -328,13 +328,13 @@ def compute_flop(sch): ...@@ -328,13 +328,13 @@ def compute_flop(sch):
if len(source) != 1: if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op") raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
return 0 return 0
elif isinstance(exp, expr.Cast): if isinstance(exp, expr.Cast):
return _count_flop(exp.value) return _count_flop(exp.value)
elif isinstance(exp, expr.Var): if isinstance(exp, expr.Var):
return 0 return 0
elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod, if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
expr.Max, expr.Min, expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)): expr.And, expr.Or, expr.Not)):
...@@ -344,12 +344,12 @@ def compute_flop(sch): ...@@ -344,12 +344,12 @@ def compute_flop(sch):
return base + _count_flop(exp.a) return base + _count_flop(exp.a)
return base + _count_flop(exp.a) + _count_flop(exp.b) return base + _count_flop(exp.a) + _count_flop(exp.b)
elif isinstance(exp, expr.Select): if isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value), return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value)) _count_flop(exp.false_value))
elif isinstance(exp, expr.Call): if isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args]) return sum([_count_flop(x) for x in exp.args])
else:
raise FlopCalculationError("Found unsupported operator in the compute expr") raise FlopCalculationError("Found unsupported operator in the compute expr")
def traverse(ops): def traverse(ops):
......
...@@ -69,7 +69,7 @@ class Tuner(object): ...@@ -69,7 +69,7 @@ class Tuner(object):
results: Array of autotvm.measure.MeasureResult results: Array of autotvm.measure.MeasureResult
result for measurement result for measurement
""" """
pass
def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
"""Begin tuning """Begin tuning
......
...@@ -90,7 +90,7 @@ class Range(NodeBase): ...@@ -90,7 +90,7 @@ class Range(NodeBase):
You do not need to create Range explicitly. You do not need to create Range explicitly.
Python list and tuple will be converted automatically to Range in api functions. Python list and tuple will be converted automatically to Range in api functions.
""" """
pass
@register_node @register_node
class LoweredFunc(NodeBase): class LoweredFunc(NodeBase):
......
...@@ -151,14 +151,14 @@ def find_libdevice_path(arch): ...@@ -151,14 +151,14 @@ def find_libdevice_path(arch):
selected_ver = 0 selected_ver = 0
selected_path = None selected_path = None
cuda_ver = get_cuda_version(cuda_path) cuda_ver = get_cuda_version(cuda_path)
if cuda_ver == 9.0 or cuda_ver == 9.1: if cuda_ver in (9.0, 9.1):
path = os.path.join(lib_path, "libdevice.10.bc") path = os.path.join(lib_path, "libdevice.10.bc")
else: else:
for fn in os.listdir(lib_path): for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"): if not fn.startswith("libdevice"):
continue continue
ver = int(fn.split(".")[-3].split("_")[-1]) ver = int(fn.split(".")[-3].split("_")[-1])
if ver > selected_ver and ver <= arch: if selected_ver < ver <= arch:
selected_ver = ver selected_ver = ver
selected_path = fn selected_path = fn
if selected_path is None: if selected_path is None:
......
...@@ -118,7 +118,6 @@ def _find_vpi_path(): ...@@ -118,7 +118,6 @@ def _find_vpi_path():
vpi_found = [p for p in vpi_path if os.path.exists(p) and os.path.isfile(p)] vpi_found = [p for p in vpi_path if os.path.exists(p) and os.path.isfile(p)]
if vpi_found: if vpi_found:
return os.path.dirname(vpi_found[0]) return os.path.dirname(vpi_found[0])
else:
raise ValueError("Cannot find tvm_vpi.vpi, make sure you did `make verilog`") raise ValueError("Cannot find tvm_vpi.vpi, make sure you did `make verilog`")
def search_path(): def search_path():
......
...@@ -189,9 +189,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -189,9 +189,9 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name) _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
return entry return entry
elif ty is Symbol.ConstVar: if ty is Symbol.ConstVar:
return entry if isinstance(node.ctx, ast.Load) else None return entry if isinstance(node.ctx, ast.Load) else None
elif ty is Symbol.BufferVar: if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index) _expr.Call.Halide, entry.op, entry.value_index)
...@@ -274,7 +274,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -274,7 +274,7 @@ class HybridParser(ast.NodeVisitor):
buf, args = lhs buf, args = lhs
return _make.Provide(buf.op, 0, rhs, args) return _make.Provide(buf.op, 0, rhs, args)
return util.make_nop() return util.make_nop()
else:
lhs, args = self.visit(lhs) lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \ _internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!") "An array access's LHS is expected to be a expr.Call!")
...@@ -347,7 +347,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -347,7 +347,7 @@ class HybridParser(ast.NodeVisitor):
if isinstance(cond, _expr.UIntImm): if isinstance(cond, _expr.UIntImm):
if cond.value: if cond.value:
return visit_list_to_block(self.visit, node.body) return visit_list_to_block(self.visit, node.body)
elif node.orelse: if node.orelse:
return visit_list_to_block(self.visit, node.orelse) return visit_list_to_block(self.visit, node.orelse)
return util.make_nop() return util.make_nop()
...@@ -451,7 +451,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -451,7 +451,7 @@ class HybridParser(ast.NodeVisitor):
bodies.append(body) bodies.append(body)
return concat_list_to_block(bodies) return concat_list_to_block(bodies)
elif iter_var is None: if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, 'int32')): if not _ir_pass.Equal(low, _api.const(0, 'int32')):
......
...@@ -60,7 +60,7 @@ def replace_io(body, rmap): ...@@ -60,7 +60,7 @@ def replace_io(body, rmap):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func] buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args) return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys(): if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func] buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \ return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index) _expr.Call.Halide, buf.op, buf.value_index)
......
...@@ -495,7 +495,7 @@ def _rule_float_suffix(op): ...@@ -495,7 +495,7 @@ def _rule_float_suffix(op):
""" """
if op.dtype == "float32": if op.dtype == "float32":
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
elif op.dtype == "float64": if op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args) return call_pure_extern(op.dtype, op.name, *op.args)
return op return op
......
...@@ -56,7 +56,7 @@ def static_cast(dtype, expr): ...@@ -56,7 +56,7 @@ def static_cast(dtype, expr):
if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits: if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
if src_type.lanes == target_type.lanes: if src_type.lanes == target_type.lanes:
return expr return expr
elif src_type.lanes == 1 and target_type.lanes > 1: if src_type.lanes == 1 and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes) return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr) return Cast(dtype, expr)
......
...@@ -23,7 +23,6 @@ class NDArray(NDArrayBase): ...@@ -23,7 +23,6 @@ class NDArray(NDArrayBase):
Instead, this is a minimal data structure to demonstrate Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers. how can we use TVM in existing project which might have their own array containers.
""" """
pass
def cpu(dev_id=0): def cpu(dev_id=0):
......
...@@ -43,8 +43,8 @@ try: ...@@ -43,8 +43,8 @@ try:
from antlr4.tree.Tree import TerminalNode from antlr4.tree.Tree import TerminalNode
except ImportError: except ImportError:
raise ParseError("Couldn't find ANTLR runtime." + raise ParseError("Couldn't find ANTLR runtime." +
"Try running `pip{} install antlr4-python{}-runtime`." "Try running `pip{version} install antlr4-python{version}-runtime`."
.format(PYTHON_VERSION, PYTHON_VERSION)) .format(version=PYTHON_VERSION))
BINARY_OPS = { BINARY_OPS = {
RelayParser.MUL: op.multiply, RelayParser.MUL: op.multiply,
...@@ -179,32 +179,30 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -179,32 +179,30 @@ class ParseTreeToRelayIR(RelayVisitor):
# variables # variables
if node_type == RelayLexer.GLOBAL_VAR: if node_type == RelayLexer.GLOBAL_VAR:
return lookup(deque([self.global_var_scope]), node_text[1:]) return lookup(deque([self.global_var_scope]), node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR: if node_type == RelayLexer.LOCAL_VAR:
# Remove the leading '%' and lookup the name. # Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name) var = lookup(self.var_scopes, name)
if var is None: if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name)) raise ParseError("Couldn't resolve `{}`.".format(name))
return var return var
elif node_type == RelayLexer.GRAPH_VAR: if node_type == RelayLexer.GRAPH_VAR:
try: try:
return self.graph_expr[int(name)] return self.graph_expr[int(name)]
except IndexError: except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name)) raise ParseError("Couldn't resolve `{}`".format(name))
# data types # data types
elif node_type == RelayLexer.NAT: if node_type == RelayLexer.NAT:
return int(node_text) return int(node_text)
elif node_type == RelayLexer.FLOAT: if node_type == RelayLexer.FLOAT:
return float(node_text) return float(node_text)
elif node_type == RelayLexer.BOOL_LIT: if node_type == RelayLexer.BOOL_LIT:
if node_text == "True": if node_text == "True":
return True return True
elif node_text == "False": if node_text == "False":
return False return False
else:
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text)) raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
else:
raise ParseError("todo: {}".format(node_text)) raise ParseError("todo: {}".format(node_text))
def visit_list(self, ctx_list): def visit_list(self, ctx_list):
......
...@@ -8,7 +8,7 @@ from .expr import Expr, Call ...@@ -8,7 +8,7 @@ from .expr import Expr, Call
class Pattern(RelayNode): class Pattern(RelayNode):
"""Base type for pattern matching constructs.""" """Base type for pattern matching constructs."""
pass
@register_relay_node @register_relay_node
class PatternWildcard(Pattern): class PatternWildcard(Pattern):
......
...@@ -10,7 +10,6 @@ from . import _backend ...@@ -10,7 +10,6 @@ from . import _backend
class CachedFunc(NodeBase): class CachedFunc(NodeBase):
"""Low-level tensor function to back a relay primitive function. """Low-level tensor function to back a relay primitive function.
""" """
pass
@register_relay_node @register_relay_node
...@@ -34,7 +33,6 @@ class CCacheKey(NodeBase): ...@@ -34,7 +33,6 @@ class CCacheKey(NodeBase):
class CCacheValue(NodeBase): class CCacheValue(NodeBase):
"""Value in the CompileEngine, including usage statistics. """Value in the CompileEngine, including usage statistics.
""" """
pass
def _get_cache_key(source_func, target): def _get_cache_key(source_func, target):
......
...@@ -49,7 +49,6 @@ class TupleValue(Value): ...@@ -49,7 +49,6 @@ class TupleValue(Value):
@register_relay_node @register_relay_node
class Closure(Value): class Closure(Value):
"""A closure produced by the interpreter.""" """A closure produced by the interpreter."""
pass
@register_relay_node @register_relay_node
......
...@@ -444,7 +444,6 @@ def create_executor(kind="debug", ...@@ -444,7 +444,6 @@ def create_executor(kind="debug",
target = _target.create(target) target = _target.create(target)
if kind == "debug": if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target) return _interpreter.Interpreter(mod, ctx, target)
elif kind == "graph": if kind == "graph":
return GraphExecutor(mod, ctx, target) return GraphExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown mode {0}".format(mode)) raise RuntimeError("unknown mode {0}".format(mode))
...@@ -15,7 +15,6 @@ def dimension_picker(prefix, surfix=''): ...@@ -15,7 +15,6 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
...@@ -104,7 +103,6 @@ class Caffe2OpConverter(object): ...@@ -104,7 +103,6 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'): if hasattr(cls, '_impl'):
return getattr(cls, '_impl') return getattr(cls, '_impl')
else:
raise NotImplementedError('{} not implemented'.format( raise NotImplementedError('{} not implemented'.format(
cls.__name__)) cls.__name__))
...@@ -234,9 +232,8 @@ class Concat(Caffe2OpConverter): ...@@ -234,9 +232,8 @@ class Concat(Caffe2OpConverter):
order = order if isinstance(order, str) else order.decode('UTF-8') order = order if isinstance(order, str) else order.decode('UTF-8')
if order == 'NCHW': if order == 'NCHW':
return 1 return 1
elif order == 'NHWC': if order == 'NHWC':
return 3 return 3
else:
raise RuntimeError( raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order)) "Unsupported storage order: {} in caffe2".format(order))
......
...@@ -10,7 +10,6 @@ from .. import op as _op ...@@ -10,7 +10,6 @@ from .. import op as _op
class RequiredAttr(object): class RequiredAttr(object):
"""Dummpy class to represent required attr""" """Dummpy class to represent required attr"""
pass
class StrAttrsDict(object): class StrAttrsDict(object):
......
...@@ -100,37 +100,37 @@ def _ActivationParams(op, inexpr, etab): ...@@ -100,37 +100,37 @@ def _ActivationParams(op, inexpr, etab):
alpha = _expr.const(par.alpha, dtype='float32') alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32') beta = _expr.const(par.beta, dtype='float32')
return _op.add(_op.multiply(inexpr, alpha), beta) return _op.add(_op.multiply(inexpr, alpha), beta)
elif whichActivation == 'ReLU': if whichActivation == 'ReLU':
return _op.nn.relu(inexpr) return _op.nn.relu(inexpr)
elif whichActivation == 'leakyReLU': if whichActivation == 'leakyReLU':
_op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype='float32')) _op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'thresholdedReLU': elif whichActivation == 'thresholdedReLU':
alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype='float32')) alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype='float32'))
return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type('float32')) return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type('float32'))
elif whichActivation == 'PReLU': if whichActivation == 'PReLU':
return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype='float32')) return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'tanh': if whichActivation == 'tanh':
return _op.tanh(inexpr) return _op.tanh(inexpr)
elif whichActivation == 'scaledTanh': if whichActivation == 'scaledTanh':
alpha = _expr.const(par.alpha, dtype='float32') alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32') beta = _expr.const(par.beta, dtype='float32')
return _op.multiply(_op.tanh(_op.multiply(inexpr, beta)), alpha) return _op.multiply(_op.tanh(_op.multiply(inexpr, beta)), alpha)
elif whichActivation == 'sigmoid': if whichActivation == 'sigmoid':
return _op.sigmoid(inexpr) return _op.sigmoid(inexpr)
elif whichActivation == 'sigmoidHard': if whichActivation == 'sigmoidHard':
alpha = _expr.const(par.alpha, dtype='float32') alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32') beta = _expr.const(par.beta, dtype='float32')
transformX = (alpha * inexpr) + beta transformX = (alpha * inexpr) + beta
return _op.clip(transformX, a_min=0., a_max=1.) return _op.clip(transformX, a_min=0., a_max=1.)
elif whichActivation == 'ELU': if whichActivation == 'ELU':
return _op.multiply(_op.add(_op.exp(inexpr), _expr.const(-1, dtype='float32')), return _op.multiply(_op.add(_op.exp(inexpr), _expr.const(-1, dtype='float32')),
_expr.const(par.alpha, dtype='float32')) _expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'softsign': if whichActivation == 'softsign':
return inexpr / (_expr.const(1, dtype='float32') + ( return inexpr / (_expr.const(1, dtype='float32') + (
op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr)))) op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr))))
elif whichActivation == 'softplus': if whichActivation == 'softplus':
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype='float32'))) return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype='float32')))
elif whichActivation == 'parametricSoftplus': if whichActivation == 'parametricSoftplus':
alpha = list(par.alpha.floatValue) alpha = list(par.alpha.floatValue)
beta = list(par.alpha.floatValue) beta = list(par.alpha.floatValue)
if len(alpha) == 1: if len(alpha) == 1:
...@@ -142,7 +142,6 @@ def _ActivationParams(op, inexpr, etab): ...@@ -142,7 +142,6 @@ def _ActivationParams(op, inexpr, etab):
alpha_expr = etab.new_const(alpha) alpha_expr = etab.new_const(alpha)
beta_expr = etab.new_const(beta) beta_expr = etab.new_const(beta)
return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr) return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr)
else:
raise NotImplementedError('%s not implemented' % whichActivation) raise NotImplementedError('%s not implemented' % whichActivation)
...@@ -163,9 +162,8 @@ def _PoolingLayerParams(op, inexpr, etab): ...@@ -163,9 +162,8 @@ def _PoolingLayerParams(op, inexpr, etab):
if op.globalPooling: if op.globalPooling:
if op.type == 0: if op.type == 0:
return _op.nn.global_max_pool2d(inexpr) return _op.nn.global_max_pool2d(inexpr)
elif op.type == 1: if op.type == 1:
return _op.nn.global_avg_pool2d(inexpr) return _op.nn.global_avg_pool2d(inexpr)
else:
raise NotImplementedError("Only max and average pooling implemented") raise NotImplementedError("Only max and average pooling implemented")
else: else:
...@@ -196,9 +194,8 @@ def _PoolingLayerParams(op, inexpr, etab): ...@@ -196,9 +194,8 @@ def _PoolingLayerParams(op, inexpr, etab):
if op.type == 0: if op.type == 0:
return _op.nn.max_pool2d(inexpr, **params) return _op.nn.max_pool2d(inexpr, **params)
elif op.type == 1: if op.type == 1:
return _op.nn.avg_pool2d(inexpr, **params) return _op.nn.avg_pool2d(inexpr, **params)
else:
raise NotImplementedError("Only max and average pooling implemented") raise NotImplementedError("Only max and average pooling implemented")
......
...@@ -60,21 +60,21 @@ def _convert_activation(inexpr, keras_layer, _): ...@@ -60,21 +60,21 @@ def _convert_activation(inexpr, keras_layer, _):
alpha = _expr.const(alpha, dtype='float32') alpha = _expr.const(alpha, dtype='float32')
beta = _expr.const(beta, dtype='float32') beta = _expr.const(beta, dtype='float32')
return _op.add(_op.multiply(inexpr, alpha), beta) return _op.add(_op.multiply(inexpr, alpha), beta)
elif act_type == 'softmax': if act_type == 'softmax':
return _op.nn.softmax(inexpr, axis=1) return _op.nn.softmax(inexpr, axis=1)
elif act_type == 'sigmoid': if act_type == 'sigmoid':
return _op.sigmoid(inexpr) return _op.sigmoid(inexpr)
elif act_type == 'tanh': if act_type == 'tanh':
return _op.tanh(inexpr) return _op.tanh(inexpr)
elif act_type == 'relu': if act_type == 'relu':
return _op.nn.relu(inexpr) return _op.nn.relu(inexpr)
elif act_type == 'softplus': if act_type == 'softplus':
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32'))) return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
elif act_type == 'elu': if act_type == 'elu':
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1. alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
alpha = _expr.const(alpha, dtype='float32') alpha = _expr.const(alpha, dtype='float32')
return _get_elu(inexpr, alpha) return _get_elu(inexpr, alpha)
elif act_type == 'selu': if act_type == 'selu':
# Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515 # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \ alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
else 1.6732632423543772848170429916717 else 1.6732632423543772848170429916717
...@@ -83,14 +83,14 @@ def _convert_activation(inexpr, keras_layer, _): ...@@ -83,14 +83,14 @@ def _convert_activation(inexpr, keras_layer, _):
alpha = _expr.const(alpha, dtype='float32') alpha = _expr.const(alpha, dtype='float32')
gamma = _expr.const(gamma, dtype='float32') gamma = _expr.const(gamma, dtype='float32')
return gamma * _get_elu(inexpr, alpha) return gamma * _get_elu(inexpr, alpha)
elif act_type == 'relu6': if act_type == 'relu6':
return _op.clip(inexpr, a_min=0., a_max=6.) return _op.clip(inexpr, a_min=0., a_max=6.)
elif act_type == 'softsign': if act_type == 'softsign':
return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr)) return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
elif act_type == 'hard_sigmoid': if act_type == 'hard_sigmoid':
x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32') x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
return _op.clip(x, a_min=0., a_max=1.) return _op.clip(x, a_min=0., a_max=1.)
else:
raise TypeError("Unsupported activation type : {}".format(act_type)) raise TypeError("Unsupported activation type : {}".format(act_type))
...@@ -100,24 +100,24 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): ...@@ -100,24 +100,24 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
if keras_layer.max_value: if keras_layer.max_value:
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value)) return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
return _op.nn.relu(inexpr) return _op.nn.relu(inexpr)
elif act_type == 'LeakyReLU': if act_type == 'LeakyReLU':
return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha)) return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
elif act_type == 'ELU': if act_type == 'ELU':
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1. alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
alpha = _expr.const(alpha, dtype='float32') alpha = _expr.const(alpha, dtype='float32')
return _get_elu(inexpr, alpha) return _get_elu(inexpr, alpha)
elif act_type == 'PReLU': if act_type == 'PReLU':
assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU." assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
_check_data_format(keras_layer) _check_data_format(keras_layer)
size = len(keras_layer.alpha.shape) size = len(keras_layer.alpha.shape)
alpha = etab.new_const(keras_layer.get_weights()[0] \ alpha = etab.new_const(keras_layer.get_weights()[0] \
.transpose(np.roll(range(size), 1))) .transpose(np.roll(range(size), 1)))
return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr) return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
elif act_type == 'ThresholdedReLU': if act_type == 'ThresholdedReLU':
theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1. theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
return _op.multiply(inexpr, _op.greater(inexpr, \ return _op.multiply(inexpr, _op.greater(inexpr, \
_expr.const(theta, dtype='float32')).astype('float32')) _expr.const(theta, dtype='float32')).astype('float32'))
else:
raise TypeError("Unsupported advanced activation type : {}".format(act_type)) raise TypeError("Unsupported advanced activation type : {}".format(act_type))
...@@ -297,9 +297,8 @@ def _convert_pooling(inexpr, keras_layer, etab): ...@@ -297,9 +297,8 @@ def _convert_pooling(inexpr, keras_layer, etab):
# global pool in keras = global pool + flatten in nnvm/relay # global pool in keras = global pool + flatten in nnvm/relay
if pool_type == 'GlobalMaxPooling2D': if pool_type == 'GlobalMaxPooling2D':
return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab) return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
elif pool_type == 'GlobalAveragePooling2D': if pool_type == 'GlobalAveragePooling2D':
return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab) return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
else:
pool_h, pool_w = keras_layer.pool_size pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w], params = {'pool_size': [pool_h, pool_w],
...@@ -317,10 +316,9 @@ def _convert_pooling(inexpr, keras_layer, etab): ...@@ -317,10 +316,9 @@ def _convert_pooling(inexpr, keras_layer, etab):
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D': if pool_type == 'MaxPooling2D':
return _op.nn.max_pool2d(inexpr, **params) return _op.nn.max_pool2d(inexpr, **params)
elif pool_type == 'AveragePooling2D': if pool_type == 'AveragePooling2D':
params['count_include_pad'] = False params['count_include_pad'] = False
return _op.nn.avg_pool2d(inexpr, **params) return _op.nn.avg_pool2d(inexpr, **params)
else:
raise TypeError("Unsupported pooling type : {}".format(keras_layer)) raise TypeError("Unsupported pooling type : {}".format(keras_layer))
......
...@@ -39,7 +39,7 @@ def _mx_fully_connected(inputs, attrs): ...@@ -39,7 +39,7 @@ def _mx_fully_connected(inputs, attrs):
def _get_channel_axis(layout, op_name): def _get_channel_axis(layout, op_name):
if layout == "NCHW": if layout == "NCHW":
return 1 return 1
elif layout == "NHWC": if layout == "NHWC":
return 3 return 3
raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name)) raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name))
...@@ -49,11 +49,11 @@ def _mx_activations(inputs, attrs): ...@@ -49,11 +49,11 @@ def _mx_activations(inputs, attrs):
assert len(inputs) == 1 assert len(inputs) == 1
if act_type == "sigmoid": if act_type == "sigmoid":
return _op.sigmoid(inputs[0]) return _op.sigmoid(inputs[0])
elif act_type == "tanh": if act_type == "tanh":
return _op.tanh(inputs[0]) return _op.tanh(inputs[0])
elif act_type == "relu": if act_type == "relu":
return _op.nn.relu(inputs[0]) return _op.nn.relu(inputs[0])
elif act_type == "softrelu": if act_type == "softrelu":
def _stable_softrelu(x): def _stable_softrelu(x):
# log(1 + exp(-abs(x))) + relu(x) # log(1 + exp(-abs(x))) + relu(x)
one = _expr.const(1, dtype="float32") one = _expr.const(1, dtype="float32")
...@@ -147,7 +147,7 @@ def _mx_pooling(inputs, attrs): ...@@ -147,7 +147,7 @@ def _mx_pooling(inputs, attrs):
if global_pool: if global_pool:
return _op.nn.global_max_pool2d(inputs[0]) return _op.nn.global_max_pool2d(inputs[0])
return _pool2d(_op.nn.max_pool2d, False) return _pool2d(_op.nn.max_pool2d, False)
elif pool_type == "avg": if pool_type == "avg":
if global_pool: if global_pool:
return _op.nn.global_avg_pool2d(inputs[0]) return _op.nn.global_avg_pool2d(inputs[0])
return _pool2d(_op.nn.avg_pool2d, True) return _pool2d(_op.nn.avg_pool2d, True)
...@@ -209,10 +209,10 @@ def _mx_leaky_relu(inputs, attrs): ...@@ -209,10 +209,10 @@ def _mx_leaky_relu(inputs, attrs):
act_type = attrs.get_str("act_type") act_type = attrs.get_str("act_type")
if act_type == "leaky": if act_type == "leaky":
return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25)) return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25))
elif act_type == "prelu": if act_type == "prelu":
assert len(inputs) == 2 assert len(inputs) == 2
return _op.nn.prelu(*inputs) return _op.nn.prelu(*inputs)
elif act_type == "elu": if act_type == "elu":
# -slope * relu(1-exp(x)) + relu(x) # -slope * relu(1-exp(x)) + relu(x)
slope = attrs.get_float("slope", 0.25) slope = attrs.get_float("slope", 0.25)
one = _expr.const(1, dtype="float32") one = _expr.const(1, dtype="float32")
...@@ -220,7 +220,7 @@ def _mx_leaky_relu(inputs, attrs): ...@@ -220,7 +220,7 @@ def _mx_leaky_relu(inputs, attrs):
mslope = _op.nn.relu(_op.subtract(one, _op.exp(x))) mslope = _op.nn.relu(_op.subtract(one, _op.exp(x)))
mslope = _op.multiply(mslope, _expr.const(-slope, dtype="float32")) mslope = _op.multiply(mslope, _expr.const(-slope, dtype="float32"))
return _op.add(mslope, _op.nn.relu(x)) return _op.add(mslope, _op.nn.relu(x))
elif act_type == "rrelu": if act_type == "rrelu":
# NOTE this is only converted for inference. # NOTE this is only converted for inference.
lower_bound = attrs.get_float("lower_bound") lower_bound = attrs.get_float("lower_bound")
upper_bound = attrs.get_float("upper_bound") upper_bound = attrs.get_float("upper_bound")
......
...@@ -18,7 +18,6 @@ def dimension_picker(prefix, surfix=''): ...@@ -18,7 +18,6 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
......
...@@ -175,7 +175,6 @@ def _dimension_picker(prefix, surfix=''): ...@@ -175,7 +175,6 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
...@@ -522,7 +521,6 @@ def _reshape(): ...@@ -522,7 +521,6 @@ def _reshape():
op_name="reshape", op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().flatten())}, extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr) ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.") raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl return _impl
...@@ -1385,7 +1383,7 @@ class GraphProto(object): ...@@ -1385,7 +1383,7 @@ class GraphProto(object):
shape=self._params[name].shape, shape=self._params[name].shape,
dtype=self._params[name].dtype)] dtype=self._params[name].dtype)]
else: else:
if key != 'dtype' and key != '_output_shapes' and key != '_class': if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \ raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key)) ("Other attributes for a Const(param) Node {} ? .".format(key))
......
...@@ -126,13 +126,12 @@ class OperatorConverter(object): ...@@ -126,13 +126,12 @@ class OperatorConverter(object):
if tensor_wrapper.tensor.Type() == TensorType.UINT8: if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape( return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy()) tensor_wrapper.tensor.ShapeAsNumpy())
elif tensor_wrapper.tensor.Type() == TensorType.FLOAT32: if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape( return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy()) tensor_wrapper.tensor.ShapeAsNumpy())
elif tensor_wrapper.tensor.Type() == TensorType.INT32: if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy()) tensor_wrapper.tensor.ShapeAsNumpy())
else:
raise NotImplementedError("Not support tensor type {}" raise NotImplementedError("Not support tensor type {}"
.format(str(tensor_wrapper.tensor.Type()))) .format(str(tensor_wrapper.tensor.Type())))
...@@ -145,11 +144,10 @@ class OperatorConverter(object): ...@@ -145,11 +144,10 @@ class OperatorConverter(object):
if tensor_type == TensorType.UINT8: if tensor_type == TensorType.UINT8:
return "uint8" return "uint8"
elif tensor_type == TensorType.FLOAT32: if tensor_type == TensorType.FLOAT32:
return "float32" return "float32"
elif tensor_type == TensorType.INT32: if tensor_type == TensorType.INT32:
return "int32" return "int32"
else:
raise NotImplementedError("Not support tensor type {}".format(str(tensor_type))) raise NotImplementedError("Not support tensor type {}".format(str(tensor_type)))
def convert_conv2d(self, op): def convert_conv2d(self, op):
...@@ -192,7 +190,7 @@ class OperatorConverter(object): ...@@ -192,7 +190,7 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
if input_shape_length == 1 or input_shape_length == 2: if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W). # The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing. # length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing. # length of 2 means N*H*W, C, do nothing.
...@@ -275,7 +273,7 @@ class OperatorConverter(object): ...@@ -275,7 +273,7 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
# TFLite is N H W C, our layout is N C H W # TFLite is N H W C, our layout is N C H W
if input_shape_length == 1 or input_shape_length == 2: if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W). # The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing. # length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing. # length of 2 means N*H*W, C, do nothing.
...@@ -299,7 +297,7 @@ class OperatorConverter(object): ...@@ -299,7 +297,7 @@ class OperatorConverter(object):
# 3: N H W C, reshape to N H*W C, transpose to N C H*W # 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W # 4: N H W C, transpose to N C H W
# add more if we need target shapes in future # add more if we need target shapes in future
if output_shape_length == 1 or output_shape_length == 2: if output_shape_length in (1, 2):
pass pass
elif output_shape_length == 3: elif output_shape_length == 3:
out = _op.transpose(out, axes=(0, 2, 1)) out = _op.transpose(out, axes=(0, 2, 1))
...@@ -320,13 +318,12 @@ class OperatorConverter(object): ...@@ -320,13 +318,12 @@ class OperatorConverter(object):
assert fused_activation_fn != ActivationFunctionType.NONE assert fused_activation_fn != ActivationFunctionType.NONE
if fused_activation_fn == ActivationFunctionType.RELU6: if fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(in_expr, a_min=0, a_max=6) return _op.clip(in_expr, a_min=0, a_max=6)
elif fused_activation_fn == ActivationFunctionType.RELU: if fused_activation_fn == ActivationFunctionType.RELU:
return _op.nn.relu(in_expr) return _op.nn.relu(in_expr)
elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
return _op.clip(in_expr, a_min=-1, a_max=1) return _op.clip(in_expr, a_min=-1, a_max=1)
elif fused_activation_fn == ActivationFunctionType.TANH: if fused_activation_fn == ActivationFunctionType.TANH:
return _op.tanh(in_expr) return _op.tanh(in_expr)
else:
fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise NotImplementedError("Unsupported fused activation fn {}" raise NotImplementedError("Unsupported fused activation fn {}"
.format(fused_activation_fn_str)) .format(fused_activation_fn_str))
...@@ -401,7 +398,7 @@ class OperatorConverter(object): ...@@ -401,7 +398,7 @@ class OperatorConverter(object):
# weight tensor type should be UINT8 (quantization) or FLOAT32 # weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type() weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type == TensorType.UINT8 or weight_tensor_type == TensorType.FLOAT32 assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
in_expr = self.get_expr(input_tensor_idx) in_expr = self.get_expr(input_tensor_idx)
...@@ -434,7 +431,7 @@ class OperatorConverter(object): ...@@ -434,7 +431,7 @@ class OperatorConverter(object):
bias_tensor = input_tensors[2] bias_tensor = input_tensors[2]
bias_tensor_type = bias_tensor.tensor.Type() bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32 # bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type == TensorType.INT32 or bias_tensor_type == TensorType.FLOAT32 assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str) dtype=bias_tensor_type_str)
......
...@@ -57,7 +57,7 @@ def compute_conv2d(attrs, inputs, out_type, target): ...@@ -57,7 +57,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
layout = attrs.data_layout layout = attrs.data_layout
kernel_layout = attrs.kernel_layout kernel_layout = attrs.kernel_layout
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype) else out_dtype)
assert layout in ["NCHW", "NHWC", "NCHW4c"] assert layout in ["NCHW", "NHWC", "NCHW4c"]
...@@ -95,15 +95,15 @@ def schedule_conv2d(attrs, outs, target): ...@@ -95,15 +95,15 @@ def schedule_conv2d(attrs, outs, target):
with target: with target:
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NCHW4c": if groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NHWC": if groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_conv2d_nhwc(outs)
elif groups != 1: if groups != 1:
if layout == "NCHW": if layout == "NCHW":
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return topi.generic.schedule_depthwise_conv2d_nchw(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif layout == "NHWC" and kernel_layout == "HWOI": if layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs) return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
raise ValueError("No compatible schedule") raise ValueError("No compatible schedule")
...@@ -127,7 +127,7 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): ...@@ -127,7 +127,7 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
groups = attrs.groups groups = attrs.groups
layout = attrs.data_layout layout = attrs.data_layout
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "") out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype) else out_dtype)
assert layout == "NCHW", "only support nchw for now" assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
......
...@@ -6,19 +6,18 @@ from ..base import register_relay_attr_node ...@@ -6,19 +6,18 @@ from ..base import register_relay_attr_node
@register_relay_attr_node @register_relay_attr_node
class Conv2DAttrs(Attrs): class Conv2DAttrs(Attrs):
"""Attribute of nn.conv2d""" """Attribute of nn.conv2d"""
pass
@register_relay_attr_node @register_relay_attr_node
class Conv2DWinogradAttrs(Attrs): class Conv2DWinogradAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_without_weight_transform""" """Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
pass
@register_relay_attr_node @register_relay_attr_node
class Conv2DWinogradWeightTransformAttrs(Attrs): class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_weight_transform""" """Attribute of nn.contrib_conv2d_winograd_weight_transform"""
pass
@register_relay_attr_node @register_relay_attr_node
class GlobalPool2DAttrs(Attrs): class GlobalPool2DAttrs(Attrs):
"""Attribute of nn.global_pool""" """Attribute of nn.global_pool"""
pass
...@@ -29,10 +29,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, ...@@ -29,10 +29,9 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def Pooling(data, kernel, stride, pad, pool_type, name): def Pooling(data, kernel, stride, pad, pool_type, name):
if pool_type == 'max': if pool_type == 'max':
return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad) return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad)
elif pool_type == 'avg': if pool_type == 'avg':
return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad,
count_include_pad=True) count_include_pad=True)
else:
raise ValueError("Invalid pooling type: " + pool_type) raise ValueError("Invalid pooling type: " + pool_type)
def Inception7A(data, def Inception7A(data,
......
...@@ -172,7 +172,6 @@ class TypeCall(Type): ...@@ -172,7 +172,6 @@ class TypeCall(Type):
@register_relay_node @register_relay_node
class TypeConstraint(Type): class TypeConstraint(Type):
"""Abstract class representing a type constraint.""" """Abstract class representing a type constraint."""
pass
@register_relay_node @register_relay_node
......
...@@ -389,7 +389,7 @@ class ProxyServerHandler(object): ...@@ -389,7 +389,7 @@ class ProxyServerHandler(object):
if key in pool_src: if key in pool_src:
self._pair_up(pool_src.pop(key), handler) self._pair_up(pool_src.pop(key), handler)
return return
elif key not in pool_dst: if key not in pool_dst:
pool_dst[key] = handler pool_dst[key] = handler
def cleanup(): def cleanup():
"""Cleanup client connection if timeout""" """Cleanup client connection if timeout"""
......
...@@ -95,7 +95,6 @@ class TCPHandler(object): ...@@ -95,7 +95,6 @@ class TCPHandler(object):
if msg: if msg:
self.on_message(msg) self.on_message(msg)
return True return True
else:
# normal close, remote is closed # normal close, remote is closed
self.close() self.close()
except socket.error as err: except socket.error as err:
......
...@@ -86,7 +86,7 @@ class Scheduler(object): ...@@ -86,7 +86,7 @@ class Scheduler(object):
value: object value: object
The resource to remove The resource to remove
""" """
pass
def summary(self): def summary(self):
"""Get summary information of the scheduler.""" """Get summary information of the scheduler."""
......
...@@ -143,19 +143,16 @@ class Buffer(NodeBase): ...@@ -143,19 +143,16 @@ class Buffer(NodeBase):
@register_node @register_node
class Split(NodeBase): class Split(NodeBase):
"""Split operation on axis.""" """Split operation on axis."""
pass
@register_node @register_node
class Fuse(NodeBase): class Fuse(NodeBase):
"""Fuse operation on axis.""" """Fuse operation on axis."""
pass
@register_node @register_node
class Singleton(NodeBase): class Singleton(NodeBase):
"""Singleton axis.""" """Singleton axis."""
pass
@register_node @register_node
......
...@@ -381,7 +381,7 @@ def stmt_list(stmt): ...@@ -381,7 +381,7 @@ def stmt_list(stmt):
""" """
if isinstance(stmt, Block): if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest) return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, ProducerConsumer): if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body) return stmt_list(stmt.body)
return [stmt] return [stmt]
......
...@@ -33,7 +33,6 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -33,7 +33,6 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
@register_node @register_node
class TensorIntrinCall(NodeBase): class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic.""" """Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls = None itervar_cls = None
...@@ -144,7 +143,6 @@ class Operation(NodeBase): ...@@ -144,7 +143,6 @@ class Operation(NodeBase):
@register_node @register_node
class PlaceholderOp(Operation): class PlaceholderOp(Operation):
"""Placeholder operation.""" """Placeholder operation."""
pass
@register_node @register_node
...@@ -164,7 +162,6 @@ class ComputeOp(Operation): ...@@ -164,7 +162,6 @@ class ComputeOp(Operation):
@register_node @register_node
class TensorComputeOp(Operation): class TensorComputeOp(Operation):
"""Tensor operation.""" """Tensor operation."""
pass
@register_node @register_node
...@@ -179,7 +176,7 @@ class ScanOp(Operation): ...@@ -179,7 +176,7 @@ class ScanOp(Operation):
@register_node @register_node
class ExternOp(Operation): class ExternOp(Operation):
"""Extern operation.""" """Extern operation."""
pass
@register_node @register_node
class HybridOp(Operation): class HybridOp(Operation):
......
...@@ -61,7 +61,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits ...@@ -61,7 +61,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if out_dtype is None: if out_dtype is None:
out_dtype = data.dtype out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC" assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
if dorefa: if dorefa:
assert layout == "NCHW", "Cannot support dorea with NHWC layout yet" assert layout == "NCHW", "Cannot support dorea with NHWC layout yet"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
......
...@@ -554,7 +554,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -554,7 +554,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout" data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key] layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
if out_dtype == "" or out_dtype == "same": if out_dtype in ("same", ""):
out_dtype = tinfos[0].dtype out_dtype = tinfos[0].dtype
if layout != 'NCHW': if layout != 'NCHW':
......
...@@ -93,9 +93,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -93,9 +93,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
if layout == 'NCHW': if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'HWCN': if layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -362,7 +362,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): ...@@ -362,7 +362,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout" data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key] layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
if out_dtype == "" or out_dtype == "same": if out_dtype in ("", "same"):
out_dtype = tinfos[0].dtype out_dtype = tinfos[0].dtype
data, kernel = tinfos[0:2] data, kernel = tinfos[0:2]
...@@ -428,7 +428,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): ...@@ -428,7 +428,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
) )
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif groups != CI: if groups != CI:
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype], [tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
group_conv2d_nchw) group_conv2d_nchw)
......
...@@ -96,7 +96,7 @@ def schedule_reduce(outs): ...@@ -96,7 +96,7 @@ def schedule_reduce(outs):
"""Internal travserse function""" """Internal travserse function"""
if isinstance(operator, tvm.tensor.PlaceholderOp): if isinstance(operator, tvm.tensor.PlaceholderOp):
return return
elif tag.is_injective(operator.tag): if tag.is_injective(operator.tag):
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops: if tensor.op not in scheduled_ops:
......
...@@ -92,14 +92,14 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits ...@@ -92,14 +92,14 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
if layout == 'NCHW': if layout == 'NCHW':
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
elif layout == 'NHWC': if layout == 'NHWC':
return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding, out_dtype, layout): def _get_workload(data, kernel, stride, padding, out_dtype, layout):
""" Get the workload structure. """ """ Get the workload structure. """
assert layout == "NCHW" or layout == "NHWC", \ assert layout in ("NCHW", "NHWC"), \
"Only support layouts NCHW and NHWC" "Only support layouts NCHW and NHWC"
if layout == "NCHW": if layout == "NCHW":
_, CI, IH, IW = [x.value for x in data.shape] _, CI, IH, IW = [x.value for x in data.shape]
......
...@@ -48,11 +48,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -48,11 +48,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype) return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'HWCN': if layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype) return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'NHWC': if layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype) return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -17,12 +17,11 @@ def upsampling_python(data, scale, layout='NCHW'): ...@@ -17,12 +17,11 @@ def upsampling_python(data, scale, layout='NCHW'):
for c in range(oshape[1]): for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np return output_np
elif layout == 'NHWC': if layout == 'NHWC':
oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3]) oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3])
output_np = np.zeros(oshape, dtype=data.dtype) output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]): for b in range(oshape[0]):
for c in range(oshape[3]): for c in range(oshape[3]):
output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale)
return output_np return output_np
else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
...@@ -59,7 +59,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits ...@@ -59,7 +59,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if out_dtype is None: if out_dtype is None:
out_dtype = data.dtype out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC" assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout) sch = _get_schedule(wkl, layout)
......
...@@ -71,11 +71,10 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out ...@@ -71,11 +71,10 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
_get_default_config(cfg, data, kernel, strides, padding, out_dtype) _get_default_config(cfg, data, kernel, strides, padding, out_dtype)
return _declaration_conv_impl(cfg, data, kernel, strides, return _declaration_conv_impl(cfg, data, kernel, strides,
padding, dilation, layout, out_dtype) padding, dilation, layout, out_dtype)
elif layout == 'HWCN': if layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'NHWC': if layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
......
...@@ -223,9 +223,8 @@ class Environment(object): ...@@ -223,9 +223,8 @@ class Environment(object):
"""The target host""" """The target host"""
if self.TARGET == "pynq": if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf" return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim": if self.TARGET == "sim":
return "llvm" return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET) raise ValueError("Unknown target %s" % self.TARGET)
......
...@@ -169,7 +169,7 @@ def clean_cast(graph): ...@@ -169,7 +169,7 @@ def clean_cast(graph):
op_name = node.attr("op_name") op_name = node.attr("op_name")
if op_name == "cast": if op_name == "cast":
return _clean_cast(node.get_children(), target_type) return _clean_cast(node.get_children(), target_type)
elif op_name == "relu": if op_name == "relu":
data, has_clip = _clean_cast( data, has_clip = _clean_cast(
node.get_children(), target_type) node.get_children(), target_type)
data = nnvm.sym.relu(data) data = nnvm.sym.relu(data)
......
...@@ -64,7 +64,7 @@ def gemm(env, mock=False): ...@@ -64,7 +64,7 @@ def gemm(env, mock=False):
dev.get_task_qid(dev.QID_COMPUTE)) dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop) dev.vta_push_uop)
if index == 0 or index == 2: if index in (0, 2):
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
0, 0, 0, 0,
......
...@@ -77,7 +77,6 @@ def fold_uop_loop(stmt_in): ...@@ -77,7 +77,6 @@ def fold_uop_loop(stmt_in):
args.append(m[1]) args.append(m[1])
args += op.args[base_args+3:] args += op.args[base_args+3:]
return tvm.call_extern("int32", "VTAUopPush", *args) return tvm.call_extern("int32", "VTAUopPush", *args)
else:
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op) raise RuntimeError("unexpected op %s" % op)
return op return op
...@@ -165,21 +164,20 @@ def cpu_access_rewrite(stmt_in): ...@@ -165,21 +164,20 @@ def cpu_access_rewrite(stmt_in):
op.condition, let_stmt) op.condition, let_stmt)
del rw_info[buffer_var] del rw_info[buffer_var]
return alloc return alloc
elif isinstance(op, tvm.expr.Load): if isinstance(op, tvm.expr.Load):
buffer_var = op.buffer_var buffer_var = op.buffer_var
if not buffer_var in rw_info: if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var( rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle") buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var] new_var = rw_info[buffer_var]
return tvm.make.Load(op.dtype, new_var, op.index) return tvm.make.Load(op.dtype, new_var, op.index)
elif isinstance(op, tvm.stmt.Store): if isinstance(op, tvm.stmt.Store):
buffer_var = op.buffer_var buffer_var = op.buffer_var
if not buffer_var in rw_info: if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var( rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle") buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var] new_var = rw_info[buffer_var]
return tvm.make.Store(new_var, op.value, op.index) return tvm.make.Store(new_var, op.value, op.index)
else:
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform( stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
...@@ -233,22 +231,19 @@ def lift_alloc_to_scope_begin(stmt_in): ...@@ -233,22 +231,19 @@ def lift_alloc_to_scope_begin(stmt_in):
if op.attr_key == "virtual_thread": if op.attr_key == "virtual_thread":
lift_stmt.append([]) lift_stmt.append([])
return None
def _post_order(op): def _post_order(op):
if isinstance(op, tvm.stmt.Allocate): if isinstance(op, tvm.stmt.Allocate):
lift_stmt[-1].append(op) lift_stmt[-1].append(op)
return op.body return op.body
elif isinstance(op, tvm.stmt.AttrStmt): if isinstance(op, tvm.stmt.AttrStmt):
if op.attr_key == "storage_scope": if op.attr_key == "storage_scope":
lift_stmt[-1].append(op) lift_stmt[-1].append(op)
return op.body return op.body
elif op.attr_key == "virtual_thread": if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body) return _merge_block(lift_stmt.pop() + [op], op.body)
return op return op
elif isinstance(op, tvm.stmt.For): if isinstance(op, tvm.stmt.For):
return _merge_block(lift_stmt.pop() + [op], op.body) return _merge_block(lift_stmt.pop() + [op], op.body)
else:
raise RuntimeError("not reached") raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform( stmt = tvm.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
...@@ -297,7 +292,7 @@ def inject_coproc_sync(stmt_in): ...@@ -297,7 +292,7 @@ def inject_coproc_sync(stmt_in):
sync = tvm.make.Call( sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync)) return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif _match_pragma(stmt, "trim_loop"): if _match_pragma(stmt, "trim_loop"):
op = stmt.body op = stmt.body
assert isinstance(op, tvm.stmt.For) assert isinstance(op, tvm.stmt.For)
return tvm.make.For( return tvm.make.For(
...@@ -584,7 +579,7 @@ def annotate_alu_coproc_scope(stmt_in): ...@@ -584,7 +579,7 @@ def annotate_alu_coproc_scope(stmt_in):
tvm.make.StringImm("VTAPushALUOp")) tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt) irb.emit(stmt)
return irb.get() return irb.get()
elif _match_pragma(stmt, "skip_alu"): if _match_pragma(stmt, "skip_alu"):
return tvm.make.Evaluate(0) return tvm.make.Evaluate(0)
return stmt return stmt
......
...@@ -193,7 +193,7 @@ def _build(funcs, target, target_host): ...@@ -193,7 +193,7 @@ def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target) tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta": if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host) return tvm.build(funcs, target="ext_dev", target_host=target_host)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu": if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host) return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target) return tvm.build(funcs, target=target)
...@@ -279,9 +279,8 @@ def schedule_conv2d(attrs, outs, target): ...@@ -279,9 +279,8 @@ def schedule_conv2d(attrs, outs, target):
target = tvm.target.create(target) target = tvm.target.create(target)
if target.device_name == "vta": if target.device_name == "vta":
return schedule_packed_conv2d(outs) return schedule_packed_conv2d(outs)
elif str(target).startswith("llvm"): if str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target) raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d(attrs, outs, target) return _nn.schedule_conv2d(attrs, outs, target)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment