Commit 0aa8ce01 by Tianqi Chen

[TOP] level4 except argmax/min, correct split (#9)

* [TOP] level4 except argmax/min, correct split

* [DOCS] Add doc generator for top
parent 5f677945
# Makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
PAPER =
BUILDDIR = _build
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
# the i18n builder cannot share the environment and doctrees with the others
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " html to make standalone HTML files"
@echo " dirhtml to make HTML files named index.html in directories"
@echo " singlehtml to make a single large HTML file"
@echo " pickle to make pickle files"
@echo " json to make JSON files"
@echo " htmlhelp to make HTML files and a HTML help project"
@echo " qthelp to make HTML files and a qthelp project"
@echo " applehelp to make an Apple Help Book"
@echo " devhelp to make HTML files and a Devhelp project"
@echo " epub to make an epub"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex"
@echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
@echo " text to make text files"
@echo " man to make manual pages"
@echo " texinfo to make Texinfo files"
@echo " info to make Texinfo files and run them through makeinfo"
@echo " gettext to make PO message catalogs"
@echo " changes to make an overview of all changed/added/deprecated items"
@echo " xml to make Docutils-native XML files"
@echo " pseudoxml to make pseudoxml-XML files for display purposes"
@echo " linkcheck to check all external links for integrity"
@echo " doctest to run all doctests embedded in the documentation (if enabled)"
@echo " coverage to run coverage check of the documentation (if enabled)"
clean:
rm -rf $(BUILDDIR)/*
rm -rf gen_modules
html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
@echo "Build finished; now you can process the pickle files."
json:
$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@echo
@echo "Build finished; now you can process the JSON files."
htmlhelp:
$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
@echo
@echo "Build finished; now you can run HTML Help Workshop with the" \
".hhp project file in $(BUILDDIR)/htmlhelp."
qthelp:
$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
@echo
@echo "Build finished; now you can run "qcollectiongenerator" with the" \
".qhcp project file in $(BUILDDIR)/qthelp, like this:"
@echo "# qcollectiongenerator $(BUILDDIR)/qthelp/rabit.qhcp"
@echo "To view the help file:"
@echo "# assistant -collectionFile $(BUILDDIR)/qthelp/rabit.qhc"
applehelp:
$(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
@echo
@echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
@echo "N.B. You won't be able to view it unless you put it in" \
"~/Library/Documentation/Help or install it in your application" \
"bundle."
devhelp:
$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
@echo
@echo "Build finished."
@echo "To view the help file:"
@echo "# mkdir -p $$HOME/.local/share/devhelp/rabit"
@echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/rabit"
@echo "# devhelp"
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)."
latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
latexpdfja:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through platex and dvipdfmx..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
text:
$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
@echo
@echo "Build finished. The text files are in $(BUILDDIR)/text."
man:
$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
@echo
@echo "Build finished. The manual pages are in $(BUILDDIR)/man."
texinfo:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo
@echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
@echo "Run \`make' in that directory to run these through makeinfo" \
"(use \`make info' here to do that automatically)."
info:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo "Running Texinfo files through makeinfo..."
make -C $(BUILDDIR)/texinfo info
@echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
gettext:
$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
@echo
@echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
changes:
$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
@echo
@echo "The overview file is in $(BUILDDIR)/changes."
linkcheck:
$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
@echo
@echo "Link check complete; look for any errors in the above output " \
"or in $(BUILDDIR)/linkcheck/output.txt."
doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
coverage:
$(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
@echo "Testing of coverage in the sources finished, look at the " \
"results in $(BUILDDIR)/coverage/python.txt."
xml:
$(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
@echo
@echo "Build finished. The XML files are in $(BUILDDIR)/xml."
pseudoxml:
$(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
@echo
@echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
# -*- coding: utf-8 -*-
#
# documentation build configuration file, created by
# sphinx-quickstart on Thu Jul 23 19:40:08 2015.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
import sys
import os, subprocess
import shlex
import recommonmark
from recommonmark.parser import CommonMarkParser
from recommonmark.transform import AutoStructify
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../python/'))
# -- General configuration ------------------------------------------------
# General information about the project.
project = u'nnvm'
author = u'%s developers' % project
copyright = u'2017, %s' % author
github_doc_root = 'https://github.com/dmlc/nnvm/tree/master/docs/'
# add markdown parser
CommonMarkParser.github_doc_root = github_doc_root
source_parsers = {
'.md': CommonMarkParser
}
os.environ['NNVM_BUILD_DOC'] = '1'
# Version information.
import nnvm
version = nnvm.__version__
release = nnvm.__version__
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinx.ext.mathjax'
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
# source_suffix = ['.rst', '.md']
source_suffix = ['.rst', '.md']
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# generate autosummary even if no references
autosummary_generate = True
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
# The reST default role (used for this markup: `text`) to use for all
# documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme is set by the make target
html_theme = os.environ.get('NNVM_THEME', 'rtd')
on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
# only import rtd theme and set it if want to build docs locally
if not on_rtd and html_theme == 'rtd':
import sphinx_rtd_theme
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, '%s.tex' % project, project,
author, 'manual'),
]
# hook for doxygen
def run_doxygen(folder):
"""Run the doxygen make command in the designated folder."""
try:
#retcode = subprocess.call("cd %s; make doc" % folder, shell=True)
retcode = subprocess.call("rm -rf _build/html/doxygen", shell=True)
retcode = subprocess.call("mkdir -p _build/html", shell=True)
retcode = subprocess.call("cp -rf doxygen/html _build/html/doxygen", shell=True)
if retcode < 0:
sys.stderr.write("doxygen terminated by signal %s" % (-retcode))
except OSError as e:
sys.stderr.write("doxygen execution failed: %s" % e)
intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
}
def generate_doxygen_xml(app):
"""Run the doxygen make commands if we're on the ReadTheDocs server"""
run_doxygen('..')
def setup(app):
# Add hook for building doxygen xml when needed
# no c++ API for now
app.connect("builder-inited", generate_doxygen_xml)
app.add_config_value('recommonmark_config', {
'url_resolver': lambda url: github_doc_root + url,
'auto_doc_ref': True
}, True)
app.add_transform(AutoStructify)
NNVM Documentation
==================
Welcome to NNVM documentation.
Contents
--------
.. toctree::
:maxdepth: 1
top
NNVM Core Operator Specification
================================
Each operator attributes are stored in json format.
tuples are stored as json array.
## Tier 1: Basic Operators
***Enables fully connected nets***
- **dense**
- attributes
- units: int Number of hidden units in the data.
- use_bias: bool Whether use bias
- inputs
- data, 2D Tensor
- weight, 2D Tensor
- bias, optional, 1D Tensor
- outputs
- output, 2D Tensor
- **relu**
- inputs
- data, nD Tensor
- outputs
- output, nD Tensor
NNVM Core Primitives
====================
**Level 1: Basic Ops**
.. autosummary::
:nosignatures:
nnvm.symbol.dense
nnvm.symbol.relu
nnvm.symbol.tanh
nnvm.symbol.sigmoid
nnvm.symbol.exp
nnvm.symbol.log
nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
nnvm.symbol.softmax
nnvm.symbol.log_softmax
**Level 2: Convolutions**
.. autosummary::
:nosignatures:
nnvm.symbol.conv2d
nnvm.symbol.conv2d_transpose
nnvm.symbol.max_pool2d
nnvm.symbol.avg_pool2d
nnvm.symbol.global_max_pool2d
nnvm.symbol.global_avg_pool2d
**Level 3: Additional Tensor Ops**
.. autosummary::
:nosignatures:
nnvm.symbol.reshape
nnvm.symbol.copy
nnvm.symbol.negative
nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__
nnvm.symbol.__rsub_scalar__
nnvm.symbol.__mul_scalar__
nnvm.symbol.__div_scalar__
nnvm.symbol.__rdiv_scalar__
nnvm.symbol.__pow_scalar__
nnvm.symbol.__rpow_scalar__
**Level 4: Broadcast and Reductions**
.. autosummary::
:nosignatures:
nnvm.symbol.transpose
nnvm.symbol.broadcast_to
nnvm.symbol.sum
nnvm.symbol.min
nnvm.symbol.max
nnvm.symbol.broadcast_add
nnvm.symbol.broadcast_sub
nnvm.symbol.broadcast_mul
nnvm.symbol.broadcast_div
.. autofunction:: nnvm.symbol.dense
.. autofunction:: nnvm.symbol.relu
.. autofunction:: nnvm.symbol.tanh
.. autofunction:: nnvm.symbol.sigmoid
.. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
.. autofunction:: nnvm.symbol.softmax
.. autofunction:: nnvm.symbol.log_softmax
.. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose
.. autofunction:: nnvm.symbol.max_pool2d
.. autofunction:: nnvm.symbol.avg_pool2d
.. autofunction:: nnvm.symbol.global_max_pool2d
.. autofunction:: nnvm.symbol.global_avg_pool2d
.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
.. autofunction:: nnvm.symbol.__rsub_scalar__
.. autofunction:: nnvm.symbol.__mul_scalar__
.. autofunction:: nnvm.symbol.__div_scalar__
.. autofunction:: nnvm.symbol.__rdiv_scalar__
.. autofunction:: nnvm.symbol.__pow_scalar__
.. autofunction:: nnvm.symbol.__rpow_scalar__
.. autofunction:: nnvm.symbol.transpose
.. autofunction:: nnvm.symbol.broadcast_to
.. autofunction:: nnvm.symbol.sum
.. autofunction:: nnvm.symbol.min
.. autofunction:: nnvm.symbol.max
.. autofunction:: nnvm.symbol.broadcast_add
.. autofunction:: nnvm.symbol.broadcast_sub
.. autofunction:: nnvm.symbol.broadcast_mul
.. autofunction:: nnvm.symbol.broadcast_div
......@@ -25,6 +25,9 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
// numpy convention, only support indices, not support list.
Tuple<int> indices_or_sections;
int axis;
// additional hint whether it is equal_split mode
// deduced from indices_or_sections
bool equal_split;
DMLC_DECLARE_PARAMETER(SplitParam) {
DMLC_DECLARE_FIELD(indices_or_sections)
......@@ -73,6 +76,54 @@ struct ScalarParam : public dmlc::Parameter<ScalarParam> {
}
};
struct TransposeParam : public dmlc::Parameter<TransposeParam> {
TShape axes;
DMLC_DECLARE_PARAMETER(TransposeParam) {
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Target axis order. By default the axes will be inverted.");
}
};
struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
TShape shape;
DMLC_DECLARE_PARAMETER(BroadcastToParam) {
DMLC_DECLARE_FIELD(shape).set_default(TShape())
.describe("The shape of the desired array."
" We can set the dim to zero if it's same as the original."
" E.g `A = broadcast_to(B, shape=(10, 0, 0))` ");
}
};
struct ReduceParam : public dmlc::Parameter<ReduceParam> {
TShape axis;
bool keepdims;
bool exclude;
DMLC_DECLARE_PARAMETER(ReduceParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -219,5 +219,6 @@ def _init_symbol_module(symbol_class, root_namespace):
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
......@@ -213,5 +213,6 @@ def _init_symbol_module(symbol_class, root_namespace):
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
......@@ -31,9 +31,9 @@ class Symbol(SymbolBase):
def __add__(self, other):
if isinstance(other, Symbol):
return _internal.__add_symbol__(self, other)
return __add_symbol__(self, other)
elif isinstance(other, _Number):
return _internal.__add_scalar__(self, scalar=other)
return __add_scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))
......@@ -42,23 +42,23 @@ class Symbol(SymbolBase):
def __sub__(self, other):
if isinstance(other, Symbol):
return _internal.__sub_symbol__(self, other)
return __sub_symbol__(self, other)
if isinstance(other, _Number):
return _internal.__sub_scalar__(self, scalar=other)
return __sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
if isinstance(other, _Number):
return _internal.__rsub_scalar__(self, scalar=other)
return __rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
return __mul_symbol__(self, other)
if isinstance(other, _Number):
return _internal.__mul_scalar__(self, scalar=other)
return __mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
......@@ -67,15 +67,15 @@ class Symbol(SymbolBase):
def __div__(self, other):
if isinstance(other, Symbol):
return _internal.__div_symbol__(self, other)
return __div_symbol__(self, other)
if isinstance(other, _Number):
return _internal.__div_scalar__(self, scalar=other)
return __div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rdiv__(self, other):
if isinstance(other, _Number):
return _internal.__rdiv_scalar__(self, scalar=other)
return __rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
......@@ -87,15 +87,15 @@ class Symbol(SymbolBase):
def __pow__(self, other):
if isinstance(other, Symbol):
return _internal.__pow_symbol__(self, other)
return __pow_symbol__(self, other)
if isinstance(other, _Number):
return _internal.__pow_scalar__(self, scalar=other)
return __pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rpow__(self, other):
if isinstance(other, _Number):
return _internal.__rpow_scalar__(self, scalar=other)
return __rpow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
......
......@@ -107,8 +107,8 @@ inline bool ElemwiseType(const NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "first input") \
.add_argument("rhs", "NDArray-or-Symbol", "second input")
.add_argument("lhs", "Tensor", "first input") \
.add_argument("rhs", "Tensor", "second input")
} // namespace top
} // namespace nnvm
......
......@@ -98,6 +98,7 @@ NNVM_REGISTER_OP(dropout)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr_parser(ParamParser<DropoutParam>)
.add_arguments(DropoutParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
......@@ -171,6 +172,7 @@ axis to be the last item in the input shape.
.set_num_inputs(5)
.set_num_outputs(3)
.set_attr_parser(ParamParser<BatchNormParam>)
.add_arguments(BatchNormParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", BatchNormInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<5, 3>)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
......@@ -199,6 +201,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxParam>)
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);
......@@ -213,6 +216,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxParam>)
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);
......@@ -229,6 +233,7 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LeakyReLUParam>)
.add_arguments(LeakyReLUParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1);
......
......@@ -61,8 +61,10 @@ NNVM_REGISTER_OP(max_pool2d)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
......@@ -85,8 +87,10 @@ NNVM_REGISTER_OP(avg_pool2d)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
......
/*!
* Copyright (c) 2017 by Contributors
* \file broadcast.cc
* \brief broadcast operator.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"
namespace nnvm {
namespace top {
// broadcast_to
DMLC_REGISTER_PARAMETER(BroadcastToParam);
inline bool BroadcastToInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
CHECK_EQ(ishape.ndim(), param.shape.ndim())
<< "Operand of shape " << ishape
<< " cannot be broadcasted to " << param.shape;
TShape oshape = param.shape;
for (dim_t i = 0; i < ishape.ndim(); ++i) {
if (oshape[i] != 0) {
CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
<< "Array cannot be broadcasted from " <<
ishape << " to " << param.shape;
} else {
oshape[i] = ishape[i];
}
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
return true;
}
NNVM_REGISTER_OP(broadcast_to)
.describe(R"code(Broadcasts the input array to a new shape.
Broadcasting is a mechanism that allows NDArrays to perform arithmetic operations
with arrays of different shapes efficiently without creating multiple copies of arrays.
Also see, `Broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ for more explanation.
Broadcasting is allowed on axes with size 1, such as from `(2,1,3,1)` to
`(2,8,3,9)`. Elements will be duplicated on the broadcasted axes.
For example::
broadcast_to([[1,2,3]], shape=(2,3)) = [[ 1., 2., 3.],
[ 1., 2., 3.]])
The dimension which you do not want to change can also be kept as `0` which means copy the original value.
So with `shape=(2,0)`, we will obtain the same result as in the above example.
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<BroadcastToParam>)
.add_arguments(BroadcastToParam::__FIELDS__())
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(4);
// binary broadcast op
inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& lhs = (*in_attrs)[0];
const TShape& rhs = (*in_attrs)[1];
// avoid pre-mature shape inference.
if (lhs.ndim() == 0 || rhs.ndim() == 0) return false;
if (lhs == rhs) {
NNVM_ASSIGN_INPUT_SHAPE(attrs, *out_attrs, 0, lhs);
return true;
}
TShape out(std::max(lhs.ndim(), rhs.ndim()));
dim_t bl = out.ndim() - lhs.ndim();
dim_t br = out.ndim() - rhs.ndim();
for (dim_t i = 0; i < out.ndim(); ++i) {
dim_t l = 1, r = 1;
if (i >= bl) l = lhs[i - bl];
if (i >= br) r = rhs[i - br];
if (l != r) {
if (l == 0 || r == 0) {
out[i] = 0;
} else {
CHECK(l == 1 || r == 1)
<< "operands could not be broadcast together with shapes "
<< lhs << " " << rhs;
out[i] = std::max(l, r);
}
} else {
out[i] = l;
}
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out);
return true;
}
#define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
.set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.add_argument("lhs", "Tensor", "first input") \
.add_argument("rhs", "Tensor", "second input")
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_add)
.add_alias("__add_symbol__")
.describe(R"code(Returns element-wise sum of the input arrays with broadcasting.
Example::
x = [[ 1., 1., 1.],
[ 1., 1., 1.]]
y = [[ 0.],
[ 1.]]
broadcast_add(x, y) = [[ 1., 1., 1.],
[ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub)
.add_alias("__sub_symbol__")
.describe(R"code(Returns element-wise difference of the input arrays with broadcasting.
Example::
x = [[ 1., 1., 1.],
[ 1., 1., 1.]]
y = [[ 0.],
[ 1.]]
broadcast_sub(x, y) = [[ 1., 1., 1.],
[ 0., 0., 0.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul)
.add_alias("__mul_symbol__")
.describe(R"code(Returns element-wise product of the input arrays with broadcasting.
Example::
x = [[ 1., 1., 1.],
[ 1., 1., 1.]]
y = [[ 0.],
[ 1.]]
broadcast_mul(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div)
.add_alias("__div_symbol__")
.describe(R"code(Returns element-wise division of the input arrays with broadcasting.
Example::
x = [[ 6., 6., 6.],
[ 6., 6., 6.]]
y = [[ 2.],
[ 3.]]
broadcast_div(x, y) = [[ 3., 3., 3.],
[ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE);
} // namespace top
} // namespace nnvm
......@@ -17,17 +17,17 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid.
.. math::
y = 1 / (1 + exp(-x))
Y = 1 / (1 + exp(-X))
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
// tanh
NNVM_REGISTER_ELEMWISE_UNARY_OP(tanh)
.describe(R"code(Returns the hyperbolic tangent of the input array, computed element-wise.
.describe(R"code(Computes hyperbolic tangent.
.. math::
tanh(x) = sinh(x) / cosh(x)
Y = sinh(X) / cosh(X)
)code" NNVM_ADD_FILELINE)
.set_support_level(1);
......@@ -100,6 +100,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__add_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__sub_scalar__)
......@@ -107,6 +108,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__sub_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rsub_scalar__)
......@@ -114,6 +116,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__rsub_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__mul_scalar__)
......@@ -121,6 +124,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__mul_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__div_scalar__)
......@@ -128,6 +132,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__div_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rdiv_scalar__)
......@@ -135,6 +140,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__rdiv_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__pow_scalar__)
......@@ -142,6 +148,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__pow_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rpow_scalar__)
......@@ -149,6 +156,7 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(__rpow_scalar__)
)code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3);
......
/*!
* Copyright (c) 2017 by Contributors
* \file reduce.cc
* \brief reduce operator.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"
namespace nnvm {
namespace top {
// reduce
DMLC_REGISTER_PARAMETER(ReduceParam);
inline TShape ReduceShapeImpl(const TShape& ishape,
const TShape& axis,
bool keepdims,
bool exclude) {
if (axis.ndim() == 0) {
if (keepdims) {
return TShape(ishape.ndim());
} else {
return TShape(1);
}
}
CHECK_LT(axis[axis.ndim() - 1], ishape.ndim())
<< "Reduction axis " << axis[axis.ndim() - 1]
<< " Exceeds input dimensions " << ishape;
if (keepdims) {
TShape oshape(ishape);
if (exclude) {
for (dim_t i = 0, j = 0; i < ishape.ndim(); ++i) {
if (j < axis.ndim() && i == axis[j]) {
++j;
continue;
}
oshape[i] = 1;
}
return oshape;
}
for (dim_t i = 0; i < axis.ndim(); ++i) {
oshape[axis[i]] = 1;
}
return oshape;
}
if (exclude) {
TShape oshape = TShape(axis.ndim());
for (dim_t i = 0; i < axis.ndim(); ++i) {
oshape[i] = ishape[axis[i]];
}
return oshape;
}
TShape oshape = TShape(std::max<dim_t>(1, ishape.ndim() - axis.ndim()));
for (dim_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
if (j < axis.ndim() && i == axis[j]) {
++j;
continue;
}
oshape[k++] = ishape[i];
}
return oshape;
}
inline bool ReduceShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 0) return false;
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
NNVM_ASSIGN_INPUT_SHAPE(
attrs, *out_attrs, 0,
ReduceShapeImpl((*in_attrs)[0], param.axis,
param.keepdims, param.exclude));
return true;
}
template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
PType param;
param.Init(attrs->dict);
std::sort(&param.axis[0], &param.axis[param.axis.ndim()]);
attrs->parsed = std::move(param);
}
#define NNVM_REGISTER_REDUCE_OP(op) \
NNVM_REGISTER_OP(op) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_argument("data", "Tensor", "The input") \
.add_arguments(ReduceParam::__FIELDS__())
NNVM_REGISTER_REDUCE_OP(sum)
.describe(R"code(Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_REDUCE_OP(max)
.describe(R"code(Computes the max of array elements over given axes.
)code" NNVM_ADD_FILELINE);
NNVM_REGISTER_REDUCE_OP(min)
.describe(R"code(Computes the min of array elements over given axes.
)code" NNVM_ADD_FILELINE);
} // namespace top
} // namespace nnvm
......@@ -7,6 +7,7 @@
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include <cctype>
#include "../op_common.h"
#include "../elemwise_op_common.h"
......@@ -31,7 +32,7 @@ inline bool FlattenInferShape(const NodeAttrs& attrs,
}
NNVM_REGISTER_OP(flatten)
.describe(R"code(Flattens the input array into a 2-D array by collapsing the higher dimensions.
.describe(R"code(Flattens the input into a 2-D array.
For an input array with shape ``(d1, d2, ..., dk)``, `flatten` operation reshapes
the input array into an output array of shape ``(d1, d2*...*dk)``.
......@@ -134,16 +135,28 @@ Example::
.set_num_outputs(1)
.set_num_inputs(kVarg)
.set_attr_parser(ParamParser<ConcatenateParam>)
.add_arguments(ConcatenateParam::__FIELDS__())
.add_argument("data", "Tensor-or-Tensor[]", "List of arrays to concatenate")
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.add_arguments(ConcatenateParam::__FIELDS__())
.set_support_level(1);
// concatenate
DMLC_REGISTER_PARAMETER(SplitParam);
inline void SplitParamParser(nnvm::NodeAttrs* attrs) {
SplitParam param;
param.Init(attrs->dict);
if (!std::isdigit(attrs->dict.at("indices_or_sections")[0])) {
param.equal_split = false;
} else {
CHECK_EQ(param.indices_or_sections.ndim(), 1);
param.equal_split = true;
}
attrs->parsed = std::move(param);
}
inline bool SplitInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
......@@ -151,7 +164,7 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
const TShape& dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
if (param.indices_or_sections.ndim() == 1) {
if (param.equal_split) {
int num_outputs = param.indices_or_sections[0];
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
CHECK_LT(param.axis, dshape.ndim());
......@@ -164,30 +177,30 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
}
} else {
dim_t num_outputs = param.indices_or_sections.ndim();
dim_t num_outputs = param.indices_or_sections.ndim() + 1;
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
CHECK_LT(param.axis, dshape.ndim());
TShape oshape = dshape;
CHECK_EQ(oshape[param.axis] % num_outputs, 0)
<< "indices_or_sections need to be able to divide input.shape[axis]";
dim_t total = 0;
for (size_t i = 0; i < out_shape->size(); ++i) {
oshape[param.axis] = param.indices_or_sections[i];
for (size_t i = 1; i < num_outputs; ++i) {
oshape[param.axis] = param.indices_or_sections[i - 1];
total += oshape[param.axis];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i - 1, oshape);
}
CHECK_EQ(total, dshape[param.axis])
CHECK_LT(total, dshape[param.axis])
<< "The sum of sections must match the input.shape[axis]";
oshape[param.axis] = dshape[param.axis] - total;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, num_outputs - 1, oshape);
}
return true;
}
inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) {
const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
if (param.indices_or_sections.ndim() == 1) {
if (param.equal_split) {
return static_cast<uint32_t>(param.indices_or_sections[0]);
} else {
return static_cast<uint32_t>(param.indices_or_sections.ndim());
return static_cast<uint32_t>(param.indices_or_sections.ndim()) + 1;
}
}
......@@ -199,12 +212,12 @@ along which to split the array.
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attr_parser(ParamParser<SplitParam>)
.set_attr_parser(SplitParamParser)
.set_num_outputs(SplitNumOutputs)
.add_arguments(SplitParam::__FIELDS__())
.add_argument("data", "Tensor", "List of arrays to concatenate")
.set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.add_arguments(SplitParam::__FIELDS__())
.set_support_level(1);
// cast
......@@ -225,9 +238,9 @@ NNVM_REGISTER_OP(cast)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data array")
.set_attr_parser(ParamParser<CastParam>)
.add_arguments(CastParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType)
.add_arguments(CastParam::__FIELDS__())
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);
......@@ -377,10 +390,77 @@ The significance of each is explained below:
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReshapeParam>)
.add_arguments(ReshapeParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(3);
// tranpose
DMLC_REGISTER_PARAMETER(TransposeParam);
inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& shp = (*in_attrs)[0];
if (shp.ndim() == 0) return false;
TShape ret(shp.ndim());
if (param.axes.ndim() == 0) {
for (dim_t i = 0; i < shp.ndim(); ++i) {
ret[i] = shp[shp.ndim() - 1 - i];
}
} else {
CHECK_EQ(shp.ndim(), param.axes.ndim());
for (size_t i = 0; i < shp.ndim(); ++i) {
CHECK(param.axes[i] < shp.ndim());
ret[i] = shp[param.axes[i]];
}
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, ret);
return true;
}
NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array.
Examples::
x = [[ 1, 2],
[ 3, 4]]
transpose(x) = [[ 1., 3.],
[ 2., 4.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
transpose(x) = [[[ 1., 5.],
[ 3., 7.]],
[[ 2., 6.],
[ 4., 8.]]]
transpose(x, axes=(1,0,2)) = [[[ 1., 2.],
[ 5., 6.]],
[[ 3., 4.],
[ 7., 8.]]]
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<TransposeParam>)
.add_arguments(TransposeParam::__FIELDS__())
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Source input")
.set_support_level(4);
} // namespace top
} // namespace nnvm
......@@ -38,7 +38,7 @@ def test_concatenate():
def test_split():
x1 = sym.Variable("x", shape=(10, 20))
z = sym.split(x1, indices_or_sections=[11, 9], name="y")
z = sym.split(x1, indices_or_sections=[11], name="y")
sdict = infer_shape(z)
assert(sdict["y"][0] == [10, 11])
assert(sdict["y"][1] == [10, 9])
......@@ -195,6 +195,59 @@ def test_reshape():
check((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4))
check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
# Level 4
def test_transpose():
def check(in_shape, out_shape, **kwargs):
x = sym.Variable("x", shape=in_shape)
y = sym.transpose(x, name="y", **kwargs)
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((4, 1), (1, 4))
check((0, 1, 2, 3), (1, 2, 3, 0), axes=(1, 2, 3, 0))
def test_broadcast_to():
def check(in_shape, tshape, out_shape):
x = sym.Variable("x", shape=in_shape)
y = sym.broadcast_to(x, shape=tshape, name="y")
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((4, 1), (0, 4), (4, 4))
check((4, 1, 5), (0, 4, 5), (4, 4, 5))
def test_broadcast_binary():
def check(lhs_shape, rhs_shape, out_shape):
x = sym.Variable("x", shape=lhs_shape)
y = sym.Variable("y", shape=rhs_shape)
z = sym.broadcast_add(x, y, name="y")
sdict = infer_shape(z)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((4, 1), (4), (4, 4))
check((5, 1, 1), (1, 4, 4), (5, 4, 4))
check((6, 1, 4), (5, 4), (6, 5, 4))
def test_reduce():
def check(in_shape, out_shape, **kwargs):
x = sym.Variable("x", shape=in_shape)
y = sym.sum(x, name="y", **kwargs)
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((4, 5), (4,), axis=1)
check((4, 5), (4, 1), axis=1, keepdims=True)
check((4, 5), (1, 5), axis=0, keepdims=True)
check((4, 5), (1, 1), axis=(), keepdims=True)
check((4, 5), (1,), axis=())
check((4, 5, 10), (5,), axis=(0, 2))
check((4, 5, 10), (1, 5, 1), axis=(0, 2), keepdims=True)
if __name__ == "__main__":
test_dense()
test_concatenate()
......@@ -206,3 +259,7 @@ if __name__ == "__main__":
test_max_pool2d()
test_global_pool2d()
test_reshape()
test_broadcast_to()
test_broadcast_binary()
test_reduce()
test_transpose()
import nnvm.symbol as sym
def test_fullc():
def test_dense():
x = sym.Variable('x')
x1 = sym.dense(x, units=3, name="dense")
x2 = sym.flatten(x1)
x3 = sym.softmax(x2)
assert x2.list_input_names() == ['x', 'dense_weight', 'dense_bias']
def test_concatenate_split():
x = sym.Variable('x')
y = sym.Variable('y')
......@@ -15,7 +16,7 @@ def test_concatenate_split():
z = sym.split(y, indices_or_sections=10)
assert len(z.list_output_names()) == 10
z = sym.split(y, indices_or_sections=[10, 20])
assert len(z.list_output_names()) == 2
assert len(z.list_output_names()) == 3
def test_unary():
......@@ -26,6 +27,7 @@ def test_unary():
x = sym.tanh(x)
assert x.list_input_names() == ['x']
def test_batchnorm():
x = sym.Variable('x')
x = sym.batch_norm(x, name="bn")
......@@ -35,6 +37,6 @@ def test_batchnorm():
if __name__ == "__main__":
test_concatenate_split()
test_fullc()
test_dense()
test_unary()
test_batchnorm()
import nnvm.symbol as sym
def test_conv2d():
x = sym.Variable('x')
y = sym.conv2d(x, channels=3, kernel_size=(3, 3),
name="y", use_bias=False)
assert y.list_input_names() == ["x", "y_weight"]
def test_max_pool2d():
x = sym.Variable('x')
y = sym.max_pool2d(x, pool_size=(3, 3), name="y")
y = sym.global_max_pool2d(y)
assert y.list_input_names() == ["x"]
if __name__ == "__main__":
test_conv2d()
test_max_pool2d()
import nnvm.symbol as sym
def test_binary_broadcast():
x = sym.Variable('x')
y = sym.Variable('y')
z = x + y
z = x * y
z = x - y
z = x / y
def test_broadcast_to():
x = sym.Variable('x')
y = sym.broadcast_to(x, shape=(3, 3))
assert y.list_input_names() == ["x"]
if __name__ == "__main__":
test_binary_broadcast()
test_broadcast_to()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment