Commit e20ef0d4 by Marcus Shawcroft Committed by Tianqi Chen

Fix pylint 2.2.2 gripes. (#2642)

parent 81334be3
......@@ -31,7 +31,7 @@ else:
class NNVMError(Exception):
"""Error that will be throwed by all nnvm functions"""
pass
def _load_lib():
"""Load libary by searching possible path."""
......
......@@ -42,8 +42,7 @@ class AttrScope(object):
if attr:
ret.update(attr)
return ret
else:
return attr
return attr
def __enter__(self):
# pylint: disable=protected-access
......
......@@ -23,13 +23,11 @@ class GraphKey(tvm.node.NodeBase):
@tvm.register_node
class GraphCacheEntry(tvm.node.NodeBase):
"""CacheEntry of compilation into a TVM Function"""
pass
@tvm.register_node
class GraphFunc(tvm.node.NodeBase):
"""Compiled result of a graph into a TVM Function"""
pass
class Engine(object):
......
......@@ -73,9 +73,8 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'):
return getattr(cls, '_impl')
else:
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
_caffe2_internal_args = {
......@@ -175,11 +174,10 @@ class Concat(Caffe2OpConverter):
order = order if isinstance(order, str) else order.decode('UTF-8')
if order == 'NCHW':
return 1
elif order == 'NHWC':
if order == 'NHWC':
return 3
else:
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
return AttrCvt(
op_name='concatenate',
......
......@@ -98,33 +98,33 @@ def ActivationParams(op, insym, symtab):
par = getattr(op, whichActivation)
if whichActivation == 'linear':
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, scalar=par.alpha), scalar=par.beta)
elif whichActivation == 'ReLU':
if whichActivation == 'ReLU':
return _sym.relu(insym)
elif whichActivation == 'leakyReLU':
if whichActivation == 'leakyReLU':
return _sym.leaky_relu(insym, alpha=par.alpha)
elif whichActivation == 'thresholdedReLU':
if whichActivation == 'thresholdedReLU':
alpha_tensor = _sym.full_like(insym, fill_value=float(par.alpha))
return _sym.elemwise_mul(insym, _sym.greater(insym, alpha_tensor))
elif whichActivation == 'PReLU':
if whichActivation == 'PReLU':
return _sym.prelu(insym, alpha=par.alpha)
elif whichActivation == 'tanh':
if whichActivation == 'tanh':
return _sym.tanh(insym)
elif whichActivation == 'scaledTanh':
if whichActivation == 'scaledTanh':
return _sym.__mul_scalar__(_sym.tanh(_sym.__mul_scalar__(
insym, scalar=par.beta)), scalar=par.alpha)
elif whichActivation == 'sigmoid':
if whichActivation == 'sigmoid':
return _sym.sigmoid(insym)
elif whichActivation == 'sigmoidHard':
if whichActivation == 'sigmoidHard':
transformX = (par.alpha * insym) + par.beta
return _sym.clip(transformX, a_min=0, a_max=1)
elif whichActivation == 'ELU':
if whichActivation == 'ELU':
return _sym.__mul_scalar__(_sym.__add_scalar__(
_sym.exp(insym), scalar=-1), scalar=par.alpha)
elif whichActivation == 'softsign':
if whichActivation == 'softsign':
return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym))))
elif whichActivation == 'softplus':
if whichActivation == 'softplus':
return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1))
elif whichActivation == 'parametricSoftplus':
if whichActivation == 'parametricSoftplus':
alpha = list(par.alpha.floatValue)
beta = list(par.alpha.floatValue)
if len(alpha) == 1:
......@@ -136,8 +136,7 @@ def ActivationParams(op, insym, symtab):
betasym = symtab.new_const(beta)
return _sym.broadcast_mul(_sym.log(_sym.broadcast_add(
_sym.exp(insym), betasym)), alphasym)
else:
raise NotImplementedError('%s not implemented' % whichActivation)
raise NotImplementedError('%s not implemented' % whichActivation)
def ScaleLayerParams(op, insym, symtab):
"""Scale layer params."""
......@@ -157,10 +156,9 @@ def PoolingLayerParams(op, insym, symtab):
if op.globalPooling:
if op.type == 0:
return _sym.global_max_pool2d(insym)
elif op.type == 1:
if op.type == 1:
return _sym.global_avg_pool2d(insym)
else:
raise NotImplementedError("Only max and average pooling implemented")
raise NotImplementedError("Only max and average pooling implemented")
else:
params = {'pool_size':list(op.kernelSize),
......@@ -190,10 +188,9 @@ def PoolingLayerParams(op, insym, symtab):
if op.type == 0:
return _sym.max_pool2d(insym, **params)
elif op.type == 1:
if op.type == 1:
return _sym.avg_pool2d(insym, **params)
else:
raise NotImplementedError("Only max and average pooling implemented")
raise NotImplementedError("Only max and average pooling implemented")
def SoftmaxLayerParams(op, insym, symtab):
return _sym.softmax(_sym.flatten(insym))
......
......@@ -921,8 +921,6 @@ class GraphProto(object):
if layer_num != self.net.n-1:
self._outs.insert(0, sym)
return
def from_darknet(self):
"""To convert the darknet symbol to nnvm symbols."""
for i in range(self.net.n):
......
......@@ -47,35 +47,34 @@ def _convert_activation(insym, keras_layer, _):
beta = keras_layer.beta if hasattr(keras_layer, "beta") else 0
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \
scalar=alpha), scalar=beta)
elif act_type == 'softmax':
if act_type == 'softmax':
return _sym.softmax(insym, axis=1)
elif act_type == 'sigmoid':
if act_type == 'sigmoid':
return _sym.sigmoid(insym)
elif act_type == 'tanh':
if act_type == 'tanh':
return _sym.tanh(insym)
elif act_type == 'relu':
if act_type == 'relu':
return _sym.relu(insym)
elif act_type == 'softplus':
if act_type == 'softplus':
return _sym.log(_sym.__add_scalar__(_sym.exp(insym), scalar=1))
elif act_type == 'elu':
if act_type == 'elu':
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1
return _get_elu(insym, alpha)
elif act_type == 'selu':
if act_type == 'selu':
# Alpha, Gamma values, obtained from https://arxiv.org/abs/1706.02515
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") \
else 1.6732632423543772848170429916717
gamma = keras_layer.gamma if hasattr(keras_layer, "gamma") \
else 1.0507009873554804934193349852946
return gamma * _get_elu(insym, alpha)
elif act_type == 'relu6':
if act_type == 'relu6':
return _sym.clip(insym, a_min=0, a_max=6)
elif act_type == 'softsign':
if act_type == 'softsign':
return insym / (1 + (_sym.relu(insym) + _sym.relu(_sym.negative(insym))))
elif act_type == 'hard_sigmoid':
if act_type == 'hard_sigmoid':
transformX = (0.2 * insym) + 0.5
return _sym.clip(transformX, a_min=0, a_max=1)
else:
raise TypeError("Unsupported activation type : {}".format(act_type))
raise TypeError("Unsupported activation type : {}".format(act_type))
def _convert_advanced_activation(insym, keras_layer, symtab):
......@@ -84,12 +83,12 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
if keras_layer.max_value:
return _sym.clip(insym, a_min=0, a_max=keras_layer.max_value)
return _sym.relu(insym)
elif act_type == 'LeakyReLU':
if act_type == 'LeakyReLU':
return _sym.leaky_relu(insym, alpha=keras_layer.alpha)
elif act_type == 'ELU':
if act_type == 'ELU':
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1
return _get_elu(insym, alpha)
elif act_type == 'PReLU':
if act_type == 'PReLU':
assert hasattr(keras_layer, "alpha"), \
"alpha required for PReLU."
_check_data_format(keras_layer)
......@@ -97,12 +96,11 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
return -symtab.new_const(keras_layer.get_weights()[0] \
.transpose(np.roll(range(size), 1))) \
* _sym.relu(-insym) + _sym.relu(insym)
elif act_type == 'ThresholdedReLU':
if act_type == 'ThresholdedReLU':
theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0
theta_tensor = _sym.full_like(insym[0], fill_value=float(theta))
return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32"))
else:
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
def _convert_merge(insym, keras_layer, _):
......@@ -280,31 +278,29 @@ def _convert_pooling(insym, keras_layer, symtab):
# global pool in keras = global pool + flatten in nnvm
if pool_type == 'GlobalMaxPooling2D':
return _convert_flatten(_sym.global_max_pool2d(insym), keras_layer, symtab)
elif pool_type == 'GlobalAveragePooling2D':
if pool_type == 'GlobalAveragePooling2D':
return _convert_flatten(_sym.global_avg_pool2d(insym), keras_layer, symtab)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w],
'strides': [stride_h, stride_w],
'padding': [0, 0]}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w],
'strides': [stride_h, stride_w],
'padding': [0, 0]}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _sym.max_pool2d(insym, **params)
elif pool_type == 'AveragePooling2D':
# TODO: in keras, padded zeros are not calculated
return _sym.avg_pool2d(insym, **params)
else:
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _sym.max_pool2d(insym, **params)
if pool_type == 'AveragePooling2D':
# TODO: in keras, padded zeros are not calculated
return _sym.avg_pool2d(insym, **params)
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
def _convert_upsample(insym, keras_layer, _):
......
......@@ -424,7 +424,7 @@ def _topo_sort(symbol):
if childs is None:
dep_cnts[name] = 0
else:
dep_cnts[name] = len(set([c.attr('name') for c in childs]))
dep_cnts[name] = len({c.attr('name') for c in childs})
for child in childs:
child_name = child.attr('name')
if child_name not in deps:
......
......@@ -9,8 +9,7 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl
......
......@@ -68,8 +68,7 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
......@@ -433,8 +432,7 @@ def _reshape():
op_name="reshape",
extras={'shape':tuple(params_new[0].asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl
def _bias_add():
......@@ -1394,7 +1392,7 @@ class GraphProto(object):
self._nodes[name] = _sym.Variable(name=name,
shape=self._params[name].shape)
else:
if key != 'dtype' and key != '_output_shapes' and key != '_class':
if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
......
......@@ -115,6 +115,8 @@ class TFParser(object):
"""TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.")
# pylint: disable=unreachable
return 0
def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb
......
......@@ -50,10 +50,9 @@ class Symbol(SymbolBase):
"""x.__add__(y) <=> x+y"""
if isinstance(other, Symbol):
return __add_symbol__(self, other)
elif isinstance(other, _Number):
if isinstance(other, _Number):
return __add_scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))
raise TypeError("type %s not supported" % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
......@@ -64,14 +63,12 @@ class Symbol(SymbolBase):
return __sub_symbol__(self, other)
if isinstance(other, _Number):
return __sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
if isinstance(other, _Number):
return __rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __mul__(self, other):
"""x.__mul__(y) <=> x*y"""
......@@ -79,8 +76,7 @@ class Symbol(SymbolBase):
return __mul_symbol__(self, other)
if isinstance(other, _Number):
return __mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
......@@ -91,28 +87,24 @@ class Symbol(SymbolBase):
return __div_symbol__(self, other)
if isinstance(other, _Number):
return __div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __rdiv__(self, other):
if isinstance(other, _Number):
return __rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
if isinstance(other, _Number):
return __lshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
if isinstance(other, _Number):
return __rshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return self.__div__(other)
......@@ -126,14 +118,12 @@ class Symbol(SymbolBase):
return __pow_symbol__(self, other)
if isinstance(other, _Number):
return __pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __rpow__(self, other):
if isinstance(other, _Number):
return __rpow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
raise TypeError('type %s not supported' % str(type(other)))
def __neg__(self):
"""x.__neg__() <=> -x"""
......@@ -238,12 +228,11 @@ class Symbol(SymbolBase):
"""internal function to get list option"""
if option == 'all':
return _ctypes.c_int(0)
elif option == 'read_only':
if option == 'read_only':
return _ctypes.c_int(1)
elif option == 'aux_state':
if option == 'aux_state':
return _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
def list_input_variables(self, option='all'):
"""List all the input variables in the symbol.
......
......@@ -23,11 +23,10 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def Pooling(data, kernel, stride, pad, pool_type, name):
if pool_type == 'max':
return sym.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name)
elif pool_type == 'avg':
if pool_type == 'avg':
return sym.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad, name=name,
count_include_pad=True)
else:
raise ValueError("Invalid pooling type: " + pool_type)
raise ValueError("Invalid pooling type: " + pool_type)
def Inception7A(data,
num_1x1,
......
......@@ -88,7 +88,6 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets):
before_correct_dets.append(detection)
dets.extend(_correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
net_shape[0], net_shape[1], relative))
return
def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
data = l['output']
......@@ -114,7 +113,6 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets):
_correct_boxes(before_correct_dets, im_shape[0], im_shape[1],
net_shape[0], net_shape[1], relative)
dets.extend(before_correct_dets)
return
def fill_network_boxes(net_shape, im_shape,
thresh, relative, tvm_out):
......
......@@ -129,14 +129,13 @@ class AttrDict(object):
lowercase = self[key].lower()
if lowercase == "1":
return True
elif lowercase == "0":
if lowercase == "0":
return False
elif lowercase == "true":
if lowercase == "true":
return True
elif lowercase == "false":
if lowercase == "false":
return False
else:
raise ValueError("Wrong bool format for key %s" % key)
raise ValueError("Wrong bool format for key %s" % key)
def get_str(self, key):
"""Get string from attr dict
......
......@@ -32,7 +32,6 @@ else:
class TVMError(Exception):
"""Error thrown by TVM function"""
pass
def _load_lib():
......
......@@ -51,7 +51,6 @@ class Function(_FunctionBase):
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
pass
class ModuleBase(object):
......@@ -207,11 +206,11 @@ def get_global_func(name, allow_missing=False):
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return Function(handle, False)
else:
if allow_missing:
return None
else:
raise ValueError("Cannot find global function %s" % name)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
......
......@@ -36,16 +36,16 @@ def convert_to_node(value):
"""
if isinstance(value, _CLASS_NODE_BASE):
return value
elif isinstance(value, bool):
if isinstance(value, bool):
return const(value, 'uint1x1')
elif isinstance(value, Number):
if isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
if isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and
......@@ -54,12 +54,12 @@ def convert_to_node(value):
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
if isinstance(value, NodeGeneric):
return value.asnode()
elif value is None:
if value is None:
return None
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
raise ValueError("don't know how to convert type %s to node" % type(value))
def const(value, dtype=None):
......
......@@ -31,11 +31,11 @@ class IntervalSet(IntSet):
@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass
@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x in Z """
pass
_init_api("tvm.arith")
......@@ -69,15 +69,14 @@ class Future(object):
class FutureError(RuntimeError):
"""Base error class of all future events"""
pass
# pylint:disable=redefined-builtin
class TimeoutError(FutureError):
"""Error raised when a task is timeout."""
pass
class ExecutionError(FutureError):
"""
Error raised when future execution crashes or failed.
"""
pass
......@@ -83,7 +83,7 @@ def encode(inp, result, protocol='json'):
"v": AUTOTVM_LOG_VERSION
}
return json.dumps(json_dict)
elif protocol == 'pickle':
if protocol == 'pickle':
row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args,
......@@ -92,8 +92,8 @@ def encode(inp, result, protocol='json'):
str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
return '\t'.join(row)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
raise RuntimeError("Invalid log protocol: " + protocol)
def decode(row, protocol='json'):
......@@ -136,7 +136,7 @@ def decode(row, protocol='json'):
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
return inp, result
elif protocol == 'pickle':
if protocol == 'pickle':
items = row.split("\t")
tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
......@@ -146,8 +146,8 @@ def decode(row, protocol='json'):
tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3]
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
raise RuntimeError("Invalid log protocol: " + protocol)
def load_from_file(filename):
......
......@@ -32,7 +32,6 @@ class InstantiationError(ValueError):
raised by cfg.raise_error
e.g. too many unrolling, too many threads in a block
"""
pass
class TransformSpace(object):
......@@ -321,17 +320,17 @@ class ReorderSpace(TransformSpace):
if np.sum(tmp_pt) == size:
merged.append(list(tmp_stack))
return
else:
for i in range(len(chains)):
# use i == np.argmax(....) here to take spatial order into consideration
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
if (tmp_pt[i] < len(chains[i]) and
(i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
tmp_stack.append(chains[i][tmp_pt[i]])
tmp_pt[i] += 1
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
tmp_pt[i] -= 1
tmp_stack.pop()
for i in range(len(chains)):
# use i == np.argmax(....) here to take spatial order into consideration
# if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
if (tmp_pt[i] < len(chains[i]) and
(i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
tmp_stack.append(chains[i][tmp_pt[i]])
tmp_pt[i] += 1
self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
tmp_pt[i] -= 1
tmp_stack.pop()
class ReorderEntity(object):
......@@ -441,7 +440,7 @@ class AnnotateSpace(TransformSpace):
if now == self.num_axis:
# only vectorize inner most dimension
vec_ct = tmp_stack.count('vec')
if vec_ct == 0 or vec_ct == 1:
if vec_ct in (0, 1):
self.entities.append(AnnotateEntity(list(tmp_stack)))
else:
for ann in self.anns[now]:
......
......@@ -294,7 +294,7 @@ def get_config():
class FlopCalculationError(RuntimeError):
"""Error happens when estimating FLOP for a compute op"""
pass
def compute_flop(sch):
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
......@@ -328,29 +328,29 @@ def compute_flop(sch):
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
return 0
elif isinstance(exp, expr.Cast):
if isinstance(exp, expr.Cast):
return _count_flop(exp.value)
elif isinstance(exp, expr.Var):
if isinstance(exp, expr.Var):
return 0
elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
base = 1 if "float" in exp.a.dtype else 0
if isinstance(exp, expr.Not): # unary
return base + _count_flop(exp.a)
return base + _count_flop(exp.a) + _count_flop(exp.b)
elif isinstance(exp, expr.Select):
if isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value))
elif isinstance(exp, expr.Call):
if isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args])
else:
raise FlopCalculationError("Found unsupported operator in the compute expr")
raise FlopCalculationError("Found unsupported operator in the compute expr")
def traverse(ops):
"""accumulate flops"""
......
......@@ -69,7 +69,7 @@ class Tuner(object):
results: Array of autotvm.measure.MeasureResult
result for measurement
"""
pass
def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
"""Begin tuning
......
......@@ -90,7 +90,7 @@ class Range(NodeBase):
You do not need to create Range explicitly.
Python list and tuple will be converted automatically to Range in api functions.
"""
pass
@register_node
class LoweredFunc(NodeBase):
......
......@@ -151,14 +151,14 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver == 9.0 or cuda_ver == 9.1:
if cuda_ver in (9.0, 9.1):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"):
continue
ver = int(fn.split(".")[-3].split("_")[-1])
if ver > selected_ver and ver <= arch:
if selected_ver < ver <= arch:
selected_ver = ver
selected_path = fn
if selected_path is None:
......
......@@ -118,8 +118,7 @@ def _find_vpi_path():
vpi_found = [p for p in vpi_path if os.path.exists(p) and os.path.isfile(p)]
if vpi_found:
return os.path.dirname(vpi_found[0])
else:
raise ValueError("Cannot find tvm_vpi.vpi, make sure you did `make verilog`")
raise ValueError("Cannot find tvm_vpi.vpi, make sure you did `make verilog`")
def search_path():
"""Get the search directory."""
......
......@@ -189,9 +189,9 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
return entry
elif ty is Symbol.ConstVar:
if ty is Symbol.ConstVar:
return entry if isinstance(node.ctx, ast.Load) else None
elif ty is Symbol.BufferVar:
if ty is Symbol.BufferVar:
if isinstance(node.ctx, ast.Load):
return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
_expr.Call.Halide, entry.op, entry.value_index)
......@@ -274,12 +274,12 @@ class HybridParser(ast.NodeVisitor):
buf, args = lhs
return _make.Provide(buf.op, 0, rhs, args)
return util.make_nop()
else:
lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!")
res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
return res
lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!")
res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
return res
def visit_Index(self, node):
......@@ -347,7 +347,7 @@ class HybridParser(ast.NodeVisitor):
if isinstance(cond, _expr.UIntImm):
if cond.value:
return visit_list_to_block(self.visit, node.body)
elif node.orelse:
if node.orelse:
return visit_list_to_block(self.visit, node.orelse)
return util.make_nop()
......@@ -451,7 +451,7 @@ class HybridParser(ast.NodeVisitor):
bodies.append(body)
return concat_list_to_block(bodies)
elif iter_var is None:
if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, 'int32')):
......
......@@ -60,7 +60,7 @@ def replace_io(body, rmap):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
if isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
......
......@@ -495,7 +495,7 @@ def _rule_float_suffix(op):
"""
if op.dtype == "float32":
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
elif op.dtype == "float64":
if op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args)
return op
......
......@@ -56,7 +56,7 @@ def static_cast(dtype, expr):
if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
if src_type.lanes == target_type.lanes:
return expr
elif src_type.lanes == 1 and target_type.lanes > 1:
if src_type.lanes == 1 and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)
......
......@@ -23,7 +23,6 @@ class NDArray(NDArrayBase):
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
pass
def cpu(dev_id=0):
......
......@@ -43,8 +43,8 @@ try:
from antlr4.tree.Tree import TerminalNode
except ImportError:
raise ParseError("Couldn't find ANTLR runtime." +
"Try running `pip{} install antlr4-python{}-runtime`."
.format(PYTHON_VERSION, PYTHON_VERSION))
"Try running `pip{version} install antlr4-python{version}-runtime`."
.format(version=PYTHON_VERSION))
BINARY_OPS = {
RelayParser.MUL: op.multiply,
......@@ -179,33 +179,31 @@ class ParseTreeToRelayIR(RelayVisitor):
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return lookup(deque([self.global_var_scope]), node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
if node_type == RelayLexer.LOCAL_VAR:
# Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))
return var
elif node_type == RelayLexer.GRAPH_VAR:
if node_type == RelayLexer.GRAPH_VAR:
try:
return self.graph_expr[int(name)]
except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name))
# data types
elif node_type == RelayLexer.NAT:
if node_type == RelayLexer.NAT:
return int(node_text)
elif node_type == RelayLexer.FLOAT:
if node_type == RelayLexer.FLOAT:
return float(node_text)
elif node_type == RelayLexer.BOOL_LIT:
if node_type == RelayLexer.BOOL_LIT:
if node_text == "True":
return True
elif node_text == "False":
if node_text == "False":
return False
else:
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
else:
raise ParseError("todo: {}".format(node_text))
raise ParseError("todo: {}".format(node_text))
def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
......
......@@ -8,7 +8,7 @@ from .expr import Expr, Call
class Pattern(RelayNode):
"""Base type for pattern matching constructs."""
pass
@register_relay_node
class PatternWildcard(Pattern):
......
......@@ -10,7 +10,6 @@ from . import _backend
class CachedFunc(NodeBase):
"""Low-level tensor function to back a relay primitive function.
"""
pass
@register_relay_node
......@@ -34,7 +33,6 @@ class CCacheKey(NodeBase):
class CCacheValue(NodeBase):
"""Value in the CompileEngine, including usage statistics.
"""
pass
def _get_cache_key(source_func, target):
......
......@@ -49,7 +49,6 @@ class TupleValue(Value):
@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
pass
@register_relay_node
......
......@@ -444,7 +444,6 @@ def create_executor(kind="debug",
target = _target.create(target)
if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target)
elif kind == "graph":
if kind == "graph":
return GraphExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown mode {0}".format(mode))
raise RuntimeError("unknown mode {0}".format(mode))
......@@ -15,8 +15,7 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl
......@@ -104,9 +103,8 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'):
return getattr(cls, '_impl')
else:
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
_caffe2_internal_args = [
......@@ -234,11 +232,10 @@ class Concat(Caffe2OpConverter):
order = order if isinstance(order, str) else order.decode('UTF-8')
if order == 'NCHW':
return 1
elif order == 'NHWC':
if order == 'NHWC':
return 3
else:
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
return AttrCvt(
op_name='concatenate',
......
......@@ -10,7 +10,6 @@ from .. import op as _op
class RequiredAttr(object):
"""Dummpy class to represent required attr"""
pass
class StrAttrsDict(object):
......
......@@ -100,37 +100,37 @@ def _ActivationParams(op, inexpr, etab):
alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32')
return _op.add(_op.multiply(inexpr, alpha), beta)
elif whichActivation == 'ReLU':
if whichActivation == 'ReLU':
return _op.nn.relu(inexpr)
elif whichActivation == 'leakyReLU':
if whichActivation == 'leakyReLU':
_op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'thresholdedReLU':
alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype='float32'))
return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type('float32'))
elif whichActivation == 'PReLU':
if whichActivation == 'PReLU':
return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'tanh':
if whichActivation == 'tanh':
return _op.tanh(inexpr)
elif whichActivation == 'scaledTanh':
if whichActivation == 'scaledTanh':
alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32')
return _op.multiply(_op.tanh(_op.multiply(inexpr, beta)), alpha)
elif whichActivation == 'sigmoid':
if whichActivation == 'sigmoid':
return _op.sigmoid(inexpr)
elif whichActivation == 'sigmoidHard':
if whichActivation == 'sigmoidHard':
alpha = _expr.const(par.alpha, dtype='float32')
beta = _expr.const(par.beta, dtype='float32')
transformX = (alpha * inexpr) + beta
return _op.clip(transformX, a_min=0., a_max=1.)
elif whichActivation == 'ELU':
if whichActivation == 'ELU':
return _op.multiply(_op.add(_op.exp(inexpr), _expr.const(-1, dtype='float32')),
_expr.const(par.alpha, dtype='float32'))
elif whichActivation == 'softsign':
if whichActivation == 'softsign':
return inexpr / (_expr.const(1, dtype='float32') + (
op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr))))
elif whichActivation == 'softplus':
if whichActivation == 'softplus':
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype='float32')))
elif whichActivation == 'parametricSoftplus':
if whichActivation == 'parametricSoftplus':
alpha = list(par.alpha.floatValue)
beta = list(par.alpha.floatValue)
if len(alpha) == 1:
......@@ -142,8 +142,7 @@ def _ActivationParams(op, inexpr, etab):
alpha_expr = etab.new_const(alpha)
beta_expr = etab.new_const(beta)
return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr)
else:
raise NotImplementedError('%s not implemented' % whichActivation)
raise NotImplementedError('%s not implemented' % whichActivation)
def _ScaleLayerParams(op, inexpr, etab):
......@@ -163,10 +162,9 @@ def _PoolingLayerParams(op, inexpr, etab):
if op.globalPooling:
if op.type == 0:
return _op.nn.global_max_pool2d(inexpr)
elif op.type == 1:
if op.type == 1:
return _op.nn.global_avg_pool2d(inexpr)
else:
raise NotImplementedError("Only max and average pooling implemented")
raise NotImplementedError("Only max and average pooling implemented")
else:
params = {'pool_size':list(op.kernelSize),
......@@ -196,10 +194,9 @@ def _PoolingLayerParams(op, inexpr, etab):
if op.type == 0:
return _op.nn.max_pool2d(inexpr, **params)
elif op.type == 1:
if op.type == 1:
return _op.nn.avg_pool2d(inexpr, **params)
else:
raise NotImplementedError("Only max and average pooling implemented")
raise NotImplementedError("Only max and average pooling implemented")
def _SoftmaxLayerParams(op, inexpr, etab):
......
......@@ -60,21 +60,21 @@ def _convert_activation(inexpr, keras_layer, _):
alpha = _expr.const(alpha, dtype='float32')
beta = _expr.const(beta, dtype='float32')
return _op.add(_op.multiply(inexpr, alpha), beta)
elif act_type == 'softmax':
if act_type == 'softmax':
return _op.nn.softmax(inexpr, axis=1)
elif act_type == 'sigmoid':
if act_type == 'sigmoid':
return _op.sigmoid(inexpr)
elif act_type == 'tanh':
if act_type == 'tanh':
return _op.tanh(inexpr)
elif act_type == 'relu':
if act_type == 'relu':
return _op.nn.relu(inexpr)
elif act_type == 'softplus':
if act_type == 'softplus':
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
elif act_type == 'elu':
if act_type == 'elu':
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
alpha = _expr.const(alpha, dtype='float32')
return _get_elu(inexpr, alpha)
elif act_type == 'selu':
if act_type == 'selu':
# Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
else 1.6732632423543772848170429916717
......@@ -83,15 +83,15 @@ def _convert_activation(inexpr, keras_layer, _):
alpha = _expr.const(alpha, dtype='float32')
gamma = _expr.const(gamma, dtype='float32')
return gamma * _get_elu(inexpr, alpha)
elif act_type == 'relu6':
if act_type == 'relu6':
return _op.clip(inexpr, a_min=0., a_max=6.)
elif act_type == 'softsign':
if act_type == 'softsign':
return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
elif act_type == 'hard_sigmoid':
if act_type == 'hard_sigmoid':
x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
return _op.clip(x, a_min=0., a_max=1.)
else:
raise TypeError("Unsupported activation type : {}".format(act_type))
raise TypeError("Unsupported activation type : {}".format(act_type))
def _convert_advanced_activation(inexpr, keras_layer, etab):
......@@ -100,25 +100,25 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
if keras_layer.max_value:
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
return _op.nn.relu(inexpr)
elif act_type == 'LeakyReLU':
if act_type == 'LeakyReLU':
return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
elif act_type == 'ELU':
if act_type == 'ELU':
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
alpha = _expr.const(alpha, dtype='float32')
return _get_elu(inexpr, alpha)
elif act_type == 'PReLU':
if act_type == 'PReLU':
assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
_check_data_format(keras_layer)
size = len(keras_layer.alpha.shape)
alpha = etab.new_const(keras_layer.get_weights()[0] \
.transpose(np.roll(range(size), 1)))
return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
elif act_type == 'ThresholdedReLU':
if act_type == 'ThresholdedReLU':
theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
return _op.multiply(inexpr, _op.greater(inexpr, \
_expr.const(theta, dtype='float32')).astype('float32'))
else:
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
def _convert_merge(inexpr, keras_layer, _):
......@@ -297,31 +297,29 @@ def _convert_pooling(inexpr, keras_layer, etab):
# global pool in keras = global pool + flatten in nnvm/relay
if pool_type == 'GlobalMaxPooling2D':
return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
elif pool_type == 'GlobalAveragePooling2D':
if pool_type == 'GlobalAveragePooling2D':
return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w],
'strides': [stride_h, stride_w],
'padding': [0, 0]}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {'pool_size': [pool_h, pool_w],
'strides': [stride_h, stride_w],
'padding': [0, 0]}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _op.nn.max_pool2d(inexpr, **params)
elif pool_type == 'AveragePooling2D':
params['count_include_pad'] = False
return _op.nn.avg_pool2d(inexpr, **params)
else:
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _op.nn.max_pool2d(inexpr, **params)
if pool_type == 'AveragePooling2D':
params['count_include_pad'] = False
return _op.nn.avg_pool2d(inexpr, **params)
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
def _convert_upsample(inexpr, keras_layer, _):
......
......@@ -39,7 +39,7 @@ def _mx_fully_connected(inputs, attrs):
def _get_channel_axis(layout, op_name):
if layout == "NCHW":
return 1
elif layout == "NHWC":
if layout == "NHWC":
return 3
raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name))
......@@ -49,11 +49,11 @@ def _mx_activations(inputs, attrs):
assert len(inputs) == 1
if act_type == "sigmoid":
return _op.sigmoid(inputs[0])
elif act_type == "tanh":
if act_type == "tanh":
return _op.tanh(inputs[0])
elif act_type == "relu":
if act_type == "relu":
return _op.nn.relu(inputs[0])
elif act_type == "softrelu":
if act_type == "softrelu":
def _stable_softrelu(x):
# log(1 + exp(-abs(x))) + relu(x)
one = _expr.const(1, dtype="float32")
......@@ -147,7 +147,7 @@ def _mx_pooling(inputs, attrs):
if global_pool:
return _op.nn.global_max_pool2d(inputs[0])
return _pool2d(_op.nn.max_pool2d, False)
elif pool_type == "avg":
if pool_type == "avg":
if global_pool:
return _op.nn.global_avg_pool2d(inputs[0])
return _pool2d(_op.nn.avg_pool2d, True)
......@@ -209,10 +209,10 @@ def _mx_leaky_relu(inputs, attrs):
act_type = attrs.get_str("act_type")
if act_type == "leaky":
return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25))
elif act_type == "prelu":
if act_type == "prelu":
assert len(inputs) == 2
return _op.nn.prelu(*inputs)
elif act_type == "elu":
if act_type == "elu":
# -slope * relu(1-exp(x)) + relu(x)
slope = attrs.get_float("slope", 0.25)
one = _expr.const(1, dtype="float32")
......@@ -220,7 +220,7 @@ def _mx_leaky_relu(inputs, attrs):
mslope = _op.nn.relu(_op.subtract(one, _op.exp(x)))
mslope = _op.multiply(mslope, _expr.const(-slope, dtype="float32"))
return _op.add(mslope, _op.nn.relu(x))
elif act_type == "rrelu":
if act_type == "rrelu":
# NOTE this is only converted for inference.
lower_bound = attrs.get_float("lower_bound")
upper_bound = attrs.get_float("upper_bound")
......
......@@ -18,8 +18,7 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl
......
......@@ -175,8 +175,7 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _dimension_constraint():
......@@ -522,8 +521,7 @@ def _reshape():
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
return _impl
def _bias_add():
......@@ -1385,7 +1383,7 @@ class GraphProto(object):
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
else:
if key != 'dtype' and key != '_output_shapes' and key != '_class':
if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
......
......@@ -126,15 +126,14 @@ class OperatorConverter(object):
if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
elif tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
elif tensor_wrapper.tensor.Type() == TensorType.INT32:
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
else:
raise NotImplementedError("Not support tensor type {}"
.format(str(tensor_wrapper.tensor.Type())))
raise NotImplementedError("Not support tensor type {}"
.format(str(tensor_wrapper.tensor.Type())))
def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
......@@ -145,12 +144,11 @@ class OperatorConverter(object):
if tensor_type == TensorType.UINT8:
return "uint8"
elif tensor_type == TensorType.FLOAT32:
if tensor_type == TensorType.FLOAT32:
return "float32"
elif tensor_type == TensorType.INT32:
if tensor_type == TensorType.INT32:
return "int32"
else:
raise NotImplementedError("Not support tensor type {}".format(str(tensor_type)))
raise NotImplementedError("Not support tensor type {}".format(str(tensor_type)))
def convert_conv2d(self, op):
"""Convert TFLite conv2d"""
......@@ -192,7 +190,7 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx)
if input_shape_length == 1 or input_shape_length == 2:
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
......@@ -275,7 +273,7 @@ class OperatorConverter(object):
in_expr = self.get_expr(input_tensor_idx)
# TFLite is N H W C, our layout is N C H W
if input_shape_length == 1 or input_shape_length == 2:
if input_shape_length in (1, 2):
# The rule is channel first (after N but before H, W).
# length of 1 means N*H*W*C, do nothing.
# length of 2 means N*H*W, C, do nothing.
......@@ -299,7 +297,7 @@ class OperatorConverter(object):
# 3: N H W C, reshape to N H*W C, transpose to N C H*W
# 4: N H W C, transpose to N C H W
# add more if we need target shapes in future
if output_shape_length == 1 or output_shape_length == 2:
if output_shape_length in (1, 2):
pass
elif output_shape_length == 3:
out = _op.transpose(out, axes=(0, 2, 1))
......@@ -320,16 +318,15 @@ class OperatorConverter(object):
assert fused_activation_fn != ActivationFunctionType.NONE
if fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(in_expr, a_min=0, a_max=6)
elif fused_activation_fn == ActivationFunctionType.RELU:
if fused_activation_fn == ActivationFunctionType.RELU:
return _op.nn.relu(in_expr)
elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
return _op.clip(in_expr, a_min=-1, a_max=1)
elif fused_activation_fn == ActivationFunctionType.TANH:
if fused_activation_fn == ActivationFunctionType.TANH:
return _op.tanh(in_expr)
else:
fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise NotImplementedError("Unsupported fused activation fn {}"
.format(fused_activation_fn_str))
fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise NotImplementedError("Unsupported fused activation fn {}"
.format(fused_activation_fn_str))
def convert_conv(self, op, conv_type):
"""convolution implementation."""
......@@ -401,7 +398,7 @@ class OperatorConverter(object):
# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type == TensorType.UINT8 or weight_tensor_type == TensorType.FLOAT32
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
in_expr = self.get_expr(input_tensor_idx)
......@@ -434,7 +431,7 @@ class OperatorConverter(object):
bias_tensor = input_tensors[2]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type == TensorType.INT32 or bias_tensor_type == TensorType.FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
dtype=bias_tensor_type_str)
......
......@@ -57,7 +57,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "")
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout in ["NCHW", "NHWC", "NCHW4c"]
......@@ -95,15 +95,15 @@ def schedule_conv2d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NCHW4c":
if groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NHWC":
if groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
elif groups != 1:
if groups != 1:
if layout == "NCHW":
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif layout == "NHWC" and kernel_layout == "HWOI":
if layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
raise ValueError("No compatible schedule")
......@@ -127,7 +127,7 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "")
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
......
......@@ -6,19 +6,18 @@ from ..base import register_relay_attr_node
@register_relay_attr_node
class Conv2DAttrs(Attrs):
"""Attribute of nn.conv2d"""
pass
@register_relay_attr_node
class Conv2DWinogradAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
pass
@register_relay_attr_node
class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_weight_transform"""
pass
@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
"""Attribute of nn.global_pool"""
pass
......@@ -29,11 +29,10 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
def Pooling(data, kernel, stride, pad, pool_type, name):
if pool_type == 'max':
return relay.nn.max_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad)
elif pool_type == 'avg':
if pool_type == 'avg':
return relay.nn.avg_pool2d(data=data, pool_size=kernel, strides=stride, padding=pad,
count_include_pad=True)
else:
raise ValueError("Invalid pooling type: " + pool_type)
raise ValueError("Invalid pooling type: " + pool_type)
def Inception7A(data,
num_1x1,
......
......@@ -172,7 +172,6 @@ class TypeCall(Type):
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
......
......@@ -389,7 +389,7 @@ class ProxyServerHandler(object):
if key in pool_src:
self._pair_up(pool_src.pop(key), handler)
return
elif key not in pool_dst:
if key not in pool_dst:
pool_dst[key] = handler
def cleanup():
"""Cleanup client connection if timeout"""
......
......@@ -95,9 +95,8 @@ class TCPHandler(object):
if msg:
self.on_message(msg)
return True
else:
# normal close, remote is closed
self.close()
# normal close, remote is closed
self.close()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
......
......@@ -86,7 +86,7 @@ class Scheduler(object):
value: object
The resource to remove
"""
pass
def summary(self):
"""Get summary information of the scheduler."""
......
......@@ -143,19 +143,16 @@ class Buffer(NodeBase):
@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass
@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass
@register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass
@register_node
......
......@@ -381,7 +381,7 @@ def stmt_list(stmt):
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, ProducerConsumer):
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
......
......@@ -33,7 +33,6 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass
itervar_cls = None
......@@ -144,7 +143,6 @@ class Operation(NodeBase):
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass
@register_node
......@@ -164,7 +162,6 @@ class ComputeOp(Operation):
@register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass
@register_node
......@@ -179,7 +176,7 @@ class ScanOp(Operation):
@register_node
class ExternOp(Operation):
"""Extern operation."""
pass
@register_node
class HybridOp(Operation):
......
......@@ -61,7 +61,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
if dorefa:
assert layout == "NCHW", "Cannot support dorea with NHWC layout yet"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
......
......@@ -554,7 +554,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"]
if out_dtype == "" or out_dtype == "same":
if out_dtype in ("same", ""):
out_dtype = tinfos[0].dtype
if layout != 'NCHW':
......
......@@ -93,10 +93,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'HWCN':
if layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
raise ValueError("not support this layout {} yet".format(layout))
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
......
......@@ -362,7 +362,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"]
if out_dtype == "" or out_dtype == "same":
if out_dtype in ("", "same"):
out_dtype = tinfos[0].dtype
data, kernel = tinfos[0:2]
......@@ -428,7 +428,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F):
)
dispatch_ctx.update(target, new_workload, cfg)
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif groups != CI:
if groups != CI:
workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
group_conv2d_nchw)
......
......@@ -96,7 +96,7 @@ def schedule_reduce(outs):
"""Internal travserse function"""
if isinstance(operator, tvm.tensor.PlaceholderOp):
return
elif tag.is_injective(operator.tag):
if tag.is_injective(operator.tag):
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops:
......
......@@ -92,14 +92,14 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits
if layout == 'NCHW':
return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
elif layout == 'NHWC':
if layout == 'NHWC':
return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa)
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding, out_dtype, layout):
""" Get the workload structure. """
assert layout == "NCHW" or layout == "NHWC", \
assert layout in ("NCHW", "NHWC"), \
"Only support layouts NCHW and NHWC"
if layout == "NCHW":
_, CI, IH, IW = [x.value for x in data.shape]
......
......@@ -48,12 +48,11 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'HWCN':
if layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'NHWC':
if layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
raise ValueError("not support this layout {} yet".format(layout))
@tvm.target.generic_func
......
......@@ -17,12 +17,11 @@ def upsampling_python(data, scale, layout='NCHW'):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
elif layout == 'NHWC':
if layout == 'NHWC':
oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3])
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[3]):
output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale)
return output_np
else:
raise ValueError("not support this layout {} yet".format(layout))
raise ValueError("not support this layout {} yet".format(layout))
......@@ -59,7 +59,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"
wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
......
......@@ -71,12 +71,11 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out
_get_default_config(cfg, data, kernel, strides, padding, out_dtype)
return _declaration_conv_impl(cfg, data, kernel, strides,
padding, dilation, layout, out_dtype)
elif layout == 'HWCN':
if layout == 'HWCN':
return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
elif layout == 'NHWC':
if layout == 'NHWC':
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
else:
raise ValueError("not support this layout {} yet".format(layout))
raise ValueError("not support this layout {} yet".format(layout))
def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
......
......@@ -223,10 +223,9 @@ class Environment(object):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
if self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)
raise ValueError("Unknown target %s" % self.TARGET)
def get_env():
......
......@@ -169,7 +169,7 @@ def clean_cast(graph):
op_name = node.attr("op_name")
if op_name == "cast":
return _clean_cast(node.get_children(), target_type)
elif op_name == "relu":
if op_name == "relu":
data, has_clip = _clean_cast(
node.get_children(), target_type)
data = nnvm.sym.relu(data)
......
......@@ -64,7 +64,7 @@ def gemm(env, mock=False):
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2:
if index in (0, 2):
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 0,
......
......@@ -77,10 +77,9 @@ def fold_uop_loop(stmt_in):
args.append(m[1])
args += op.args[base_args+3:]
return tvm.call_extern("int32", "VTAUopPush", *args)
else:
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op
ret = tvm.ir_pass.IRTransform(
stmt.body, None, _post_order, ["Call"])
......@@ -165,22 +164,21 @@ def cpu_access_rewrite(stmt_in):
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
elif isinstance(op, tvm.expr.Load):
if isinstance(op, tvm.expr.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Load(op.dtype, new_var, op.index)
elif isinstance(op, tvm.stmt.Store):
if isinstance(op, tvm.stmt.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Store(new_var, op.value, op.index)
else:
raise RuntimeError("not reached")
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
......@@ -233,23 +231,20 @@ def lift_alloc_to_scope_begin(stmt_in):
if op.attr_key == "virtual_thread":
lift_stmt.append([])
return None
def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
lift_stmt[-1].append(op)
return op.body
elif isinstance(op, tvm.stmt.AttrStmt):
if isinstance(op, tvm.stmt.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op)
return op.body
elif op.attr_key == "virtual_thread":
if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
elif isinstance(op, tvm.stmt.For):
if isinstance(op, tvm.stmt.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
else:
raise RuntimeError("not reached")
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
......@@ -297,7 +292,7 @@ def inject_coproc_sync(stmt_in):
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif _match_pragma(stmt, "trim_loop"):
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.stmt.For)
return tvm.make.For(
......@@ -584,7 +579,7 @@ def annotate_alu_coproc_scope(stmt_in):
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif _match_pragma(stmt, "skip_alu"):
if _match_pragma(stmt, "skip_alu"):
return tvm.make.Evaluate(0)
return stmt
......
......@@ -193,7 +193,7 @@ def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
......@@ -279,10 +279,9 @@ def schedule_conv2d(attrs, outs, target):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
elif str(target).startswith("llvm"):
if str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d(attrs, outs, target)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment