Commit 89c124bc by Tianqi Chen

Update tvm, fix lint due to pylint update (#423)

parent e722dbcb
# pylint: disable=invalid-name, unused-argument # pylint: disable=invalid-name, unused-argument
"""CoreML frontend.""" """CoreML frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm
import numpy as np import numpy as np
import tvm
from .. import symbol as _sym from .. import symbol as _sym
from .common import SymbolTable from .common import SymbolTable
...@@ -77,6 +78,7 @@ def ConvolutionLayerParams(op, insym, symtab): ...@@ -77,6 +78,7 @@ def ConvolutionLayerParams(op, insym, symtab):
return ret return ret
def BatchnormLayerParams(op, insym, symtab): def BatchnormLayerParams(op, insym, symtab):
"""Get layer of batchnorm parameter"""
# this changes the symbol # this changes the symbol
if op.instanceNormalization: if op.instanceNormalization:
raise NotImplementedError("instance normalization not implemented") raise NotImplementedError("instance normalization not implemented")
...@@ -89,6 +91,7 @@ def BatchnormLayerParams(op, insym, symtab): ...@@ -89,6 +91,7 @@ def BatchnormLayerParams(op, insym, symtab):
return _sym.batch_norm(data=insym, **params) return _sym.batch_norm(data=insym, **params)
def ActivationParams(op, insym, symtab): def ActivationParams(op, insym, symtab):
"""Get activation parameters"""
whichActivation = op.WhichOneof('NonlinearityType') whichActivation = op.WhichOneof('NonlinearityType')
par = getattr(op, whichActivation) par = getattr(op, whichActivation)
if whichActivation == 'linear': if whichActivation == 'linear':
...@@ -129,6 +132,8 @@ def ActivationParams(op, insym, symtab): ...@@ -129,6 +132,8 @@ def ActivationParams(op, insym, symtab):
betasym = symtab.new_const(beta) betasym = symtab.new_const(beta)
return _sym.broadcast_mul(_sym.log(_sym.broadcast_add( return _sym.broadcast_mul(_sym.log(_sym.broadcast_add(
_sym.exp(insym), betasym)), alphasym) _sym.exp(insym), betasym)), alphasym)
else:
raise NotImplementedError('%s not implemented' % whichActivation)
def ScaleLayerParams(op, insym, symtab): def ScaleLayerParams(op, insym, symtab):
"""Scale layer params.""" """Scale layer params."""
...@@ -144,6 +149,7 @@ def ScaleLayerParams(op, insym, symtab): ...@@ -144,6 +149,7 @@ def ScaleLayerParams(op, insym, symtab):
return ret return ret
def PoolingLayerParams(op, insym, symtab): def PoolingLayerParams(op, insym, symtab):
"""get pooling parameters"""
if op.globalPooling: if op.globalPooling:
if op.type == 0: if op.type == 0:
return _sym.global_max_pool2d(insym) return _sym.global_max_pool2d(insym)
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
import topi import topi
import topi.testing
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment