Commit f4789db6 by Yao Wang Committed by Tianqi Chen

Add optimizer (#334)

parent bd40bcd1
...@@ -28,7 +28,6 @@ This level enables fully connected multi-layer perceptron. ...@@ -28,7 +28,6 @@ This level enables fully connected multi-layer perceptron.
:nosignatures: :nosignatures:
nnvm.symbol.dense nnvm.symbol.dense
nnvm.symbol.matmul
nnvm.symbol.relu nnvm.symbol.relu
nnvm.symbol.tanh nnvm.symbol.tanh
nnvm.symbol.sigmoid nnvm.symbol.sigmoid
...@@ -40,12 +39,6 @@ This level enables fully connected multi-layer perceptron. ...@@ -40,12 +39,6 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div nnvm.symbol.elemwise_div
nnvm.symbol.elemwise_sum nnvm.symbol.elemwise_sum
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
nnvm.symbol.flatten nnvm.symbol.flatten
nnvm.symbol.concatenate nnvm.symbol.concatenate
nnvm.symbol.expand_dims nnvm.symbol.expand_dims
...@@ -57,7 +50,6 @@ This level enables fully connected multi-layer perceptron. ...@@ -57,7 +50,6 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.log_softmax nnvm.symbol.log_softmax
nnvm.symbol.pad nnvm.symbol.pad
nnvm.symbol.block_grad nnvm.symbol.block_grad
nnvm.symbol.indicator
**Level 2: Convolutions** **Level 2: Convolutions**
...@@ -81,8 +73,6 @@ This level enables typical convnet models. ...@@ -81,8 +73,6 @@ This level enables typical convnet models.
:nosignatures: :nosignatures:
nnvm.symbol.reshape nnvm.symbol.reshape
nnvm.symbol.reshape_like
nnvm.symbol.expand_like
nnvm.symbol.copy nnvm.symbol.copy
nnvm.symbol.negative nnvm.symbol.negative
nnvm.symbol.leaky_relu nnvm.symbol.leaky_relu
...@@ -109,11 +99,21 @@ This level enables typical convnet models. ...@@ -109,11 +99,21 @@ This level enables typical convnet models.
nnvm.symbol.broadcast_sub nnvm.symbol.broadcast_sub
nnvm.symbol.broadcast_mul nnvm.symbol.broadcast_mul
nnvm.symbol.broadcast_div nnvm.symbol.broadcast_div
nnvm.symbol.clip
nnvm.symbol.greater
nnvm.symbol.less
nnvm.symbol.expand_like
nnvm.symbol.reshape_like
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
Detailed Definitions Detailed Definitions
-------------------- --------------------
.. autofunction:: nnvm.symbol.dense .. autofunction:: nnvm.symbol.dense
.. autofunction:: nnvm.symbol.matmul
.. autofunction:: nnvm.symbol.relu .. autofunction:: nnvm.symbol.relu
.. autofunction:: nnvm.symbol.tanh .. autofunction:: nnvm.symbol.tanh
.. autofunction:: nnvm.symbol.sigmoid .. autofunction:: nnvm.symbol.sigmoid
...@@ -125,12 +125,6 @@ Detailed Definitions ...@@ -125,12 +125,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.elemwise_sum .. autofunction:: nnvm.symbol.elemwise_sum
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
.. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims .. autofunction:: nnvm.symbol.expand_dims
...@@ -142,7 +136,6 @@ Detailed Definitions ...@@ -142,7 +136,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.log_softmax .. autofunction:: nnvm.symbol.log_softmax
.. autofunction:: nnvm.symbol.pad .. autofunction:: nnvm.symbol.pad
.. autofunction:: nnvm.symbol.block_grad .. autofunction:: nnvm.symbol.block_grad
.. autofunction:: nnvm.symbol.indicator
.. autofunction:: nnvm.symbol.conv2d .. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose .. autofunction:: nnvm.symbol.conv2d_transpose
...@@ -152,8 +145,6 @@ Detailed Definitions ...@@ -152,8 +145,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.global_avg_pool2d .. autofunction:: nnvm.symbol.global_avg_pool2d
.. autofunction:: nnvm.symbol.reshape .. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.reshape_like
.. autofunction:: nnvm.symbol.expand_like
.. autofunction:: nnvm.symbol.copy .. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative .. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.leaky_relu .. autofunction:: nnvm.symbol.leaky_relu
...@@ -175,3 +166,14 @@ Detailed Definitions ...@@ -175,3 +166,14 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.broadcast_sub .. autofunction:: nnvm.symbol.broadcast_sub
.. autofunction:: nnvm.symbol.broadcast_mul .. autofunction:: nnvm.symbol.broadcast_mul
.. autofunction:: nnvm.symbol.broadcast_div .. autofunction:: nnvm.symbol.broadcast_div
.. autofunction:: nnvm.symbol.clip
.. autofunction:: nnvm.symbol.greater
.. autofunction:: nnvm.symbol.less
.. autofunction:: nnvm.symbol.expand_like
.. autofunction:: nnvm.symbol.reshape_like
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
...@@ -241,6 +241,16 @@ struct MatMulParam : public dmlc::Parameter<MatMulParam> { ...@@ -241,6 +241,16 @@ struct MatMulParam : public dmlc::Parameter<MatMulParam> {
} }
}; };
struct ClipParam : public dmlc::Parameter<ClipParam> {
double a_min, a_max;
DMLC_DECLARE_PARAMETER(ClipParam) {
DMLC_DECLARE_FIELD(a_min)
.describe("Minimum value such that value smaller then this will be clipped.");
DMLC_DECLARE_FIELD(a_max)
.describe("Maximum value such that value larger then this will be clipped.");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -54,6 +54,9 @@ OpHandle = ctypes.c_void_p ...@@ -54,6 +54,9 @@ OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p
# Global dict of str to symbol to initialize variables
_all_var_init = {}
#---------------------------- #----------------------------
# helper function definition # helper function definition
#---------------------------- #----------------------------
......
...@@ -4,9 +4,12 @@ from __future__ import absolute_import as _abs ...@@ -4,9 +4,12 @@ from __future__ import absolute_import as _abs
import logging import logging
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from . import graph_attr, graph_util from . import graph_attr, graph_util
from .. import graph as _graph from .. import graph as _graph
from .. import symbol as sym
from .._base import _all_var_init
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
...@@ -201,6 +204,9 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -201,6 +204,9 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
By default, llvm is used if it is enabled, By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used. otherwise a stackvm intepreter is used.
initialize : bool, optional
Whether to initialize variables in global dict _all_var_init.
Returns Returns
------- -------
graph : Graph graph : Graph
...@@ -230,6 +236,10 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -230,6 +236,10 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
if not isinstance(dtype, str): if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype) idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype)) dtype.update(zip(graph.index.input_names, idtype))
# Initialize all variables specified in _all_var_init
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Apply optimization # Apply optimization
graph = optimize(graph, shape, dtype) graph = optimize(graph, shape, dtype)
# Precompute prune # Precompute prune
...@@ -250,6 +260,11 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -250,6 +260,11 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
with target: with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module") libmod = graph_attr._move_out_module(graph, "module")
# Write variable initial values into params
if init_var:
if params is None:
params = {}
params.update(init_var)
return graph, libmod, params return graph, libmod, params
...@@ -329,3 +344,45 @@ def precompute_prune(graph, params): ...@@ -329,3 +344,45 @@ def precompute_prune(graph, params):
with tvm.build_config(auto_unroll_max_step=0): with tvm.build_config(auto_unroll_max_step=0):
out_arrs = _run_graph(pre_graph, params) out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs)) return graph, dict(zip(out_names, out_arrs))
def initialize_variables(ishape, idtype):
""" Initialize variables stored in _all_var_init dictionary.
Parameters
----------
ishape : dict of str to tuple of int
The input shape to the graph
idtype : str or dict of str to str
The input types to the graph
Returns
-------
init_var : dict of str to tvm.ndarray
"""
symbol_init_dict = {}
const_init_dict = {}
init_var = {}
for key, value in _all_var_init.items():
if isinstance(value, sym.Symbol):
symbol_init_dict[key] = value
else:
const_init_dict[key] = tvm.nd.array(value)
# Make sure variables are initialized only once.
_all_var_init.clear()
if symbol_init_dict:
# Create dummy params to run initialization graph
params = {}
for name, shape in ishape.items():
dtype = idtype if isinstance(idtype, str) else idtype[name]
params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu())
init_group_sym = sym.Group(symbol_init_dict.values())
graph = _graph.create(init_group_sym)
with tvm.build_config(auto_unroll_max_step=0):
init_values = _run_graph(graph, params)
init_var.update(dict(zip(symbol_init_dict.keys(), init_values)))
init_var.update(const_init_dict)
for name, data in init_var.items():
ishape[name] = data.shape
return init_var
# pylint: disable=too-few-public-methods, no-member
"""API for scheduling learning rate."""
from .. import symbol as sym
class LRScheduler(object):
"""Base class of a learning rate scheduler.
A scheduler returns a new learning rate based on the number of updates that have
been performed.
Parameters
----------
base_lr : float, optional
The initial learning rate.
"""
def __init__(self, base_lr=0.01, name='LRScheduler'):
self.name = name
self.base_lr = base_lr
def __call__(self, num_update):
"""Return a new learning rate based on number of updates.
Parameters
----------
num_update: nnvm Symbol
the number of updates applied to weight.
"""
raise NotImplementedError("__call__ method must be overridden.")
class FactorScheduler(LRScheduler):
"""Reduce the learning rate by a factor for every *n* steps.
It returns a new learning rate by::
base_lr * pow(factor, num_update/step)
Parameters
----------
step : int
Changes the learning rate for every n updates.
factor : float, optional
The factor to change the learning rate.
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8, name='FactorScheduler', **kwargs):
super(FactorScheduler, self).__init__(name=name, **kwargs)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
raise ValueError("Factor must be no more than 1 to make lr reduce")
self.step = step
self.factor = factor
self.stop_factor_lr = stop_factor_lr
def __call__(self, num_update):
updated_lr = self.base_lr * self.factor ** (num_update / self.step)
return sym.clip(updated_lr, a_min=self.stop_factor_lr, a_max=self.base_lr)
# pylint: disable=invalid-name, no-member, too-few-public-methods, too-many-arguments, too-many-locals, protected-access
"""Optimizer API"""
from . import graph_util
from .. import symbol as sym
class Optimizer(object):
"""Base class inherited by all optimizers.
Parameters
----------
learning_rate : float, optional
The initial learning rate.
lr_scheduler : LRScheduler, optional
The learning rate scheduler.
rescale_grad : float, optional
Multiply the gradient with `rescale_grad` before updating. Often
choose to be ``1.0/batch_size``.
clip_gradient : float, optional
Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
wd : float, optional
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.
name : string, optional
The name of optimizer.
"""
def __init__(self, learning_rate=0.01, lr_scheduler=None,
rescale_grad=1, clip_gradient=None, wd=0, name="Optimizer"):
self.name = name
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
self.rescale_grad = rescale_grad
self.clip_gradient = clip_gradient
self.wd = wd
init_update_t = sym.Variable(name+'_t', init=sym.zeros(shape=(1,), dtype="int32"))
self.update_t = sym._assign(init_update_t, init_update_t + 1)
def minimize(self, obj, var=None):
"""Minimize given obj symbol respect to var. If var is not set, all input
variables of obj will be used.
Parameters
----------
obj : nnvm Symbol or list of nnvm Symbols
Symbols to be minimized.
var : nnvm Symbol or list of nnvm Symbols, optional
Symbols the gradient respect to.
Returns
-------
group_sym : nnvm Symbol
Group symbol represents update symbols.
"""
raise NotImplementedError()
def _get_lr(self):
"""Gets the learning rate with learning rate scheduler.
Returns
-------
lr : float
Learning rate.
"""
if self.lr_scheduler is not None:
lr = self.lr_scheduler(self.update_t)
else:
lr = self.lr
return lr
class SGD(Optimizer):
"""The SGD optimizer
"""
def __init__(self, name='SGD', **kwargs):
super(SGD, self).__init__(name=name, **kwargs)
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
lr_t = self._get_lr()
for v, g in zip(variables, grads):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
updates.append(sym._assign(v, v - lr_t * (g + self.wd * v)))
return sym.Group(updates)
class Adam(Optimizer):
"""The Adam optimizer.
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999,
epsilon=1e-8, name='Adam', **kwargs):
super(Adam, self).__init__(learning_rate=learning_rate, name=name, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = []
self.v = []
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
for i, v in enumerate(variables):
self.m.append(sym.Variable(self.name + '_m' + str(i), init=sym.zeros_like(v)))
self.v.append(sym.Variable(self.name + '_v' + str(i), init=sym.zeros_like(v)))
rate = sym.sqrt(1 - self.beta2 ** self.update_t) / (1 - self.beta1 ** self.update_t)
lr_t = self._get_lr() * rate
for variable, g, m, v in zip(variables, grads, self.m, self.v):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
update_m = sym._assign(m, self.beta1 * m + (1 - self.beta1) * g)
update_v = sym._assign(v, self.beta2 * v + (1 - self.beta2) * g * g)
update_var = sym._assign(variable, variable - lr_t * (update_m / (sym.sqrt(update_v) \
+ self.epsilon) + self.wd * variable))
updates.append(update_var)
return sym.Group(updates)
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import, protected-access
"""Symbolic graph construction API. """Symbolic graph construction API.
This namespace contains most of the registered operators. This namespace contains most of the registered operators.
...@@ -8,10 +8,12 @@ from __future__ import absolute_import as _abs ...@@ -8,10 +8,12 @@ from __future__ import absolute_import as _abs
import sys as _sys import sys as _sys
import os as _os import os as _os
import ctypes as _ctypes import ctypes as _ctypes
from numbers import Number as _Number from numbers import Number as _Number
import numpy as np
from . import _base from . import _base
from ._base import _LIB, check_call as _check_call, _FFI_MODE from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
from .attribute import AttrScope from .attribute import AttrScope
from . import _symbol_internal as _internal from . import _symbol_internal as _internal
...@@ -309,13 +311,19 @@ class Symbol(SymbolBase): ...@@ -309,13 +311,19 @@ class Symbol(SymbolBase):
self.handle, deps.handle)) self.handle, deps.handle))
def Variable(name, **kwargs): def Variable(name, init=None, **kwargs):
"""Create a symbolic variable with specified name. """Create a symbolic variable with specified name.
Parameters Parameters
---------- ----------
name : str name : str
Name of the variable. Name of the variable.
init : Symbol or numpy.ndarray
Symbol or numpy ndarray of initial value for the variable.
Note that for symbolic initialization value, it must be able
to be defined through InferShape, such as sym.zeros_like(v),
in which v is an input or parameter. Otherwise, pass a numpy
ndarray instead.
kwargs : dict of string -> string kwargs : dict of string -> string
Additional attributes to set on the variable. Additional attributes to set on the variable.
...@@ -333,6 +341,11 @@ def Variable(name, **kwargs): ...@@ -333,6 +341,11 @@ def Variable(name, **kwargs):
attr = AttrScope.current.get(kwargs) attr = AttrScope.current.get(kwargs)
if attr: if attr:
ret._set_attr(**attr) ret._set_attr(**attr)
if init is not None:
if not isinstance(init, (Symbol, np.ndarray)):
raise TypeError('Expect a Symbol or numpy ndarray'
'for variable `init`')
_all_var_init[name] = init
return ret return ret
......
...@@ -123,6 +123,21 @@ class AttrDict(object): ...@@ -123,6 +123,21 @@ class AttrDict(object):
else: else:
raise ValueError("Wrong bool format for key %s" % key) raise ValueError("Wrong bool format for key %s" % key)
def get_string(self, key):
"""Get string from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : str
The result value
"""
return self[key]
def __repr__(self): def __repr__(self):
return str({k : self[k] for k in self.keys()}) return str({k : self[k] for k in self.keys()})
......
...@@ -143,3 +143,95 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast) ...@@ -143,3 +143,95 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to # broadcast_to
reg.register_pattern("broadcast_to", OpPattern.BROADCAST) reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast) reg.register_schedule("broadcast_to", _fschedule_broadcast)
# clip
reg.register_pattern("clip", OpPattern.ELEMWISE)
reg.register_schedule("clip", _fschedule_elemwise)
# elemwise sum
@reg.register_compute("elemwise_sum")
def compute_elemwise_sum(attrs, inputs, _):
"""Compute definition of elemwise sum"""
num_args = attrs.get_int("num_args")
assert num_args == len(inputs), "Number of tensors does not match num_args."
return topi.tensor.elemwise_sum(inputs, num_args)
reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
reg.register_schedule("elemwise_sum", _fschedule_elemwise)
# full
@reg.register_compute("full")
def compute_full(attrs, inputs, _):
"""Compute definition of full"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
fill_value = attrs.get_float("fill_value")
return topi.tensor.full(shape, dtype, fill_value)
reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("full", _fschedule_elemwise)
# full_like
@reg.register_compute("full_like")
def compute_full_like(attrs, inputs, _):
"""Compute definition of full_like"""
fill_value = attrs.get_float("fill_value")
return topi.tensor.full_like(inputs[0], fill_value)
reg.register_pattern("full_like", OpPattern.ELEMWISE)
reg.register_schedule("full_like", _fschedule_elemwise)
# zeros
@reg.register_compute("zeros")
def compute_zeros(attrs, inputs, _):
"""Compute definition of zeros"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
return topi.tensor.full(shape, dtype, 0)
reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("zeros", _fschedule_elemwise)
# zeros_like
@reg.register_compute("zeros_like")
def compute_zeros_like(_, inputs, out_info):
"""Compute definition of zeros_like"""
return topi.tensor.full_like(inputs[0], 0)
reg.register_pattern("zeros_like", OpPattern.ELEMWISE)
reg.register_schedule("zeros_like", _fschedule_elemwise)
# ones
@reg.register_compute("ones")
def compute_ones(attrs, inputs, _):
"""Compute definition of ones"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
#tvm.tensor.Tensor()
return topi.tensor.full(shape, dtype, 1)
reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("ones", _fschedule_elemwise)
# ones_like
@reg.register_compute("ones_like")
def compute_ones_like(_, inputs, out_info):
"""Compute definition of ones_like"""
return topi.tensor.full_like(inputs[0], 1)
reg.register_pattern("ones_like", OpPattern.ELEMWISE)
reg.register_schedule("ones_like", _fschedule_elemwise)
# greater
@reg.register_compute("greater")
def compute_greater(_, inputs, out_info):
"""Compute definition of greater"""
return topi.tensor.greater(inputs[0], inputs[1], 'float32')
reg.register_pattern("greater", OpPattern.ELEMWISE)
reg.register_schedule("greater", _fschedule_elemwise)
# less
@reg.register_compute("less")
def compute_less(_, inputs, out_info):
"""Compute definition of less"""
return topi.tensor.less(inputs[0], inputs[1], 'float32')
reg.register_pattern("less", OpPattern.ELEMWISE)
reg.register_schedule("less", _fschedule_elemwise)
# block_grad
reg.register_compute("block_grad", _compute_unary(topi.identity))
reg.register_pattern("block_grad", OpPattern.ELEMWISE)
reg.register_schedule("block_grad", _fschedule_elemwise)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Tensor transformation ops""" """Tensor transformation ops"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
...@@ -10,6 +11,32 @@ from .registry import OpPattern ...@@ -10,6 +11,32 @@ from .registry import OpPattern
reg.register_pattern("expand_dims", OpPattern.BROADCAST) reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast) reg.register_schedule("expand_dims", _fschedule_broadcast)
# expand_like
@reg.register_compute("expand_like")
def compute_expand_like(attrs, inputs, _):
"""Compute definition of expand_like"""
exclude = attrs.get_bool("exclude")
axis = attrs.get_int_tuple("axis")
if exclude:
exclude_axis = (axis,) if isinstance(axis, int) else axis
axis = []
for item in range(len(inputs[1].shape)):
if item not in exclude_axis:
axis.append(item)
axis = tuple(axis)
return topi.transform.expand_like(inputs[0], inputs[1], axis)
reg.register_pattern("expand_like", OpPattern.BROADCAST)
reg.register_schedule("expand_like", _fschedule_broadcast)
# reshape_like
@reg.register_compute("reshape_like")
def compute_reshape_like(attrs, inputs, out_info):
"""Compute definition of reshape_like"""
return topi.reshape(inputs[0], inputs[1].shape)
reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
reg.register_schedule("reshape_like", _fschedule_injective)
# transpose # transpose
reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective) reg.register_schedule("transpose", _fschedule_injective)
......
...@@ -130,15 +130,14 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu) ...@@ -130,15 +130,14 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
// y = relu(x) // y = relu(x)
// grad = indicator(x > 0) // grad = indicator(x > 0) * ograd
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", NodeEntry sub0 = MakeNode("zeros_like", n->attrs.name + "_sub0",
{n->inputs[0]}); {n->inputs[0]});
NodeEntry sub1 = MakeNode("greater", n->attrs.name + "_sub1",
{n->inputs[0], sub0}, {{"exclude", "true"}});
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("elemwise_mul", n->attrs.name + "_grad", { MakeNode("elemwise_mul", n->attrs.name + "_grad",
ograds[0], {ograds[0], sub1})
MakeNode("greater", n->attrs.name + "_grad_mask",
{n->inputs[0], zero}, {{"exclude", "true"}})
})
}; };
}) })
.set_support_level(1); .set_support_level(1);
...@@ -358,23 +357,21 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -358,23 +357,21 @@ NNVM_REGISTER_OP(log_softmax)
// grad_x = sum(grad_x, keepdim, axis) // grad_x = sum(grad_x, keepdim, axis)
// grad_x = neg grad_x // grad_x = neg grad_x
// grad_x = grad_x + ones_like(grad_x) // grad_x = grad_x + ones_like(grad_x)
// grad_x = expand_dims(grad_x, axis)
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed); const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
NodeEntry output = NodeEntry{n, 0, 0}; NodeEntry output = NodeEntry{n, 0, 0};
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output}); NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0}, NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
{{"axis", std::to_string(param.axis)}, {"keepdims", "true"}}); {{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
NodeEntry sub2 = MakeNode("negative", n->attrs.name + "_grad_sub2", {sub1}); NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]},
NodeEntry sub3 = MakeNode("ones_like", n->attrs.name + "_grad_sub3", {sub2}); {{"fill_value", "-1"}});
NodeEntry sub4 = MakeNode("elemwise_add", n->attrs.name + "_grad_sub4", {sub2, sub3}); NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2});
return std::vector<NodeEntry> { return std::vector<NodeEntry> {
MakeNode("expand_like", n->attrs.name + "_grad", {sub4, output}, MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]})
{{"axis", std::to_string(param.axis)}})
}; };
}) })
.set_support_level(1); .set_support_level(1);
// leaky_rlu // leaky_relu
DMLC_REGISTER_PARAMETER(LeakyReLUParam); DMLC_REGISTER_PARAMETER(LeakyReLUParam);
NNVM_REGISTER_OP(leaky_relu) NNVM_REGISTER_OP(leaky_relu)
...@@ -407,14 +404,15 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -407,14 +404,15 @@ NNVM_REGISTER_OP(leaky_relu)
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero", NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
{n->inputs[0]}); {n->inputs[0]});
NodeEntry sub0 = MakeNode("greater", n->attrs.name + "_pos_grad", NodeEntry sub0 = MakeNode("greater", n->attrs.name + "_pos_grad",
{n->inputs[0], zero}, {{"exclude", "true"}}); {n->inputs[0], zero});
NodeEntry sub1 = MakeNode("less", n->attrs.name + "_neg_grad", NodeEntry sub1 = MakeNode("less", n->attrs.name + "_neg_grad",
{n->inputs[0], zero}, {{"exclude", "true"}}); {n->inputs[0], zero});
NodeEntry sub2 = MakeNode("__mul_scalar__", n->attrs.name + "_neg_mul_2", NodeEntry sub2 = MakeNode("__mul_scalar__", n->attrs.name + "_neg_mul_2",
{sub1}, {sub1},
{{"scalar", std::to_string(param.alpha)}}); {{"scalar", std::to_string(param.alpha)}});
NodeEntry sub3 = MakeNode("elemwise_add", n->attrs.name + "_sub3", {sub0, sub2});
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("elemwise_add", n->attrs.name + "_add_grad", {sub0, sub2}) MakeNode("elemwise_mul", n->attrs.name + "_grad", {ograds[0], sub3})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
...@@ -190,7 +190,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add) ...@@ -190,7 +190,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
// y = n0 + n1 // y = n0 + n1
// grad_0 = grad_y // grad_0 = grad_y
// grad_1 = grad_y // grad_1 = grad_y
return std::vector<NodeEntry>{ograds[0], ograds[0]}; return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}),
MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
}); });
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub) NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
...@@ -311,7 +314,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) ...@@ -311,7 +314,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
// y = copy(n0) // y = copy(n0)
// grad_0 = grad_y // grad_0 = grad_y
return std::vector<NodeEntry>{ograds[0]}; return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
}); });
DMLC_REGISTER_PARAMETER(InitOpParam); DMLC_REGISTER_PARAMETER(InitOpParam);
...@@ -329,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full) ...@@ -329,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__()) .add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros) NNVM_REGISTER_INIT_OP(zeros)
.describe(R"code(Fill target with zeros .describe(R"code(Fill target with zeros
...@@ -341,7 +345,7 @@ NNVM_REGISTER_INIT_OP(zeros) ...@@ -341,7 +345,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INIT_OP(ones) NNVM_REGISTER_INIT_OP(ones)
.describe(R"code(Fill target with ones .describe(R"code(Fill target with ones
...@@ -353,7 +357,7 @@ NNVM_REGISTER_INIT_OP(ones) ...@@ -353,7 +357,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_support_level(1); .set_support_level(4);
// full_like // full_like
NNVM_REGISTER_INIT_LIKE_OP(full_like) NNVM_REGISTER_INIT_LIKE_OP(full_like)
...@@ -364,21 +368,21 @@ as the input array ...@@ -364,21 +368,21 @@ as the input array
.add_arguments(FillValueParam::__FIELDS__()) .add_arguments(FillValueParam::__FIELDS__())
.set_attr_parser(ParamParser<FillValueParam>) .set_attr_parser(ParamParser<FillValueParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>)
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(zeros_like) NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type .describe(R"code(Return an array of zeros with the same shape and type
as the input array. as the input array.
)code") )code")
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(ones_like) NNVM_REGISTER_INIT_LIKE_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type .describe(R"code(Return an array of ones with the same shape and type
as the input array. as the input array.
)code") )code")
.set_support_level(1); .set_support_level(4);
// unary scalar op // unary scalar op
DMLC_REGISTER_PARAMETER(ScalarParam); DMLC_REGISTER_PARAMETER(ScalarParam);
...@@ -415,7 +419,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__) ...@@ -415,7 +419,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0]}; return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
}); });
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
...@@ -601,10 +606,11 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum) ...@@ -601,10 +606,11 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum)
CHECK_EQ(ograds.size(), 1); CHECK_EQ(ograds.size(), 1);
std::vector<NodeEntry> ret; std::vector<NodeEntry> ret;
for (size_t i = 0; i < n->inputs.size(); i++) { for (size_t i = 0; i < n->inputs.size(); i++) {
ret.push_back(ograds[0]); ret.push_back(MakeNode("copy", n->attrs.name + "_grad_0", {ograds[0]}));
} }
return ret; return ret;
}); })
.set_support_level(4);
NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad) NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
.describe(R"code(Blocks gradient computation for input. .describe(R"code(Blocks gradient computation for input.
...@@ -614,7 +620,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad) ...@@ -614,7 +620,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
"FInplaceIdentity", [](const NodeAttrs& attrs){ "FInplaceIdentity", [](const NodeAttrs& attrs){
return std::vector<bool>{true}; return std::vector<bool>{true};
}) })
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes); .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_support_level(4);
DMLC_REGISTER_PARAMETER(IndicatorParam); DMLC_REGISTER_PARAMETER(IndicatorParam);
...@@ -628,7 +635,7 @@ with 1.0 if (left > right), otherwise 0.0 element-wise. ...@@ -628,7 +635,7 @@ with 1.0 if (left > right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input") .add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INDICATOR_OP(less) NNVM_REGISTER_INDICATOR_OP(less)
...@@ -640,7 +647,7 @@ with 1.0 if (left < right), otherwise 0.0 element-wise. ...@@ -640,7 +647,7 @@ with 1.0 if (left < right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input") .add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_support_level(1); .set_support_level(4);
NNVM_REGISTER_INDICATOR_OP(_max_mask) NNVM_REGISTER_INDICATOR_OP(_max_mask)
.describe(R"code(Function that returns a mask tensor .describe(R"code(Function that returns a mask tensor
...@@ -668,5 +675,73 @@ with 1.0 if the value is minimum over given axes, otherwise 0.0 element-wise. ...@@ -668,5 +675,73 @@ with 1.0 if the value is minimum over given axes, otherwise 0.0 element-wise.
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_support_level(1); .set_support_level(1);
DMLC_REGISTER_PARAMETER(ClipParam);
NNVM_REGISTER_OP(clip)
.describe(R"doc(Clips (limits) the values in an array.
Given an interval, values outside the interval are clipped to the interval edges.
Clipping ``x`` between `a_min` and `a_x` would be::
clip(x, a_min, a_max) = max(min(x, a_max), a_min))
Example::
x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ClipParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ClipParam params = get<ClipParam>(attrs.parsed);
return Array<Tensor>{
topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min),
tvm::make_const(tvm::Float(32), params.a_max)) };
})
.add_argument("data", "NDArray-or-Symbol", "Input array.")
.add_arguments(ClipParam::__FIELDS__())
.set_attr<nnvm::FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
// y = clip(x, a_min, a_max)
// min_mask = greater_equal(x, a_min*ones_like(x))
// => ones_like(x) - less(x, a_min)
// max_mask = less_equal(x, a_max*ones_like(x))
// => ones_like(x) - greater(x, a_max)
// grad_x = min_mask * max_mask * grad_y
CHECK_EQ(ograds.size(), 1);
NodeEntry sub0 = MakeNode("ones_like", n->attrs.name + "_grad_sub_0",
{n->inputs[0]});
// min_mask
NodeEntry sub1 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1",
{sub0}, {{"scalar", n->attrs.dict["a_min"]}});
NodeEntry sub2 = MakeNode("less", n->attrs.name + "_grad_sub_2",
{n->inputs[0], sub1});
NodeEntry sub3 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_3",
{sub0, sub2});
// max_mask
NodeEntry sub4 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_4",
{sub0}, {{"scalar", n->attrs.dict["a_max"]}});
NodeEntry sub5 = MakeNode("greater", n->attrs.name + "_grad_sub_5",
{n->inputs[0], sub4});
NodeEntry sub6 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_6",
{sub0, sub5});
// min_mask * max_mask
NodeEntry sub7 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_7",
{sub3, sub6});
return std::vector<NodeEntry>{
MakeNode("elemwise_mul", n->attrs.name + "_grad",
{sub7, ograds[0]})
};
})
.set_support_level(4);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -137,7 +137,20 @@ Example:: ...@@ -137,7 +137,20 @@ Example::
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis); Array<Expr> axis;
if (param.exclude) {
std::set<dim_t> exclude_axis;
for (dim_t i = 0; i < param.axis.ndim(); ++i) {
exclude_axis.insert(param.axis[i]);
}
for (dim_t i = 0; i < inputs[0].ndim(); ++i) {
if (exclude_axis.count(i) == 0) {
axis.push_back(make_const(Int(32), i));
}
}
} else {
axis = ShapeToArray(param.axis);
}
return Array<Tensor>{ return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) }; topi::sum(inputs[0], axis, param.keepdims) };
}) })
...@@ -150,7 +163,6 @@ Example:: ...@@ -150,7 +163,6 @@ Example::
MakeNode("expand_like", n->attrs.name + "_grad", MakeNode("expand_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]}, {ograds[0], n->inputs[0]},
{{"axis", axis.str()}, {{"axis", axis.str()},
{"keepdims", std::to_string(param.keepdims)},
{"exclude", std::to_string(param.exclude)}}) {"exclude", std::to_string(param.exclude)}})
}; };
}); });
......
...@@ -48,6 +48,15 @@ This is an experimental operator. ...@@ -48,6 +48,15 @@ This is an experimental operator.
.set_attr<FInplaceOption>( .set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) { "FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{1, 0}}; return std::vector<std::pair<int, int> >{{1, 0}};
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_zero_grad",
{n->inputs[0]}),
ograds[0]
};
}); });
} // namespace top } // namespace top
......
...@@ -229,29 +229,24 @@ will return a new array with shape ``(2,5,3,4)``. ...@@ -229,29 +229,24 @@ will return a new array with shape ``(2,5,3,4)``.
NNVM_REGISTER_OP(expand_like) NNVM_REGISTER_OP(expand_like)
.describe(R"code(Expand an input array with the shape of second array. .describe(R"code(Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and expanding dims. This operation can always be composed of unsqueezing and expanding dims.
Examples:: Examples::
input = [ 12. 19. 27.] input = [ 12. 19. 27.]
input.shape = (3,) input.shape = (3,)
new_shape_array = [[[1,2],[2,3],[1,3]], new_shape_array = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]], [[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]] [[7,1],[7,2],[7,3]]]
new_shape_array.shape = (3, 3, 2) new_shape_array.shape = (3, 3, 2)
expand_like(input, [1,2], new_shape_array) = expand_like(input, [1,2], new_shape_array) =
[[[12,12],[12,12],[12,12]], [[[12,12],[12,12],[12,12]],
[[19,19],[19,19],[19,19]], [[19,19],[19,19],[19,19]],
[[27,27],[27,27],[27,27]]] [[27,27],[27,27],[27,27]]]
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("input", "Tensor", "Source input") .add_argument("input", "Tensor", "Source input")
.add_argument("shape_like", "Tensor", "Input with new shape") .add_argument("shape_like", "Tensor", "Input with new shape")
.add_arguments(ReduceParam::__FIELDS__()) .add_arguments(IndicatorParam::__FIELDS__())
.set_attr_parser(ParamParser<ReduceParam>) .set_attr_parser(ParamParser<IndicatorParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>)
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>) .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_num_inputs(2) .set_num_inputs(2)
...@@ -259,7 +254,7 @@ Examples:: ...@@ -259,7 +254,7 @@ Examples::
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed); const IndicatorParam& param = nnvm::get<IndicatorParam>(n->attrs.parsed);
std::ostringstream axis; std::ostringstream axis;
axis << param.axis; axis << param.axis;
...@@ -267,11 +262,11 @@ Examples:: ...@@ -267,11 +262,11 @@ Examples::
MakeNode("sum", n->attrs.name + "_grad", MakeNode("sum", n->attrs.name + "_grad",
{ograds[0]}, {ograds[0]},
{{"axis", axis.str()}, {{"axis", axis.str()},
{"keepdims", std::to_string(param.keepdims)}, {"exclude", std::to_string(param.exclude)}}),
{"exclude", std::to_string(param.exclude)}}) MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
}; };
}) })
.set_support_level(1); .set_support_level(4);
// split // split
DMLC_REGISTER_PARAMETER(SplitParam); DMLC_REGISTER_PARAMETER(SplitParam);
...@@ -564,13 +559,10 @@ The significance of each is explained below: ...@@ -564,13 +559,10 @@ The significance of each is explained below:
NNVM_REGISTER_OP(reshape_like) NNVM_REGISTER_OP(reshape_like)
.describe(R"code(Reshapes the input array by the size of another array. .describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array. the input array into an output array with the same shape as the second input array.
.. note:: .. note::
Sizes for both array should be compatible. Sizes for both array should be compatible.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.add_argument("shape_like", "Tensor", "Input data.") .add_argument("shape_like", "Tensor", "Input data.")
...@@ -589,10 +581,12 @@ the input array into an output array with the same shape as the second input arr ...@@ -589,10 +581,12 @@ the input array into an output array with the same shape as the second input arr
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
return MakeGradNode("reshape_like", n, return std::vector<NodeEntry>{
{ograds[0], n->inputs[0]}); MakeNode("reshape_like", n->attrs.name + "_grad", {ograds[0], n->inputs[0]}),
MakeNode("zeros_like", n->attrs.name + "_zero_grad", { n->inputs[1]})
};
}) })
.set_support_level(3); .set_support_level(4);
// squeeze // squeeze
DMLC_REGISTER_PARAMETER(SqueezeParam); DMLC_REGISTER_PARAMETER(SqueezeParam);
...@@ -680,7 +674,8 @@ Examples:: ...@@ -680,7 +674,8 @@ Examples::
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
MakeNode("reshape_like", n->attrs.name + "_grad", {n->inputs[0]}) MakeNode("reshape_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]})
}; };
}) })
.set_support_level(1); .set_support_level(1);
......
import numpy as np
import tvm
import nnvm
import nnvm.compiler.optimizer as optimizer
import nnvm.compiler.lr_scheduler as lr_scheduler
from nnvm.testing.config import ctx_list
from tvm.contrib import graph_runtime
def helper(symbol, inputs, params, update_func, run_times, target, ctx, dtype="float32"):
ishapes = {}
np_inputs = {}
params_dict = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
for (name, shape, s) in params:
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
params_dict.update({name: np_inputs[name]})
graph, lib, rt_params = nnvm.compiler.build(symbol, target, shape=ishapes)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**np_inputs)
m.set_input(**rt_params)
for _ in range(run_times):
m.run()
y_np = update_func(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_sgd():
for target, ctx in ctx_list():
data = nnvm.sym.Variable("data")
weight = nnvm.sym.Variable("weight")
out = nnvm.sym.elemwise_mul(data, weight ** 2)
dshape = (1, 2, 3)
wshape = dshape
base_lr = 0.1
lr_factor = 0.5
rescale_grad = 0.2
wd = 0.1
clip_gradient = 0.25
scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
opt = optimizer.SGD(learning_rate=base_lr, lr_scheduler=scheduler,
rescale_grad=rescale_grad, clip_gradient=clip_gradient,
wd=wd)
opt_sym = opt.minimize(out, var=weight)
inputs = [("data", dshape, data)]
params = [("weight", wshape, weight)]
def update_func(data, weight):
gradient_0 = data * 2 * weight * rescale_grad
gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
weight_0 = weight - base_lr * lr_factor * (gradient_0 + wd * weight)
gradient_1 = data * 2 * weight_0 * rescale_grad
gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
weight_1 = weight_0 - base_lr * (lr_factor ** 2) * (gradient_1 + wd * weight_0)
return weight_1
helper(opt_sym, inputs, params, update_func, 2, target, ctx)
def test_adam():
for target, ctx in ctx_list():
data = nnvm.sym.Variable("data")
weight = nnvm.sym.Variable("weight")
out = nnvm.sym.elemwise_mul(data, weight ** 2)
dshape = (1, 2, 3)
wshape = dshape
base_lr = 0.1
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
lr_factor = 0.5
rescale_grad = 0.2
wd = 0.1
clip_gradient = 0.25
scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
opt = optimizer.Adam(learning_rate=base_lr, beta1=beta1, beta2=beta2, epsilon=epsilon,
lr_scheduler=scheduler, rescale_grad=rescale_grad,
clip_gradient=clip_gradient, wd=wd)
opt_sym = opt.minimize(out, var=weight)
inputs = [("data", dshape, data)]
params = [("weight", wshape, weight)]
def update_func(data, weight):
rate_0 = np.sqrt(1 - beta2) / (1 - beta1)
lr_0 = base_lr * lr_factor * rate_0
gradient_0 = data * 2 * weight * rescale_grad
gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
m_0 = (1 - beta1) * gradient_0
v_0 = (1 - beta2) * (gradient_0 ** 2)
weight_0 = weight - lr_0 * (m_0 / (np.sqrt(v_0) + epsilon) + wd * weight)
rate_1 = np.sqrt(1 - beta2 ** 2) / (1 - beta1 ** 2)
lr_1 = base_lr * (lr_factor ** 2) * rate_1
gradient_1 = data * 2 * weight_0 * rescale_grad
gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
m_1 = beta1 * m_0 + (1 - beta1) * gradient_1
v_1 = beta2 * v_0 + (1 - beta2) * (gradient_1 ** 2)
weight_1 = weight_0 - lr_1 * (m_1 / (np.sqrt(v_1) + epsilon) + wd * weight_0)
return weight_1
helper(opt_sym, inputs, params, update_func, 2, target, ctx)
if __name__ == "__main__":
test_sgd()
test_adam()
...@@ -8,15 +8,14 @@ from nnvm.testing.config import ctx_list ...@@ -8,15 +8,14 @@ from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype, def helper(symbol, inputs, dtype,
np_forward, np_backward=None): np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {} ishapes = {}
input_syms = [] input_syms = []
np_inputs = {} np_inputs = {}
for (k, v) in inputs.items(): for (name, shape, s) in inputs:
ishapes.update({k: v[0]}) ishapes.update({name: shape})
np_inputs.update({k: np.random.uniform(size=v[0]).astype(dtype)}) np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
if len(v) > 1: input_syms.append(s)
input_syms.append(v[1])
for target, ctx in ctx_list(): for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes) graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
...@@ -25,23 +24,26 @@ def helper(symbol, inputs, dtype, ...@@ -25,23 +24,26 @@ def helper(symbol, inputs, dtype,
y_np = np_forward(**np_inputs) y_np = np_forward(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
# backward # backward
if np_backward: if np_backward:
graph._set_symbol_list_attr("grad_ys", symbol) graph._set_symbol_list_attr("grad_ys", symbol)
for x in input_syms: graph._set_symbol_list_attr("grad_xs", input_syms)
graph._set_symbol_list_attr("grad_xs", x) graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads"))
graph = graph.apply("Gradient") graph = graph.apply("Gradient")
ishapes.update({"head_grads": y_np.shape}) ishapes.update({"head_grads": y_np.shape})
graph, lib, _ = nnvm.compiler.build(graph, target, ishapes) graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
head_grads = np.random.uniform(size=y_np.shape).astype(dtype) head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
y_np = head_grads * np_backward(**np_inputs) y_np = np_backward(head_grads=head_grads, **np_inputs)
m.run(head_grads=head_grads, **np_inputs) b_inputs = {}
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype)) if need_input:
b_inputs.update(np_inputs)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) if need_head_grads:
b_inputs.update({"head_grads":head_grads})
m.run(**b_inputs)
for i in range(len(y_np)):
out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
def test_relu(): def test_relu():
...@@ -52,10 +54,15 @@ def test_relu(): ...@@ -52,10 +54,15 @@ def test_relu():
x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2 x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
return (x > 0) * x return (x > 0) * x
def backward(head_grads, x):
sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
return [(sub > 0).astype("float") * \
((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward, backward)
def test_sym_scalar_pow(): def test_sym_scalar_pow():
...@@ -66,12 +73,12 @@ def test_sym_scalar_pow(): ...@@ -66,12 +73,12 @@ def test_sym_scalar_pow():
def forward(x): def forward(x):
return x**scalar return x**scalar
def backward(x): def backward(head_grads, x):
return scalar * x**(scalar - 1) return [scalar * x**(scalar - 1) * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -83,12 +90,12 @@ def test_scalar_sym_pow(): ...@@ -83,12 +90,12 @@ def test_scalar_sym_pow():
def forward(x): def forward(x):
return scalar**x return scalar**x
def backward(x): def backward(head_grads, x):
return np.log(scalar) * scalar**x return [np.log(scalar) * scalar**x * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -99,12 +106,12 @@ def test_exp(): ...@@ -99,12 +106,12 @@ def test_exp():
def forward(x): def forward(x):
return np.exp(x) return np.exp(x)
def backward(x): def backward(head_grads, x):
return np.exp(x) return [np.exp(x) * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -115,12 +122,12 @@ def test_log(): ...@@ -115,12 +122,12 @@ def test_log():
def forward(x): def forward(x):
return np.log(x) return np.log(x)
def backward(x): def backward(head_grads, x):
return 1. / x return [1. / x * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -131,13 +138,13 @@ def test_tanh(): ...@@ -131,13 +138,13 @@ def test_tanh():
def forward(x): def forward(x):
return np.sinh(x) / np.cosh(x) return np.sinh(x) / np.cosh(x)
def backward(x): def backward(head_grads, x):
y_np = forward(x) y_np = forward(x)
return (1 - y_np**2) return [(1 - y_np**2) * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -148,13 +155,13 @@ def test_sigmoid(): ...@@ -148,13 +155,13 @@ def test_sigmoid():
def forward(x): def forward(x):
return 1.0 / (1.0 + np.exp(-x)) return 1.0 / (1.0 + np.exp(-x))
def backward(x): def backward(head_grads, x):
y_np = forward(x) y_np = forward(x)
return y_np *(1 - y_np) return [y_np *(1 - y_np) * head_grads]
dtype = "float32" dtype = "float32"
dshape = (1, 3, 32, 32) dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
...@@ -165,10 +172,15 @@ def test_softmax(): ...@@ -165,10 +172,15 @@ def test_softmax():
def forward(x): def forward(x):
return topi.testing.softmax_python(x) return topi.testing.softmax_python(x)
def backward(head_grads, x):
y = topi.testing.softmax_python(x)
grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True))
return [grad]
dtype = "float32" dtype = "float32"
dshape = (10, 1000) dshape = (10, 1000)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward), backward
def test_log_softmax(): def test_log_softmax():
...@@ -178,26 +190,32 @@ def test_log_softmax(): ...@@ -178,26 +190,32 @@ def test_log_softmax():
def forward(x): def forward(x):
return topi.testing.log_softmax_python(x) return topi.testing.log_softmax_python(x)
def backward(head_grads, x):
y = topi.testing.log_softmax_python(x)
grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True)
return [grad]
dtype = "float32" dtype = "float32"
dshape = (10, 1000) dshape = (10, 1000)
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward, backward)
def test_dense(): def test_dense():
x = sym.Variable("x") x = sym.Variable("x", shape=(10, 100))
y = sym.dense(x, units=3, name="dense") w = sym.Variable("dense_weight", shape=(3, 100))
b = sym.Variable("dense_bias", shape=(3,))
y = sym.dense(x, w, b, use_bias=True, units=3, name="dense")
y = sym.flatten(y) y = sym.flatten(y)
def forward(x, dense_weight, dense_bias): def forward(x, dense_weight, dense_bias):
return np.dot(x, dense_weight.T) + dense_bias return np.dot(x, dense_weight.T) + dense_bias
dtype = "float32" dtype = "float32"
inputs = { inputs = [
'x': ((10, 100), x), ('x', (10, 100), x),
'dense_weight': ((3, 100),), ('dense_weight', (3, 100), w),
'dense_bias': ((3,),) ('dense_bias', (3,), b)
} ]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward)
...@@ -215,13 +233,13 @@ def test_batchnorm(): ...@@ -215,13 +233,13 @@ def test_batchnorm():
return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta
dtype = "float32" dtype = "float32"
inputs = { inputs = [
'x': ((10, 20), x), ('x', (10, 20), x),
'gamma': ((20,),), ('gamma', (20,), gamma),
'beta': ((20,),), ('beta', (20,), beta),
'moving_mean': ((20,),), ('moving_mean', (20,), moving_var),
'moving_var': ((20,),) ('moving_var', (20,), moving_mean)
} ]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward)
...@@ -283,9 +301,12 @@ def verify_squeeze(dshape, axis): ...@@ -283,9 +301,12 @@ def verify_squeeze(dshape, axis):
def forward(x): def forward(x):
return np.squeeze(x, axis=axis) + 1 return np.squeeze(x, axis=axis) + 1
def backward(head_grads, x):
return [np.reshape(head_grads, x.shape)]
dtype = "float32" dtype = "float32"
inputs = {'x': (dshape, x)} inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward, backward)
def test_squeeze(): def test_squeeze():
...@@ -304,7 +325,7 @@ def test_pad(): ...@@ -304,7 +325,7 @@ def test_pad():
mode='constant', constant_values=1.) mode='constant', constant_values=1.)
dtype = "float32" dtype = "float32"
inputs = {'x': ((1, 3, 28, 28), x)} inputs = [('x', (1, 3, 28, 28), x)]
helper(y, inputs, dtype, forward) helper(y, inputs, dtype, forward)
......
...@@ -6,6 +6,46 @@ import nnvm.symbol as sym ...@@ -6,6 +6,46 @@ import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {}
input_syms = []
np_inputs = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
input_syms.append(s)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs)
y_np = np_forward(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
# backward
if np_backward:
graph._set_symbol_list_attr("grad_ys", symbol)
graph._set_symbol_list_attr("grad_xs", input_syms)
graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
graph = graph.apply("Gradient")
ishapes.update({"head_grads": y_np.shape})
graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
m = graph_runtime.create(graph, lib, ctx)
head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
y_np = np_backward(head_grads=head_grads, **np_inputs)
b_inputs = {}
if need_input:
b_inputs.update(np_inputs)
if need_head_grads:
b_inputs.update({"head_grads":head_grads})
m.run(**b_inputs)
for i in range(len(y_np)):
out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
def verify_transpose(dshape, axes): def verify_transpose(dshape, axes):
x = sym.Variable("x") x = sym.Variable("x")
if axes: if axes:
...@@ -66,13 +106,245 @@ def verify_reshape(dshape, oshape): ...@@ -66,13 +106,245 @@ def verify_reshape(dshape, oshape):
out = m.get_output(0, tvm.nd.empty(out_np.shape)) out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_reshape(): def test_reshape():
verify_reshape((2, 3, 4), (-1, 2, 1)) verify_reshape((2, 3, 4), (-1, 2, 1))
verify_reshape((2, 3, 4), (8, 3)) verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2)) verify_reshape((4, 7), (2, 7, 2))
def test_clip():
x = sym.Variable("x")
a_min=0.2
a_max=0.75
y = sym.clip(x, a_min=a_min, a_max=a_max)
def forward(x):
return np.clip(x, a_min=a_min, a_max=a_max)
def backward(head_grads, x):
mask1 = np.greater_equal(x, a_min).astype("float")
mask2 = np.less_equal(x, a_max).astype("float")
return [head_grads * mask1 * mask2]
dtype = "float32"
inputs = [('x', (3, 4, 5), x)]
helper(y, inputs, dtype, forward, backward)
def test_greater():
l = sym.Variable("l")
r = sym.Variable("r")
y = sym.greater(l, r)
def forward(l, r):
return np.greater(l, r).astype("float32")
def backward(head_grads, l, r):
return [np.zeros_like(l)]
dtype = "float32"
inputs = [('l', (3, 4, 5), l),
('r', (3, 4, 5), r)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_less():
l = sym.Variable("l")
r = sym.Variable("r")
y = sym.less(l, r)
def forward(l, r):
return np.less(l, r).astype("float32")
def backward(head_grads, l, r):
return [np.zeros_like(l)]
dtype = "float32"
inputs = [('l', (3, 4, 5), l),
('r', (3, 4, 5), r)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_reshape_like():
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.reshape_like(x, y)
def forward(x, y):
return np.reshape(x, y.shape)
def backward(head_grads, x, y):
return [np.reshape(head_grads, x.shape),
np.zeros_like(y)]
dtype = "float32"
inputs = [('x', (3, 4, 5), x),
('y', (5, 4, 3), y)]
helper(z, inputs, dtype, forward, backward)
def verify_expand_like(in_shape, out_shape, axis, exclude):
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.expand_like(x, y, axis=axis, exclude=exclude)
def forward(x, y):
odim = len(out_shape)
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
real_axis = list(set(range(odim)) - set(real_axis))
for i in real_axis:
x = np.expand_dims(x, i).astype(x.dtype)
for i in real_axis:
x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype)
return x
def backward(head_grads, x, y):
odim = len(out_shape)
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
real_axis = list(set(range(odim)) - set(real_axis))
return [np.sum(head_grads, axis=tuple(real_axis)),
np.zeros_like(y)]
dtype = "float32"
inputs = [('x', in_shape, x),
('y', out_shape, y)]
helper(z, inputs, dtype, forward, backward, need_input=False)
def test_expand_like():
verify_expand_like((3,), (3, 2), [1], False)
verify_expand_like((2,), (2, 3), [1], False)
verify_expand_like((3, 4), (3, 5, 4), [1], False)
verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True)
def verify_elemwise_sum(num_args):
s = [sym.Variable("input" + str(i)) for i in range(num_args)]
y = sym.elemwise_sum(*s, num_args=num_args)
def forward(**inputs):
return np.sum(np.array(list(inputs.values())), axis=0)
def backward(head_grads, **inputs):
return [head_grads] * num_args
dtype = "float32"
inputs = [("input" + str(i), (3, 4, 5), s[i])
for i in range(num_args)]
helper(y, inputs, dtype, forward, backward, need_input=False)
def test_elemwise_sum():
verify_elemwise_sum(1)
verify_elemwise_sum(5)
verify_elemwise_sum(7)
def test_block_grad():
x = sym.Variable("x")
y = sym.block_grad(x)
def forward(x):
return x
def backward(head_grads, x):
return [np.zeros_like(head_grads)]
dtype = "float32"
inputs = [('x', (3, 4, 5), x)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_full():
shape = (3, 4, 5)
value = 7
dtype = "float32"
for target, ctx in ctx_list():
data = sym.Variable("data", dtype=dtype)
# full_like
s = sym.full_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=value, dtype=dtype),
atol=1e-5, rtol=1e-5)
# ones_like
s = sym.ones_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=1, dtype=dtype),
atol=1e-5, rtol=1e-5)
# zeros_like
s = sym.zeros_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=0, dtype=dtype),
atol=1e-5, rtol=1e-5)
# full
s = sym.full(shape=shape, dtype=dtype, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=value, dtype=dtype),
atol=1e-5, rtol=1e-5)
# ones
s = sym.ones(shape=shape, dtype=dtype, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=1, dtype=dtype),
atol=1e-5, rtol=1e-5)
# zeros
s = sym.zeros(shape=shape, dtype=dtype, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=0, dtype=dtype),
atol=1e-5, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_reshape() test_reshape()
test_reduce() test_reduce()
test_tranpose() test_tranpose()
test_clip()
test_greater()
test_less()
test_reshape_like()
test_expand_like()
test_elemwise_sum()
test_block_grad()
test_full()
print(nnvm.compiler.engine.dump()) print(nnvm.compiler.engine.dump())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment